diff --git a/src/iceberg/catalog/rest/CMakeLists.txt b/src/iceberg/catalog/rest/CMakeLists.txt index b862bc869..8fb2e93c0 100644 --- a/src/iceberg/catalog/rest/CMakeLists.txt +++ b/src/iceberg/catalog/rest/CMakeLists.txt @@ -23,6 +23,7 @@ set(ICEBERG_REST_SOURCES auth/auth_properties.cc auth/auth_session.cc auth/oauth2_util.cc + auth/token_refresh_scheduler.cc catalog_properties.cc endpoint.cc error_handlers.cc diff --git a/src/iceberg/catalog/rest/auth/auth_session.cc b/src/iceberg/catalog/rest/auth/auth_session.cc index 7251dc4a9..0b4c391d2 100644 --- a/src/iceberg/catalog/rest/auth/auth_session.cc +++ b/src/iceberg/catalog/rest/auth/auth_session.cc @@ -19,9 +19,16 @@ #include "iceberg/catalog/rest/auth/auth_session.h" +#include +#include +#include +#include #include +#include "iceberg/catalog/rest/auth/auth_properties.h" #include "iceberg/catalog/rest/auth/oauth2_util.h" +#include "iceberg/catalog/rest/auth/token_refresh_scheduler.h" +#include "iceberg/catalog/rest/http_client.h" namespace iceberg::rest::auth { @@ -44,6 +51,183 @@ class DefaultAuthSession : public AuthSession { std::unordered_map headers_; }; +/// \brief OAuth2 session with automatic token refresh. +class OAuth2AuthSession : public AuthSession, + public std::enable_shared_from_this { + public: + struct Config { + std::string token_endpoint; + std::string client_id; + std::string client_secret; + std::string scope; + bool keep_refreshed; + }; + + /// \brief Create an OAuth2 session and optionally schedule refresh. + static std::shared_ptr Create( + const OAuthTokenResponse& initial_token, Config config, HttpClient& client) { + auto session = std::shared_ptr( + new OAuth2AuthSession(std::move(config), client)); + session->SetInitialToken(initial_token); + return session; + } + + Status Authenticate(std::unordered_map& headers) override { + std::shared_lock lock(mutex_); + for (const auto& [key, value] : headers_) { + headers.insert_or_assign(key, value); + } + return {}; + } + + Status Close() override { + bool expected = false; + if (!closed_.compare_exchange_strong(expected, true)) { + return {}; // Already closed + } + TokenRefreshScheduler::Instance().Cancel(scheduled_task_id_.load()); + return {}; + } + + private: + OAuth2AuthSession(Config config, HttpClient& client) + : config_(std::move(config)), client_(client) {} + + void SetInitialToken(const OAuthTokenResponse& token_response) { + token_ = token_response.access_token; + headers_ = {{std::string(kAuthorizationHeader), std::string(kBearerPrefix) + token_}}; + + // Determine expiration time + if (token_response.expires_in_secs.has_value()) { + expires_at_ = std::chrono::steady_clock::now() + + std::chrono::seconds(*token_response.expires_in_secs); + } else if (auto exp_ms = ExpiresAtMillis(token_); exp_ms.has_value()) { + // Convert absolute epoch millis to steady_clock time_point + auto now_sys = std::chrono::system_clock::now(); + auto now_steady = std::chrono::steady_clock::now(); + auto exp_sys = + std::chrono::system_clock::time_point(std::chrono::milliseconds(*exp_ms)); + expires_at_ = now_steady + (exp_sys - now_sys); + } + + if (config_.keep_refreshed && + expires_at_ != std::chrono::steady_clock::time_point{}) { + ScheduleRefresh(); + } + } + + void DoRefresh() { DoRefreshAttempt(0, std::chrono::milliseconds(200)); } + + /// \brief Single refresh attempt. On failure, schedules a retry via the + /// scheduler (non-blocking) instead of sleeping on the worker thread. + void DoRefreshAttempt(int attempt, std::chrono::milliseconds backoff) { + static constexpr int kMaxRetries = 5; + static constexpr auto kMaxBackoff = std::chrono::milliseconds(10'000); + + if (closed_.load()) return; + + // Build credential and properties once (invariant across retries) + std::string credential = config_.client_id.empty() + ? config_.client_secret + : config_.client_id + ":" + config_.client_secret; + + // Use an empty session for the refresh request (no auth headers — + // avoids circular dependency of using an expired token to refresh itself) + auto empty_session = AuthSession::MakeDefault({}); + + AuthProperties props; + props.Set(AuthProperties::kCredential, credential); + props.Set(AuthProperties::kScope, config_.scope); + props.Set(AuthProperties::kOAuth2ServerUri, config_.token_endpoint); + + auto result = FetchToken(client_, *empty_session, props); + if (result.has_value()) { + auto& response = result.value(); + { + std::unique_lock lock(mutex_); + token_ = response.access_token; + headers_ = { + {std::string(kAuthorizationHeader), std::string(kBearerPrefix) + token_}}; + + // Update expiration + if (response.expires_in_secs.has_value()) { + expires_at_ = std::chrono::steady_clock::now() + + std::chrono::seconds(*response.expires_in_secs); + } else if (auto exp_ms = ExpiresAtMillis(token_); exp_ms.has_value()) { + auto now_sys = std::chrono::system_clock::now(); + auto now_steady = std::chrono::steady_clock::now(); + auto exp_sys = + std::chrono::system_clock::time_point(std::chrono::milliseconds(*exp_ms)); + expires_at_ = now_steady + (exp_sys - now_sys); + } + } + // Note: ScheduleRefresh must be called outside the lock. + ScheduleRefresh(); + return; // Success + } + + // Schedule retry with exponential backoff (non-blocking) + if (attempt + 1 < kMaxRetries) { + auto next_backoff = + std::min(std::chrono::duration_cast(backoff * 2), + kMaxBackoff); + std::weak_ptr weak_self = shared_from_this(); + TokenRefreshScheduler::Instance().Schedule( + backoff, + [weak_self = std::move(weak_self), next_attempt = attempt + 1, next_backoff] { + if (auto self = weak_self.lock()) { + self->DoRefreshAttempt(next_attempt, next_backoff); + } + }); + } + // All retries exhausted — stop refreshing silently. + // Next request will use the expired token; server returns 401. + } + + /// \brief Schedule the next token refresh based on expiration time. + /// + /// Must be called outside any lock on mutex_ (CalculateRefreshDelay + /// acquires shared_lock internally). + void ScheduleRefresh() { + if (!config_.keep_refreshed || closed_.load()) return; + + auto delay = CalculateRefreshDelay(); + if (delay <= std::chrono::milliseconds::zero()) return; + + std::weak_ptr weak_self = shared_from_this(); + auto new_id = TokenRefreshScheduler::Instance().Schedule( + delay, [weak_self = std::move(weak_self)] { + if (auto self = weak_self.lock()) { + self->DoRefresh(); + } + }); + scheduled_task_id_.store(new_id); + } + + std::chrono::milliseconds CalculateRefreshDelay() const { + std::shared_lock lock(mutex_); + auto now = std::chrono::steady_clock::now(); + if (expires_at_ <= now) return std::chrono::milliseconds::zero(); + + auto expires_in = + std::chrono::duration_cast(expires_at_ - now); + // Refresh window: 10% of remaining time, capped at 5 minutes + auto refresh_window = std::min(expires_in / 10, std::chrono::milliseconds(300'000)); + auto wait_time = expires_in - refresh_window; + return std::max(wait_time, std::chrono::milliseconds(10)); + } + + mutable std::shared_mutex mutex_; // protects token_, headers_, expires_at_ + std::string token_; + std::unordered_map headers_; + std::chrono::steady_clock::time_point expires_at_{}; + + Config config_; + HttpClient& client_; + std::atomic scheduled_task_id_{0}; + std::atomic closed_{false}; +}; + } // namespace std::shared_ptr AuthSession::MakeDefault( @@ -52,12 +236,17 @@ std::shared_ptr AuthSession::MakeDefault( } std::shared_ptr AuthSession::MakeOAuth2( - const OAuthTokenResponse& initial_token, const std::string& /*token_endpoint*/, - const std::string& /*client_id*/, const std::string& /*client_secret*/, - const std::string& /*scope*/, HttpClient& /*client*/) { - // TODO(lishuxu): Create OAuth2AuthSession with auto-refresh support. - return MakeDefault({{std::string(kAuthorizationHeader), - std::string(kBearerPrefix) + initial_token.access_token}}); + const OAuthTokenResponse& initial_token, const std::string& token_endpoint, + const std::string& client_id, const std::string& client_secret, + const std::string& scope, HttpClient& client) { + OAuth2AuthSession::Config config{ + .token_endpoint = token_endpoint, + .client_id = client_id, + .client_secret = client_secret, + .scope = scope, + .keep_refreshed = true, + }; + return OAuth2AuthSession::Create(initial_token, std::move(config), client); } } // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/auth/oauth2_util.cc b/src/iceberg/catalog/rest/auth/oauth2_util.cc index 3d209d2bd..3cc7808dc 100644 --- a/src/iceberg/catalog/rest/auth/oauth2_util.cc +++ b/src/iceberg/catalog/rest/auth/oauth2_util.cc @@ -29,6 +29,7 @@ #include "iceberg/catalog/rest/json_serde_internal.h" #include "iceberg/json_serde_internal.h" #include "iceberg/util/macros.h" +#include "iceberg/util/transform_util.h" namespace iceberg::rest::auth { @@ -74,4 +75,48 @@ Result FetchToken(HttpClient& client, AuthSession& session, return token_response; } +std::optional ExpiresAtMillis(std::string_view token) { + if (token.empty()) { + return std::nullopt; + } + + // A JWT has exactly 3 dot-separated parts: header.payload.signature + auto first_dot = token.find('.'); + if (first_dot == std::string_view::npos) { + return std::nullopt; + } + auto second_dot = token.find('.', first_dot + 1); + if (second_dot == std::string_view::npos) { + return std::nullopt; + } + // Ensure there are exactly 3 parts (no additional dots after the signature). + // Note: JWE tokens have 5 segments — they are intentionally not supported here + // and will return nullopt (graceful degradation to not scheduling refresh). + if (token.find('.', second_dot + 1) != std::string_view::npos) { + return std::nullopt; + } + + // Extract and decode the payload (second part). + // Note: Base64UrlDecode returns empty string on both empty input and decode failure. + // A valid JWT payload is never empty (at minimum "{}"), so empty result reliably + // indicates a decode failure here. + std::string_view payload_b64 = token.substr(first_dot + 1, second_dot - first_dot - 1); + std::string payload = TransformUtil::Base64UrlDecode(payload_b64); + if (payload.empty()) { + return std::nullopt; + } + + // Parse JSON and extract "exp" claim + auto json = nlohmann::json::parse(payload, nullptr, false); + if (json.is_discarded() || !json.is_object()) { + return std::nullopt; + } + auto it = json.find("exp"); + if (it == json.end() || !it->is_number()) { + return std::nullopt; + } + auto exp_seconds = static_cast(it->get()); + return exp_seconds * 1000; // Convert seconds to milliseconds +} + } // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/auth/oauth2_util.h b/src/iceberg/catalog/rest/auth/oauth2_util.h index 39dd12964..428ebc385 100644 --- a/src/iceberg/catalog/rest/auth/oauth2_util.h +++ b/src/iceberg/catalog/rest/auth/oauth2_util.h @@ -19,6 +19,8 @@ #pragma once +#include +#include #include #include #include @@ -53,4 +55,14 @@ ICEBERG_REST_EXPORT Result FetchToken( ICEBERG_REST_EXPORT std::unordered_map AuthHeaders( const std::string& token); +/// \brief Extract expiration time from a JWT token. +/// +/// Decodes the JWT payload (base64url) and reads the "exp" claim. +/// Returns std::nullopt if the token is not a valid JWT or has no "exp" claim. +/// +/// \param token A token string. If it is a JWT (three dot-separated base64url +/// segments), the "exp" claim is extracted from the payload. +/// \return Expiration time as milliseconds since epoch, or std::nullopt. +ICEBERG_REST_EXPORT std::optional ExpiresAtMillis(std::string_view token); + } // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/auth/token_refresh_scheduler.cc b/src/iceberg/catalog/rest/auth/token_refresh_scheduler.cc new file mode 100644 index 000000000..ac23dbbe0 --- /dev/null +++ b/src/iceberg/catalog/rest/auth/token_refresh_scheduler.cc @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/catalog/rest/auth/token_refresh_scheduler.h" + +#include + +namespace iceberg::rest::auth { + +TokenRefreshScheduler& TokenRefreshScheduler::Instance() { + static TokenRefreshScheduler instance; + return instance; +} + +TokenRefreshScheduler::TokenRefreshScheduler() : worker_([this] { Run(); }) {} + +TokenRefreshScheduler::~TokenRefreshScheduler() { Shutdown(); } + +uint64_t TokenRefreshScheduler::Schedule(std::chrono::milliseconds delay, + std::function callback) { + std::lock_guard lock(mutex_); + if (shutdown_) { + return 0; + } + uint64_t id = next_id_++; + tasks_.push_back(Task{.id = id, + .fire_at = std::chrono::steady_clock::now() + delay, + .callback = std::move(callback)}); + cv_.notify_one(); + return id; +} + +void TokenRefreshScheduler::Cancel(uint64_t handle) { + if (handle == 0) return; + std::lock_guard lock(mutex_); + std::erase_if(tasks_, [handle](const Task& t) { return t.id == handle; }); +} + +void TokenRefreshScheduler::Shutdown() { + { + std::lock_guard lock(mutex_); + if (shutdown_) return; + shutdown_ = true; + tasks_.clear(); + } + cv_.notify_one(); + if (worker_.joinable()) { + worker_.join(); + } +} + +void TokenRefreshScheduler::Run() { + while (true) { + std::function callback; + + { + std::unique_lock lock(mutex_); + + if (tasks_.empty() && !shutdown_) { + // Wait until a task is added or shutdown is requested + cv_.wait(lock, [this] { return !tasks_.empty() || shutdown_; }); + } + + if (shutdown_) break; + if (tasks_.empty()) continue; + + // Find the task with the earliest fire_at + auto earliest_it = std::min_element( + tasks_.begin(), tasks_.end(), + [](const Task& a, const Task& b) { return a.fire_at < b.fire_at; }); + + auto fire_at = earliest_it->fire_at; + auto target_id = earliest_it->id; + + // Wait until fire_at or until woken (new task, cancel, or shutdown). + // Note: The predicate does O(n) scan on each spurious wakeup. This is + // acceptable for the expected task count (< 10). If task count grows + // significantly, consider replacing vector with a priority queue. + cv_.wait_until(lock, fire_at, [&] { + // Wake up if: shutdown, task list changed, or time is up + if (shutdown_) return true; + if (tasks_.empty()) return true; + // Check if the earliest task has changed (new task added or cancelled) + auto new_earliest = std::min_element( + tasks_.begin(), tasks_.end(), + [](const Task& a, const Task& b) { return a.fire_at < b.fire_at; }); + return new_earliest->id != target_id; + }); + + if (shutdown_) break; + + // If we were woken because the earliest task changed, loop again + auto now = std::chrono::steady_clock::now(); + auto due_it = std::find_if(tasks_.begin(), tasks_.end(), + [now](const Task& t) { return t.fire_at <= now; }); + if (due_it == tasks_.end()) continue; + + callback = std::move(due_it->callback); + tasks_.erase(due_it); + } + + // Execute callback outside the lock + if (callback) { + callback(); + } + } +} + +} // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/auth/token_refresh_scheduler.h b/src/iceberg/catalog/rest/auth/token_refresh_scheduler.h new file mode 100644 index 000000000..f5737e51d --- /dev/null +++ b/src/iceberg/catalog/rest/auth/token_refresh_scheduler.h @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "iceberg/catalog/rest/iceberg_rest_export.h" + +/// \file iceberg/catalog/rest/auth/token_refresh_scheduler.h +/// \brief Global scheduler for OAuth2 token refresh tasks. + +namespace iceberg::rest::auth { + +/// \brief A process-global scheduler for delayed token refresh tasks. +/// +/// Uses a single background thread that sleeps until the next task is due. +/// All OAuth2AuthSession instances share this scheduler. Tasks are lightweight +/// (a single HTTP POST to refresh a token), so one thread is sufficient. +/// +/// Thread safety: All public methods are thread-safe. +class ICEBERG_REST_EXPORT TokenRefreshScheduler { + public: + /// \brief Get the global singleton instance. + /// + /// The instance is created on first access (lazy initialization) and lives + /// until process exit. The background thread is a daemon thread that will + /// be joined on destruction. + static TokenRefreshScheduler& Instance(); + + /// \brief Schedule a callback to run after a delay. + /// + /// \param delay Time to wait before executing the callback. + /// \param callback Function to execute when the delay expires. + /// \return A unique handle that can be used to cancel the task. + uint64_t Schedule(std::chrono::milliseconds delay, std::function callback); + + /// \brief Cancel a previously scheduled task. + /// + /// If the task has already fired or does not exist, this is a no-op. + /// + /// \param handle The handle returned by Schedule(). + void Cancel(uint64_t handle); + + /// \brief Shutdown the scheduler, cancelling all pending tasks. + /// + /// After shutdown, Schedule() calls are no-ops (return 0). + /// This is called automatically on destruction. + /// + /// WARNING: Do not call this on the global Instance() unless you intend to + /// permanently stop all token refresh for the entire process. This is mainly + /// useful for testing with locally-constructed scheduler instances. + void Shutdown(); + + ~TokenRefreshScheduler(); + + // Non-copyable, non-movable + TokenRefreshScheduler(const TokenRefreshScheduler&) = delete; + TokenRefreshScheduler& operator=(const TokenRefreshScheduler&) = delete; + TokenRefreshScheduler(TokenRefreshScheduler&&) = delete; + TokenRefreshScheduler& operator=(TokenRefreshScheduler&&) = delete; + + /// \brief Construct a scheduler (prefer Instance() for production use). + /// + /// This constructor is public to allow testing with isolated instances. + /// In production code, use Instance() to get the global singleton. + TokenRefreshScheduler(); + + private: + /// \brief Worker loop that processes tasks. + void Run(); + + struct Task { + uint64_t id; + std::chrono::steady_clock::time_point fire_at; + std::function callback; + }; + + std::mutex mutex_; + std::condition_variable cv_; + std::vector tasks_; + uint64_t next_id_ = 1; // 0 is reserved as "invalid handle" + bool shutdown_ = false; + std::thread worker_; +}; + +} // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/meson.build b/src/iceberg/catalog/rest/meson.build index a1f8ce973..204b5bc78 100644 --- a/src/iceberg/catalog/rest/meson.build +++ b/src/iceberg/catalog/rest/meson.build @@ -21,6 +21,7 @@ iceberg_rest_sources = files( 'auth/auth_properties.cc', 'auth/auth_session.cc', 'auth/oauth2_util.cc', + 'auth/token_refresh_scheduler.cc', 'catalog_properties.cc', 'endpoint.cc', 'error_handlers.cc', @@ -87,6 +88,7 @@ install_headers( 'auth/auth_properties.h', 'auth/auth_session.h', 'auth/oauth2_util.h', + 'auth/token_refresh_scheduler.h', ], subdir: 'iceberg/catalog/rest/auth', ) diff --git a/src/iceberg/test/auth_manager_test.cc b/src/iceberg/test/auth_manager_test.cc index bd06fee3f..6e89e3872 100644 --- a/src/iceberg/test/auth_manager_test.cc +++ b/src/iceberg/test/auth_manager_test.cc @@ -358,4 +358,359 @@ TEST_F(AuthManagerTest, OAuthTokenResponseNATokenType) { EXPECT_EQ(result->token_type, "N_A"); } +// ---- ExpiresAtMillis tests ---- + +// Helper: build a minimal JWT with a given payload JSON string. +// JWT = base64url(header) + "." + base64url(payload) + "." + base64url(signature) +namespace { + +// Base64url encode (no padding) for test token construction +std::string Base64UrlEncode(std::string_view input) { + static constexpr std::string_view kChars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + std::string output; + uint32_t buffer = 0; + int bits = 0; + for (uint8_t c : input) { + buffer = (buffer << 8) | c; + bits += 8; + while (bits >= 6) { + bits -= 6; + output.push_back(kChars[(buffer >> bits) & 0x3F]); + } + } + if (bits > 0) { + output.push_back(kChars[(buffer << (6 - bits)) & 0x3F]); + } + return output; +} + +std::string MakeJwt(const std::string& payload_json) { + std::string header = R"({"alg":"HS256","typ":"JWT"})"; + std::string signature = "test-signature"; + return Base64UrlEncode(header) + "." + Base64UrlEncode(payload_json) + "." + + Base64UrlEncode(signature); +} + +} // namespace + +// Verifies ExpiresAtMillis extracts exp claim from a valid JWT +TEST_F(AuthManagerTest, ExpiresAtMillisValidJwt) { + // exp = 1700000000 (seconds since epoch) + std::string token = MakeJwt(R"({"sub":"user","exp":1700000000})"); + auto result = ExpiresAtMillis(token); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), 1700000000LL * 1000); // milliseconds +} + +// Verifies ExpiresAtMillis handles large exp values correctly +TEST_F(AuthManagerTest, ExpiresAtMillisLargeExp) { + std::string token = MakeJwt(R"({"exp":2000000000})"); + auto result = ExpiresAtMillis(token); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), 2000000000LL * 1000); +} + +// Verifies ExpiresAtMillis truncates floating-point exp to integer +TEST_F(AuthManagerTest, ExpiresAtMillisFloatExp) { + std::string token = MakeJwt(R"({"exp":1700000000.5})"); + auto result = ExpiresAtMillis(token); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), 1700000000LL * 1000); // truncated to int +} + +// Verifies ExpiresAtMillis returns nullopt for non-JWT token without dots +TEST_F(AuthManagerTest, ExpiresAtMillisNonJwtNoDots) { + auto result = ExpiresAtMillis("just-a-plain-token"); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt for token with only one dot +TEST_F(AuthManagerTest, ExpiresAtMillisOneDot) { + auto result = ExpiresAtMillis("part1.part2"); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt for token with too many segments +TEST_F(AuthManagerTest, ExpiresAtMillisFourParts) { + auto result = ExpiresAtMillis("a.b.c.d"); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt when JWT has no exp claim +TEST_F(AuthManagerTest, ExpiresAtMillisNoExpClaim) { + std::string token = MakeJwt(R"({"sub":"user","iat":1700000000})"); + auto result = ExpiresAtMillis(token); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt when exp is not a number +TEST_F(AuthManagerTest, ExpiresAtMillisExpNotInteger) { + std::string token = MakeJwt(R"({"exp":"not-a-number"})"); + auto result = ExpiresAtMillis(token); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt for malformed base64 payload +TEST_F(AuthManagerTest, ExpiresAtMillisMalformedBase64) { + // Use invalid base64url characters in the payload part + std::string token = "eyJhbGciOiJIUzI1NiJ9.!!!invalid!!!.signature"; + auto result = ExpiresAtMillis(token); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt for empty token +TEST_F(AuthManagerTest, ExpiresAtMillisEmptyToken) { + auto result = ExpiresAtMillis(""); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt when payload is not valid JSON +TEST_F(AuthManagerTest, ExpiresAtMillisInvalidJson) { + std::string header = R"({"alg":"HS256"})"; + std::string invalid_json = "this is not json"; + std::string token = + Base64UrlEncode(header) + "." + Base64UrlEncode(invalid_json) + "." + "sig"; + auto result = ExpiresAtMillis(token); + EXPECT_FALSE(result.has_value()); +} + +// ---- TokenRefreshScheduler tests ---- + +} // namespace iceberg::rest::auth + +#include +#include +#include +#include + +#include "iceberg/catalog/rest/auth/token_refresh_scheduler.h" +#include "iceberg/catalog/rest/types.h" + +namespace iceberg::rest::auth { + +// Verifies that a scheduled task fires after the specified delay +TEST(TokenRefreshSchedulerTest, ScheduleFiresAfterDelay) { + TokenRefreshScheduler scheduler; + std::atomic fired{false}; + + scheduler.Schedule(std::chrono::milliseconds(50), [&] { fired.store(true); }); + + // Should not have fired immediately + EXPECT_FALSE(fired.load()); + + // Wait enough time for it to fire + std::this_thread::sleep_for(std::chrono::milliseconds(150)); + EXPECT_TRUE(fired.load()); + + scheduler.Shutdown(); +} + +// Verifies that cancelling a task prevents it from executing +TEST(TokenRefreshSchedulerTest, CancelPreventsExecution) { + TokenRefreshScheduler scheduler; + std::atomic fired{false}; + + auto handle = + scheduler.Schedule(std::chrono::milliseconds(100), [&] { fired.store(true); }); + + // Cancel before it fires + scheduler.Cancel(handle); + + // Wait past the scheduled time + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + EXPECT_FALSE(fired.load()); + + scheduler.Shutdown(); +} + +// Verifies that multiple tasks fire in chronological order +TEST(TokenRefreshSchedulerTest, MultipleTasksFireInOrder) { + TokenRefreshScheduler scheduler; + std::vector order; + std::mutex order_mutex; + + auto append = [&](int val) { + std::lock_guard lock(order_mutex); + order.push_back(val); + }; + + // Schedule in reverse order of fire time + scheduler.Schedule(std::chrono::milliseconds(150), [&] { append(3); }); + scheduler.Schedule(std::chrono::milliseconds(50), [&] { append(1); }); + scheduler.Schedule(std::chrono::milliseconds(100), [&] { append(2); }); + + // Wait for all to fire + std::this_thread::sleep_for(std::chrono::milliseconds(300)); + + std::lock_guard lock(order_mutex); + ASSERT_EQ(3u, order.size()); + EXPECT_EQ(1, order[0]); + EXPECT_EQ(2, order[1]); + EXPECT_EQ(3, order[2]); + + scheduler.Shutdown(); +} + +// Verifies that shutdown with pending tasks does not crash +TEST(TokenRefreshSchedulerTest, ShutdownWithPendingTasks) { + TokenRefreshScheduler scheduler; + std::atomic fired{false}; + + scheduler.Schedule(std::chrono::milliseconds(5000), [&] { fired.store(true); }); + + // Shutdown immediately — should not crash and task should not fire + scheduler.Shutdown(); + + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + EXPECT_FALSE(fired.load()); +} + +// Verifies that Schedule after shutdown returns invalid handle (0) +TEST(TokenRefreshSchedulerTest, ScheduleAfterShutdownIsNoop) { + TokenRefreshScheduler scheduler; + scheduler.Shutdown(); + + auto handle = scheduler.Schedule(std::chrono::milliseconds(10), [] {}); + EXPECT_EQ(0u, handle); +} + +// Verifies that cancelling an invalid handle does not crash +TEST(TokenRefreshSchedulerTest, CancelInvalidHandleIsNoop) { + TokenRefreshScheduler scheduler; + // Should not crash + scheduler.Cancel(0); + scheduler.Cancel(999); + scheduler.Shutdown(); +} + +// ---- OAuth2AuthSession integration tests ---- + +// Verifies that MakeOAuth2 creates a session with correct initial Bearer header +TEST(OAuth2AuthSessionTest, InitialTokenIsUsed) { + HttpClient client({}); + OAuthTokenResponse token_response; + token_response.access_token = "initial-token-123"; + token_response.token_type = "bearer"; + token_response.expires_in_secs = 3600; + + // Create session (refresh will fail since there's no real server, but + // initial token should work) + auto session = AuthSession::MakeOAuth2(token_response, "http://localhost/oauth/tokens", + "client_id", "client_secret", "catalog", client); + + std::unordered_map headers; + ASSERT_THAT(session->Authenticate(headers), IsOk()); + EXPECT_EQ(headers["Authorization"], "Bearer initial-token-123"); + + session->Close(); +} + +// Verifies that session without expiration does not schedule refresh +TEST(OAuth2AuthSessionTest, NoExpirationNoRefresh) { + HttpClient client({}); + OAuthTokenResponse token_response; + token_response.access_token = "static-token"; + token_response.token_type = "bearer"; + // No expires_in_secs set — token is not a JWT either + + auto session = AuthSession::MakeOAuth2(token_response, "http://localhost/oauth/tokens", + "id", "secret", "catalog", client); + + std::unordered_map headers; + ASSERT_THAT(session->Authenticate(headers), IsOk()); + EXPECT_EQ(headers["Authorization"], "Bearer static-token"); + + // Wait a bit — no crash, no refresh attempt + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + headers.clear(); + ASSERT_THAT(session->Authenticate(headers), IsOk()); + EXPECT_EQ(headers["Authorization"], "Bearer static-token"); + + session->Close(); +} + +// Verifies that Close prevents further refresh callbacks +TEST(OAuth2AuthSessionTest, CloseStopsRefresh) { + HttpClient client({}); + OAuthTokenResponse token_response; + token_response.access_token = "token-before-close"; + token_response.token_type = "bearer"; + token_response.expires_in_secs = 1; // Expires in 1 second + + auto session = AuthSession::MakeOAuth2(token_response, "http://localhost:9999/tokens", + "id", "secret", "catalog", client); + + // Close immediately — should cancel the scheduled refresh + session->Close(); + + // Wait past expiration + refresh window + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + + // Token should still be the original (no refresh happened) + std::unordered_map headers; + ASSERT_THAT(session->Authenticate(headers), IsOk()); + EXPECT_EQ(headers["Authorization"], "Bearer token-before-close"); +} + +// Verifies that concurrent Authenticate calls are thread-safe +TEST(OAuth2AuthSessionTest, ConcurrentAuthenticate) { + HttpClient client({}); + OAuthTokenResponse token_response; + token_response.access_token = "concurrent-token"; + token_response.token_type = "bearer"; + token_response.expires_in_secs = 3600; + + auto session = AuthSession::MakeOAuth2(token_response, "http://localhost/oauth/tokens", + "id", "secret", "catalog", client); + + // Launch multiple threads calling Authenticate concurrently + std::vector threads; + std::atomic success_count{0}; + + for (int i = 0; i < 10; ++i) { + threads.emplace_back([&] { + for (int j = 0; j < 100; ++j) { + std::unordered_map headers; + auto status = session->Authenticate(headers); + if (status.has_value()) { + success_count.fetch_add(1); + } + } + }); + } + + for (auto& t : threads) { + t.join(); + } + + EXPECT_EQ(1000, success_count.load()); + session->Close(); +} + +// Verifies that session still returns last known token after all refresh retries fail +TEST(OAuth2AuthSessionTest, RefreshFailureKeepsLastToken) { + HttpClient client({}); + OAuthTokenResponse token_response; + token_response.access_token = "original-token"; + token_response.token_type = "bearer"; + token_response.expires_in_secs = 1; // Very short — will trigger refresh soon + + auto session = AuthSession::MakeOAuth2( + token_response, "http://localhost:9999/nonexistent", // Will fail + "id", "secret", "catalog", client); + + // Wait for refresh to be attempted and fail (all retries) + // With non-blocking retries: 200ms + 400ms + 800ms + 1600ms ≈ 3s total + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + + // Session should still return the original token (no crash) + std::unordered_map headers; + ASSERT_THAT(session->Authenticate(headers), IsOk()); + EXPECT_EQ(headers["Authorization"], "Bearer original-token"); + + session->Close(); +} + } // namespace iceberg::rest::auth diff --git a/src/iceberg/test/meson.build b/src/iceberg/test/meson.build index 3e88fc7ea..b5d4d2283 100644 --- a/src/iceberg/test/meson.build +++ b/src/iceberg/test/meson.build @@ -135,6 +135,10 @@ if get_option('rest').enabled() ), 'dependencies': [iceberg_rest_dep], }, + 'polaris_oauth2_integration_test': { + 'sources': files('polaris_oauth2_test.cc'), + 'dependencies': [iceberg_rest_dep], + }, } endif endif diff --git a/src/iceberg/test/transform_util_test.cc b/src/iceberg/test/transform_util_test.cc index 54f36cd07..dd5f0fe38 100644 --- a/src/iceberg/test/transform_util_test.cc +++ b/src/iceberg/test/transform_util_test.cc @@ -159,6 +159,56 @@ TEST(TransformUtilTest, Base64Encode) { EXPECT_EQ("AA==", TransformUtil::Base64Encode({"\x00", 1})); } +TEST(TransformUtilTest, Base64Decode) { + // Empty string + EXPECT_EQ("", TransformUtil::Base64Decode("")); + + // Round-trip with Base64Encode + EXPECT_EQ("a", TransformUtil::Base64Decode("YQ==")); + EXPECT_EQ("ab", TransformUtil::Base64Decode("YWI=")); + EXPECT_EQ("abc", TransformUtil::Base64Decode("YWJj")); + EXPECT_EQ("abcde", TransformUtil::Base64Decode("YWJjZGU=")); + EXPECT_EQ("abcdef", TransformUtil::Base64Decode("YWJjZGVm")); + EXPECT_EQ("hello", TransformUtil::Base64Decode("aGVsbG8=")); + EXPECT_EQ("test string", TransformUtil::Base64Decode("dGVzdCBzdHJpbmc=")); + + // Without padding (should still work) + EXPECT_EQ("a", TransformUtil::Base64Decode("YQ")); + EXPECT_EQ("ab", TransformUtil::Base64Decode("YWI")); + + // Invalid characters return empty + EXPECT_EQ("", TransformUtil::Base64Decode("!!!")); +} + +TEST(TransformUtilTest, Base64UrlDecode) { + // Empty string + EXPECT_EQ("", TransformUtil::Base64UrlDecode("")); + + // Standard cases (same as Base64Decode for alphanumeric) + EXPECT_EQ("hello", TransformUtil::Base64UrlDecode("aGVsbG8")); + EXPECT_EQ("abc", TransformUtil::Base64UrlDecode("YWJj")); + + // URL-safe characters: '-' and '_' instead of '+' and '/' + // "?>" in standard base64 is "Pz4=" (contains '+' and '/') + // In base64url it would use '-' and '_' + // Let's test with a known value: bytes {0xFB, 0xFF, 0xFE} encode to "+//+" in + // standard base64, and "-__-" in base64url + std::string decoded = TransformUtil::Base64UrlDecode("-__-"); + EXPECT_EQ(3u, decoded.size()); + EXPECT_EQ('\xFB', decoded[0]); + EXPECT_EQ('\xFF', decoded[1]); + EXPECT_EQ('\xFE', decoded[2]); + + // Standard base64 chars '+' and '/' should be invalid in base64url + EXPECT_EQ("", TransformUtil::Base64UrlDecode("+//+")); + + // With padding (should handle gracefully) + EXPECT_EQ("hello", TransformUtil::Base64UrlDecode("aGVsbG8=")); + + // Invalid characters return empty + EXPECT_EQ("", TransformUtil::Base64UrlDecode("!!!invalid!!!")); +} + struct ParseRoundTripParam { std::string name; std::string str; diff --git a/src/iceberg/util/transform_util.cc b/src/iceberg/util/transform_util.cc index a9221310e..fb3c4d67e 100644 --- a/src/iceberg/util/transform_util.cc +++ b/src/iceberg/util/transform_util.cc @@ -283,4 +283,76 @@ std::string TransformUtil::Base64Encode(std::string_view str_to_encode) { return encoded; } +namespace { + +// Shared base64 decode logic. The decode table maps ASCII char → 6-bit value. +// 0xFF means invalid character. +std::string Base64DecodeWithTable(std::string_view input, + const std::array& table) { + // Strip trailing padding + while (!input.empty() && input.back() == '=') { + input.remove_suffix(1); + } + if (input.empty()) { + return {}; + } + + std::string output; + output.reserve((input.size() * 3) / 4); + + uint32_t buffer = 0; + int bits_collected = 0; + + for (char c : input) { + uint8_t val = table[static_cast(c)]; + if (val == 0xFF) { + return {}; // Invalid character + } + buffer = (buffer << 6) | val; + bits_collected += 6; + if (bits_collected >= 8) { + bits_collected -= 8; + output.push_back(static_cast((buffer >> bits_collected) & 0xFF)); + } + } + + return output; +} + +// Standard base64 decode table: A-Z=0-25, a-z=26-51, 0-9=52-61, +=62, /=63 +constexpr std::array kBase64DecodeTable = [] { + std::array table{}; + table.fill(0xFF); + for (int i = 0; i < 26; ++i) { + table[static_cast('A' + i)] = static_cast(i); + table[static_cast('a' + i)] = static_cast(26 + i); + } + for (int i = 0; i < 10; ++i) { + table[static_cast('0' + i)] = static_cast(52 + i); + } + table[static_cast('+')] = 62; + table[static_cast('/')] = 63; + return table; +}(); + +// Base64url decode table: same as standard but '-'=62, '_'=63 (RFC 4648 §5) +constexpr std::array kBase64UrlDecodeTable = [] { + auto table = kBase64DecodeTable; + table[static_cast('+')] = 0xFF; // '+' is invalid in base64url + table[static_cast('/')] = 0xFF; // '/' is invalid in base64url + table[static_cast('-')] = 62; + table[static_cast('_')] = 63; + return table; +}(); + +} // namespace + +std::string TransformUtil::Base64Decode(std::string_view encoded) { + return Base64DecodeWithTable(encoded, kBase64DecodeTable); +} + +std::string TransformUtil::Base64UrlDecode(std::string_view encoded) { + return Base64DecodeWithTable(encoded, kBase64UrlDecodeTable); +} + } // namespace iceberg diff --git a/src/iceberg/util/transform_util.h b/src/iceberg/util/transform_util.h index c23d08c8c..b751e6531 100644 --- a/src/iceberg/util/transform_util.h +++ b/src/iceberg/util/transform_util.h @@ -20,6 +20,7 @@ #pragma once #include +#include #include "iceberg/iceberg_export.h" #include "iceberg/result.h" @@ -139,6 +140,19 @@ class ICEBERG_EXPORT TransformUtil { /// \brief Base64 encode a string static std::string Base64Encode(std::string_view str_to_encode); + + /// \brief Base64 decode a string (standard alphabet: +/). + /// + /// Handles optional padding ('='). Returns an empty string if the input + /// contains invalid characters. + static std::string Base64Decode(std::string_view encoded); + + /// \brief Base64url decode a string (URL-safe alphabet: -_). + /// + /// Handles optional padding ('='). Returns an empty string if the input + /// contains invalid characters. This variant uses '-' and '_' instead of + /// '+' and '/' per RFC 4648 §5. + static std::string Base64UrlDecode(std::string_view encoded); }; } // namespace iceberg