Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/Access/AccessControl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,12 @@ void AccessControl::restoreFromBackup(RestorerFromBackup & restorer, const Strin

void AccessControl::setExternalAuthenticatorsConfig(const Poco::Util::AbstractConfiguration & config)
{
external_authenticators->setConfiguration(config, getLogger());
/// Re-read `enable_token_auth` on every config reload. `setupFromMainConfig`
/// runs only once at startup, so without this re-sync flipping the flag in
/// the config and triggering a reload would silently leave the previous
/// value in place -- operators who toggle token auth off in response to an
/// IdP outage or a credential leak would see no effect until restart.
external_authenticators->setConfiguration(config, getLogger(), config.getBool("enable_token_auth", true));
}


Expand Down
56 changes: 55 additions & 1 deletion src/Access/AuthenticationData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ bool AuthenticationData::Util::checkPasswordBcrypt(std::string_view password [[m

bool operator ==(const AuthenticationData & lhs, const AuthenticationData & rhs)
{
/// `MemoryAccessStorage::updateNoLock` short-circuits when the existing
/// entity equals the new one, so any field omitted from this comparator
/// becomes invisible to ALTER USER -- same-type ALTER would silently
/// no-op. JWT users carry two extra fields (`token_processor_name` and
/// `jwt_claims`) and they MUST take part in equality, otherwise re-pinning
/// a JWT user via ALTER USER is a no-op (CREATE USER OR REPLACE works
/// only by accident, via storage->insertOrReplace).
return (lhs.type == rhs.type) && (lhs.password_hash == rhs.password_hash)
&& (lhs.ldap_server_name == rhs.ldap_server_name) && (lhs.kerberos_realm == rhs.kerberos_realm)
#if USE_SSL
Expand All @@ -157,6 +164,8 @@ bool operator ==(const AuthenticationData & lhs, const AuthenticationData & rhs)
#endif
&& (lhs.http_auth_scheme == rhs.http_auth_scheme)
&& (lhs.http_auth_server_name == rhs.http_auth_server_name)
&& (lhs.token_processor_name == rhs.token_processor_name)
&& (lhs.jwt_claims == rhs.jwt_claims)
&& (lhs.valid_until == rhs.valid_until);
}

Expand Down Expand Up @@ -411,7 +420,23 @@ boost::intrusive_ptr<ASTAuthenticationData> AuthenticationData::toAST() const
}
case AuthenticationType::JWT:
{
throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "JWT is available only in ClickHouse Cloud");
/// Round-trip into the same shape the parser produces: PROCESSOR
/// child first (when set), CLAIMS child after (when set), with the
/// AST flags telling the formatter which slot is which.
const auto & processor_name = getTokenProcessorName();
if (!processor_name.empty())
{
node->has_jwt_processor = true;
node->children.push_back(make_intrusive<ASTLiteral>(processor_name));
}

const auto & claims = getJWTClaims();
if (!claims.empty())
{
node->has_jwt_claims = true;
node->children.push_back(make_intrusive<ASTLiteral>(claims));
}
break;
}
case AuthenticationType::KERBEROS:
{
Expand Down Expand Up @@ -689,6 +714,35 @@ AuthenticationData AuthenticationData::fromAST(const ASTAuthenticationData & que
auth_data.setHTTPAuthenticationServerName(server);
auth_data.setHTTPAuthenticationScheme(scheme);
}
#if USE_JWT_CPP
else if (query.type == AuthenticationType::JWT)
{
/// `query.has_jwt_processor` and `query.has_jwt_claims` describe which
/// of the two optional clauses the parser saw. Children are pushed in
/// PROCESSOR-then-CLAIMS order, so we walk them in that order.
size_t arg_idx = 0;

if (query.has_jwt_processor)
{
String processor_name = checkAndGetLiteralArgument<String>(args[arg_idx++], "processor");
if (processor_name.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "JWT 'PROCESSOR' name must not be empty");
auth_data.setTokenProcessorName(processor_name);
}

if (query.has_jwt_claims)
{
String value = checkAndGetLiteralArgument<String>(args[arg_idx++], "claims");
picojson::value json_obj;
auto error = picojson::parse(json_obj, value);
if (!error.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bad JWT claims: {}", error);
if (!json_obj.is<picojson::object>())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bad JWT claims: is not an object");
auth_data.setJWTClaims(value);
}
}
#endif
else
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected ASTAuthenticationData structure");
Expand Down
8 changes: 8 additions & 0 deletions src/Access/AuthenticationData.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class AuthenticationData
const String & getHTTPAuthenticationServerName() const { return http_auth_server_name; }
void setHTTPAuthenticationServerName(const String & name) { http_auth_server_name = name; }

const String & getTokenProcessorName() const { return token_processor_name; }
void setTokenProcessorName(const String & name) { token_processor_name = name; }

const String & getJWTClaims() const { return jwt_claims; }
void setJWTClaims(const String & claims) { jwt_claims = claims; }

time_t getValidUntil() const { return valid_until; }
void setValidUntil(time_t valid_until_) { valid_until = valid_until_; }

Expand Down Expand Up @@ -121,6 +127,8 @@ class AuthenticationData
String http_auth_server_name;
HTTPAuthenticationScheme http_auth_scheme = HTTPAuthenticationScheme::BASIC;
time_t valid_until = 0;
String token_processor_name;
String jwt_claims;
};

}
232 changes: 232 additions & 0 deletions src/Access/Common/JWKSProvider.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
#include <Access/Common/JWKSProvider.h>

#if USE_JWT_CPP
#include <Common/Exception.h>
#include <Common/logger_useful.h>
#include <filesystem>
#include <mutex>
#include <shared_mutex>
#include <system_error>
#include <Poco/Net/HTTPRequest.h>
#include <Poco/Net/HTTPResponse.h>
#include <Poco/Net/HTTPSClientSession.h>
#include <Poco/StreamCopier.h>
#include <fstream>


namespace DB
{

namespace ErrorCodes
{
extern const int AUTHENTICATION_FAILED;
extern const int INVALID_CONFIG_PARAMETER;
}

JWKSType JWKSClient::getJWKS()
{
/// `last_request_send` semantics: timestamp of the most recent fetch
/// *attempt*, success or failure. Updated unconditionally before the
/// HTTP call so a failed fetch doesn't leave the timestamp stale and
/// invite every concurrent thread to re-hammer a failing endpoint
/// (L-02). Within `refresh_timeout` of an attempt:
/// - if a previously-successful JWKS is cached, serve it.
/// - otherwise, throw a "fetch in cooldown" exception so callers
/// don't queue up new attempts during the back-off window.

{
std::shared_lock lock(mutex);
auto now = std::chrono::steady_clock::now();
if (last_request_send.has_value())
{
auto diff = std::chrono::duration<double>(now - *last_request_send).count();
if (diff < static_cast<double>(refresh_timeout))
{
if (cached_jwks.has_value())
return cached_jwks.value();
throw Exception(ErrorCodes::AUTHENTICATION_FAILED,
"JWKS endpoint at '{}' is in cooldown after a recent failed fetch; will retry after the cache lifetime elapses",
jwks_uri.toString());
}
}
}

std::unique_lock lock(mutex);
auto now = std::chrono::steady_clock::now();
if (last_request_send.has_value())
{
auto diff = std::chrono::duration<double>(now - *last_request_send).count();
if (diff < static_cast<double>(refresh_timeout))
{
if (cached_jwks.has_value())
return cached_jwks.value();
throw Exception(ErrorCodes::AUTHENTICATION_FAILED,
"JWKS endpoint at '{}' is in cooldown after a recent failed fetch; will retry after the cache lifetime elapses",
jwks_uri.toString());
}
}

/// Mark the attempt before issuing the network call so that even if the
/// fetch throws, subsequent waiters on this mutex see an updated
/// `last_request_send` and short-circuit via the cooldown branches above
/// instead of repeating the failing fetch back-to-back.
last_request_send = now;

Poco::Net::HTTPResponse response;
std::string response_string;

Poco::Net::HTTPRequest request{Poco::Net::HTTPRequest::HTTP_GET, jwks_uri.getPathAndQuery()};

/// Bound every JWKS fetch to a known limit. Without this, Poco's default
/// `HTTPSession` timeout of 60 seconds applies, and because the JWKS fetch
/// runs while `ExternalAuthenticators::mutex` is held by the outer
/// `checkTokenCredentials` call, a single slow or hung JWKS endpoint would
/// stall the whole auth subsystem (LDAP, Kerberos, HTTP basic, all other
/// token auth paths) for up to a full minute per request. 10 seconds is a
/// conservative cap: well above any healthy provider latency, well below
/// the default.
const Poco::Timespan jwks_http_timeout(/*seconds=*/10, 0);

if (jwks_uri.getScheme() == "https")
{
Poco::Net::HTTPSClientSession session = Poco::Net::HTTPSClientSession(jwks_uri.getHost(), jwks_uri.getPort());
session.setTimeout(jwks_http_timeout, jwks_http_timeout, jwks_http_timeout);
session.sendRequest(request);
std::istream & response_stream = session.receiveResponse(response);
if (response.getStatus() != Poco::Net::HTTPResponse::HTTP_OK || !response_stream)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Failed to get user info by access token, code: {}, reason: {}",
response.getStatus(), response.getReason());
Poco::StreamCopier::copyToString(response_stream, response_string);
}
else
{
Poco::Net::HTTPClientSession session = Poco::Net::HTTPClientSession(jwks_uri.getHost(), jwks_uri.getPort());
session.setTimeout(jwks_http_timeout, jwks_http_timeout, jwks_http_timeout);
session.sendRequest(request);
std::istream & response_stream = session.receiveResponse(response);
if (response.getStatus() != Poco::Net::HTTPResponse::HTTP_OK || !response_stream)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Failed to get user info by access token, code: {}, reason: {}", response.getStatus(), response.getReason());
Poco::StreamCopier::copyToString(response_stream, response_string);
}

JWKSType parsed_jwks;

try
{
parsed_jwks = jwt::parse_jwks(response_string);
}
catch (const std::exception & e)
{
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Failed to parse JWKS: {}", e.what());
}

cached_jwks = std::move(parsed_jwks);
return cached_jwks.value();
}

StaticJWKSParams::StaticJWKSParams(const std::string & static_jwks_, const std::string & static_jwks_file_)
{
if (static_jwks_.empty() && static_jwks_file_.empty())
throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER,
"JWT validator misconfigured: `static_jwks` or `static_jwks_file` keys must be present in static JWKS validator configuration");
if (!static_jwks_.empty() && !static_jwks_file_.empty())
throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER,
"JWT validator misconfigured: `static_jwks` and `static_jwks_file` keys cannot both be present in static JWKS validator configuration");

static_jwks = static_jwks_;
static_jwks_file = static_jwks_file_;
}

StaticJWKS::StaticJWKS(const StaticJWKSParams & params)
{
static_jwks_file = params.static_jwks_file;

String content = String(params.static_jwks);
if (!static_jwks_file.empty())
{
std::ifstream ifs(static_jwks_file);
Poco::StreamCopier::copyToString(ifs, content);
/// Record the mtime so subsequent `getJWKS()` calls can notice rotation.
std::error_code ec;
const auto write_time = std::filesystem::last_write_time(static_jwks_file, ec);
if (!ec)
last_loaded_mtime = write_time;
}
try
{
auto keys = jwt::parse_jwks(content);
jwks = std::move(keys);
}
catch (const std::exception & e)
{
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Failed to parse JWKS: {}", e.what());
}
}

void StaticJWKS::reloadFromFileIfChangedNoLock()
{
/// Inline `static_jwks` source: nothing to refresh from disk.
if (static_jwks_file.empty())
return;

std::error_code ec;
const auto mtime = std::filesystem::last_write_time(static_jwks_file, ec);
if (ec)
{
/// File disappeared or became unreadable. Keep the previously-loaded
/// keys -- failing closed here would lock everyone out on a transient
/// filesystem hiccup. The operator gets a log signal.
LOG_WARNING(getLogger("TokenAuthentication"),
"StaticJWKS: failed to stat '{}' for refresh ({}); keeping previously-loaded keys.",
static_jwks_file, ec.message());
return;
}
if (mtime <= last_loaded_mtime)
return;

/// File has been rotated. Read + parse + swap.
String content;
try
{
std::ifstream ifs(static_jwks_file);
Poco::StreamCopier::copyToString(ifs, content);
auto new_keys = jwt::parse_jwks(content);
jwks = std::move(new_keys);
last_loaded_mtime = mtime;
LOG_INFO(getLogger("TokenAuthentication"),
"StaticJWKS: reloaded keys from '{}' after detecting mtime change.", static_jwks_file);
}
catch (const std::exception & e)
{
/// Malformed new JWKS: keep the old one. Loud signal so the operator
/// knows the rotation didn't take.
LOG_ERROR(getLogger("TokenAuthentication"),
"StaticJWKS: failed to parse '{}' on refresh: {}; keeping previously-loaded keys.",
static_jwks_file, e.what());
}
}

JWKSType StaticJWKS::getJWKS()
{
/// Fast path: shared lock + mtime check. Refresh under exclusive lock only
/// when the file actually changed.
{
std::shared_lock lock(mutex);
if (static_jwks_file.empty())
return jwks;

std::error_code ec;
const auto mtime = std::filesystem::last_write_time(static_jwks_file, ec);
if (ec)
return jwks;
if (mtime <= last_loaded_mtime)
return jwks;
}

std::unique_lock lock(mutex);
reloadFromFileIfChangedNoLock();
return jwks;
}

}
#endif
Loading
Loading