From a79bcf2c2227dea732b9ee42e9a53bc55abec590 Mon Sep 17 00:00:00 2001 From: I763688 Date: Mon, 1 Jun 2026 12:00:43 +0100 Subject: [PATCH 1/5] feat: Add client-owned token cache --- src/sap_cloud_sdk/agentgateway/_customer.py | 61 +++- src/sap_cloud_sdk/agentgateway/_lob.py | 78 ++++- .../agentgateway/_token_cache.py | 276 +++++++++++++++ src/sap_cloud_sdk/agentgateway/agw_client.py | 18 +- src/sap_cloud_sdk/agentgateway/config.py | 14 + src/sap_cloud_sdk/agentgateway/user-guide.md | 33 +- tests/agentgateway/unit/test_agw_client.py | 52 ++- tests/agentgateway/unit/test_config.py | 7 +- tests/agentgateway/unit/test_customer.py | 81 +++++ tests/agentgateway/unit/test_lob.py | 68 ++++ tests/agentgateway/unit/test_token_cache.py | 328 ++++++++++++++++++ 11 files changed, 994 insertions(+), 22 deletions(-) create mode 100644 src/sap_cloud_sdk/agentgateway/_token_cache.py create mode 100644 tests/agentgateway/unit/test_token_cache.py diff --git a/src/sap_cloud_sdk/agentgateway/_customer.py b/src/sap_cloud_sdk/agentgateway/_customer.py index eb8a364b..be5019ea 100644 --- a/src/sap_cloud_sdk/agentgateway/_customer.py +++ b/src/sap_cloud_sdk/agentgateway/_customer.py @@ -24,6 +24,7 @@ IntegrationDependency, MCPTool, ) +from sap_cloud_sdk.agentgateway._token_cache import _TokenCache from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError logger = logging.getLogger(__name__) @@ -42,6 +43,11 @@ _GRANT_TYPE_JWT_BEARER = "urn:ietf:params:oauth:grant-type:jwt-bearer" +def _cache_scope_key(credentials: CustomerCredentials, app_tid: str | None) -> str: + """Build a cache scope key for customer-flow tokens.""" + return f"customer::{credentials.client_id}::{app_tid or ''}" + + class _CredentialFields: """Field names in the credentials JSON file.""" @@ -212,17 +218,18 @@ def _request_token_mtls( timeout: float, app_tid: str | None = None, extra_data: dict | None = None, -) -> str: +) -> dict: """Make mTLS token request to IAS. Args: credentials: Customer credentials with certificate and private key. grant_type: OAuth2 grant type. + timeout: HTTP timeout in seconds. app_tid: BTP Application Tenant ID of subscriber (optional). extra_data: Additional form data for the token request. Returns: - Access token string. + Token response payload. Raises: AgentGatewaySDKError: If token request fails. @@ -282,7 +289,7 @@ def _request_token_mtls( ) logger.debug("Token acquired successfully (length: %d)", len(access_token)) - return access_token + return token_data except httpx.RequestError as e: raise AgentGatewaySDKError(f"Token request failed: {e}") @@ -292,6 +299,7 @@ def get_system_token_mtls( credentials: CustomerCredentials, timeout: float, app_tid: str | None = None, + token_cache: _TokenCache | None = None, ) -> str: """Get system-scoped token using mTLS client credentials flow. @@ -301,18 +309,36 @@ def get_system_token_mtls( credentials: Customer credentials. timeout: HTTP timeout in seconds. app_tid: BTP Application Tenant ID of subscriber (optional). + token_cache: Optional token cache used to reuse still-valid tokens. Returns: - System-scoped access token. + System-scoped access token, fetched or served from cache. """ + scope_key = _cache_scope_key(credentials, app_tid) + if token_cache: + cached_token = token_cache.get_system_token(scope_key) + if cached_token: + logger.debug("Using cached system token for scope '%s'", scope_key) + return cached_token + logger.info("Acquiring system token via mTLS client credentials") - return _request_token_mtls( + token_data = _request_token_mtls( credentials, grant_type=_GRANT_TYPE_CLIENT_CREDENTIALS, timeout=timeout, app_tid=app_tid, extra_data={"response_type": "token"}, ) + access_token = token_data["access_token"] + + if token_cache: + token_cache.set_system_token( + access_token, + token_cache.compute_expires_at(token_data), + scope_key, + ) + + return access_token def exchange_user_token( @@ -320,6 +346,7 @@ def exchange_user_token( user_token: str, timeout: float, app_tid: str | None = None, + token_cache: _TokenCache | None = None, ) -> str: """Exchange user token for AGW-scoped token using jwt-bearer grant. @@ -331,12 +358,21 @@ def exchange_user_token( user_token: User's JWT token to exchange. timeout: HTTP timeout in seconds. app_tid: BTP Application Tenant ID of subscriber (optional). + token_cache: Optional token cache used to reuse still-valid exchanged + tokens. Returns: - AGW-scoped access token with user identity. + AGW-scoped access token with user identity, fetched or served from cache. """ + scope_key = _cache_scope_key(credentials, app_tid) + if token_cache: + cached_token = token_cache.get_user_token(user_token, scope_key) + if cached_token: + logger.debug("Using cached exchanged user token for scope '%s'", scope_key) + return cached_token + logger.info("Exchanging user token for AGW-scoped token via jwt-bearer grant") - return _request_token_mtls( + token_data = _request_token_mtls( credentials, grant_type=_GRANT_TYPE_JWT_BEARER, timeout=timeout, @@ -346,6 +382,17 @@ def exchange_user_token( "token_format": "jwt", }, ) + access_token = token_data["access_token"] + + if token_cache: + token_cache.set_user_token( + user_token, + access_token, + token_cache.compute_expires_at(token_data), + scope_key, + ) + + return access_token def _build_mcp_url(gateway_url: str, ord_id: str, gt_id: str) -> str: diff --git a/src/sap_cloud_sdk/agentgateway/_lob.py b/src/sap_cloud_sdk/agentgateway/_lob.py index 42258f93..c4944b78 100644 --- a/src/sap_cloud_sdk/agentgateway/_lob.py +++ b/src/sap_cloud_sdk/agentgateway/_lob.py @@ -23,6 +23,7 @@ ) from sap_cloud_sdk.agentgateway._models import MCPTool +from sap_cloud_sdk.agentgateway._token_cache import _GatewayUrlCache, _TokenCache from sap_cloud_sdk.agentgateway.exceptions import MCPServerNotFoundError logger = logging.getLogger(__name__) @@ -38,6 +39,16 @@ _DESTINATION_INSTANCE = "default" +def _system_scope_key(tenant_subdomain: str) -> str: + """Build the cache scope key for tenant-scoped system auth.""" + return f"lob-system::{tenant_subdomain}" + + +def _user_scope_key(tenant_subdomain: str) -> str: + """Build the cache scope key for tenant-scoped user auth.""" + return f"lob-user::{tenant_subdomain}" + + def _ias_dest_name() -> str: """Get IAS destination name based on landscape. @@ -184,6 +195,8 @@ def get_ias_user_fragment_name(tenant_subdomain: str) -> str: async def fetch_system_auth( tenant_subdomain: str, + token_cache: _TokenCache | None = None, + gateway_url_cache: _GatewayUrlCache | None = None, ) -> tuple[str, str]: """Fetch system-scoped auth (Phase 1 - client credentials). @@ -192,13 +205,29 @@ async def fetch_system_auth( Args: tenant_subdomain: Tenant subdomain for multi-tenant lookup. + token_cache: Optional token cache used to reuse still-valid system + tokens. + gateway_url_cache: Optional cache for gateway URLs associated with the + cached system-token scope. Returns: - Tuple of (raw_access_token, gateway_url). + Tuple of `(raw_access_token, gateway_url)`, fetched or served from cache. Raises: MCPServerNotFoundError: If no IAS fragment or auth token is found. """ + scope_key = _system_scope_key(tenant_subdomain) + if (token_cache is None) != (gateway_url_cache is None): + raise ValueError( + "token_cache and gateway_url_cache must both be provided or both be None" + ) + if token_cache and gateway_url_cache is not None: + cached_token = token_cache.get_system_token(scope_key) + cached_gateway_url = gateway_url_cache.get(scope_key) + if cached_token and cached_gateway_url: + logger.debug("Using cached system auth for tenant '%s'", tenant_subdomain) + return cached_token, cached_gateway_url + loop = asyncio.get_running_loop() def _fetch_system_auth_sync(): @@ -218,12 +247,25 @@ def _fetch_system_auth_sync(): return _fetch_auth_token(dest_name, tenant_subdomain, options) - return await loop.run_in_executor(None, _fetch_system_auth_sync) + token, gateway_url = await loop.run_in_executor(None, _fetch_system_auth_sync) + + if token_cache: + token_cache.set_system_token( + token, + token_cache.compute_expires_at_from_bearer(token), + scope_key, + ) + if gateway_url_cache is not None: + gateway_url_cache[scope_key] = gateway_url + + return token, gateway_url async def fetch_user_auth( user_token: str, tenant_subdomain: str, + token_cache: _TokenCache | None = None, + gateway_url_cache: _GatewayUrlCache | None = None, ) -> tuple[str, str]: """Fetch user-scoped auth (Phase 2 - token exchange). @@ -234,13 +276,29 @@ async def fetch_user_auth( Args: user_token: User's JWT for principal propagation. tenant_subdomain: Tenant subdomain for multi-tenant lookup. + token_cache: Optional token cache used to reuse still-valid exchanged + user tokens. + gateway_url_cache: Optional cache for gateway URLs associated with the + cached user-token scope. Returns: - Tuple of (raw_access_token, gateway_url). + Tuple of `(raw_access_token, gateway_url)`, fetched or served from cache. Raises: MCPServerNotFoundError: If no IAS user fragment or auth token is found. """ + scope_key = _user_scope_key(tenant_subdomain) + if (token_cache is None) != (gateway_url_cache is None): + raise ValueError( + "token_cache and gateway_url_cache must both be provided or both be None" + ) + if token_cache and gateway_url_cache is not None: + cached_token = token_cache.get_user_token(user_token, scope_key) + cached_gateway_url = gateway_url_cache.get(scope_key) + if cached_token and cached_gateway_url: + logger.debug("Using cached user auth for tenant '%s'", tenant_subdomain) + return cached_token, cached_gateway_url + loop = asyncio.get_running_loop() def _fetch_user_auth_sync(): @@ -262,7 +320,19 @@ def _fetch_user_auth_sync(): return _fetch_auth_token(dest_name, tenant_subdomain, options) - return await loop.run_in_executor(None, _fetch_user_auth_sync) + token, gateway_url = await loop.run_in_executor(None, _fetch_user_auth_sync) + + if token_cache: + token_cache.set_user_token( + user_token, + token, + token_cache.compute_expires_at_from_bearer(token), + scope_key, + ) + if gateway_url_cache is not None: + gateway_url_cache[scope_key] = gateway_url + + return token, gateway_url async def list_server_tools( diff --git a/src/sap_cloud_sdk/agentgateway/_token_cache.py b/src/sap_cloud_sdk/agentgateway/_token_cache.py new file mode 100644 index 00000000..42397ab5 --- /dev/null +++ b/src/sap_cloud_sdk/agentgateway/_token_cache.py @@ -0,0 +1,276 @@ +"""Token cache for Agent Gateway flows. + +Caches IAS tokens (system + user-exchanged) per client to avoid redundant +token requests during agentic loops. Used by both customer flow (mTLS) and +LoB flow (BTP Destination Service). + +Keying: +- System tokens are keyed by a flow-specific scope key. +- User tokens are keyed by `sha256(user_jwt + "|" + scope_key)[:16]`. + +Thread safety: +Token fetches run in the default `ThreadPoolExecutor` via +`loop.run_in_executor`. CPython GIL makes individual dict / OrderedDict +operations atomic, but compound check-then-set is not. Two concurrent +coroutines for the same key may both miss and both fetch; the race +produces redundant token requests, not corruption. +""" + +import base64 +import hashlib +import json +import logging +import time +from collections import OrderedDict +from dataclasses import dataclass +from datetime import datetime, timezone + +from sap_cloud_sdk.agentgateway.config import ClientConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class _CachedToken: + """A cached token with monotonic expiry.""" + + token: str + expires_at: float # time.monotonic() value + + def is_valid(self) -> bool: + """Return True if the token has not yet reached its monotonic expiry.""" + return time.monotonic() < self.expires_at + + +def _parse_jwt_exp(jwt: str) -> int | None: + """Extract `exp` claim (seconds since epoch) from a JWT without verification. + + Returns None if the JWT is malformed or has no `exp` claim. The result + is used only as a hint for cache TTL — never for security decisions. + """ + try: + parts = jwt.split(".") + if len(parts) < 2: + return None + payload_b64 = parts[1] + payload_b64 += "=" * (-len(payload_b64) % 4) + claims = json.loads(base64.urlsafe_b64decode(payload_b64)) + exp = claims.get("exp") + return int(exp) if exp is not None else None + except (ValueError, KeyError, TypeError, json.JSONDecodeError): + return None + + +def _parse_response_expires_at(expires_at: object) -> float | None: + """Parse a token response `expires_at` value into epoch seconds.""" + if expires_at is None or isinstance(expires_at, bool): + return None + + if isinstance(expires_at, (int, float)): + return float(expires_at) + + if not isinstance(expires_at, str): + return None + + normalized = expires_at.strip() + if not normalized: + return None + + try: + return float(normalized) + except ValueError: + pass + + try: + parsed = datetime.fromisoformat(normalized.replace("Z", "+00:00")) + except ValueError: + return None + + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + + return parsed.timestamp() + + +def _monotonic_expiry_from_epoch(expiry_epoch_seconds: float, buffer: float) -> float: + """Translate a wall-clock expiry into a monotonic deadline.""" + return time.monotonic() + (expiry_epoch_seconds - time.time()) - buffer + + +def _monotonic_expiry_from_ttl(ttl_seconds: float, buffer: float) -> float: + """Translate a TTL into a monotonic deadline.""" + return time.monotonic() + ttl_seconds - buffer + + +def compute_expires_at(token_data: dict, config: ClientConfig) -> float: + """Resolve the cache expiry timestamp (monotonic) for a token response. + + Resolution order: + 1. `expires_at` from the response, minus the buffer. + 2. `expires_in` from the response, minus the buffer. + 3. `exp` claim from `access_token`, minus the buffer. + 4. `exp` claim from `id_token`, minus the buffer. + 5. Config-provided fallback TTL. + """ + buffer = config.token_expiry_buffer_seconds + + expires_at = _parse_response_expires_at(token_data.get("expires_at")) + if expires_at is not None: + return _monotonic_expiry_from_epoch(expires_at, buffer) + + expires_in = token_data.get("expires_in") + if expires_in is not None: + try: + return _monotonic_expiry_from_ttl(float(expires_in), buffer) + except (ValueError, TypeError): + pass + + for token_field in ("access_token", "id_token"): + jwt = token_data.get(token_field) + if not jwt: + continue + + exp = _parse_jwt_exp(jwt) + if exp is not None: + return _monotonic_expiry_from_epoch(float(exp), buffer) + + return time.monotonic() + config.fallback_token_ttl_seconds + + +class _GatewayUrlCache: + """LRU-bounded cache for gateway URLs keyed by scope key. + + URLs are assumed stable for the lifetime of a client instance. Bounded to + avoid unbounded growth in long-lived clients serving many tenants. + """ + + def __init__(self, max_size: int = 64): + self._max_size = max_size + self._cache: OrderedDict[str, str] = OrderedDict() + + def get(self, scope_key: str) -> str | None: + value = self._cache.get(scope_key) + if value is not None: + self._cache.move_to_end(scope_key) + return value + + def __setitem__(self, scope_key: str, url: str) -> None: + self._cache[scope_key] = url + self._cache.move_to_end(scope_key) + while len(self._cache) > self._max_size: + self._cache.popitem(last=False) + + def __contains__(self, scope_key: str) -> bool: + return scope_key in self._cache + + +class _TokenCache: + """Per-client token cache with TTL and LRU eviction. + + Both system and user tokens use OrderedDict for LRU ordering. + """ + + def __init__(self, config: ClientConfig): + """Initialize empty caches bounded by sizes from `config`.""" + self._config = config + self._system_tokens: OrderedDict[str, _CachedToken] = OrderedDict() + self._user_tokens: OrderedDict[str, _CachedToken] = OrderedDict() + + # --- System Token --- + + def get_system_token(self, scope_key: str) -> str | None: + """Return a valid cached system token for `scope_key`, or None.""" + cached = self._system_tokens.get(scope_key) + if cached and cached.is_valid(): + self._system_tokens.move_to_end(scope_key) + return cached.token + if cached: + del self._system_tokens[scope_key] + return None + + def set_system_token(self, token: str, expires_at: float, scope_key: str) -> None: + """Cache a system token under `scope_key`; evict LRU once size exceeds limit.""" + self._system_tokens[scope_key] = _CachedToken( + token=token, expires_at=expires_at + ) + self._system_tokens.move_to_end(scope_key) + while len(self._system_tokens) > self._config.max_system_token_cache_size: + evicted, _ = self._system_tokens.popitem(last=False) + logger.debug("System token cache full — evicted '%s'", evicted) + + def invalidate_system_token(self, scope_key: str) -> None: + """Drop the cached system token for `scope_key` (no-op if absent).""" + if self._system_tokens.pop(scope_key, None): + logger.debug("Invalidated system token (scope_key=%s)", scope_key) + + # --- User Tokens --- + + def get_user_token(self, user_jwt: str, scope_key: str) -> str | None: + """Return a valid cached exchanged token for `(user_jwt, scope_key)`, or None.""" + key = self._hash_key(user_jwt, scope_key) + cached = self._user_tokens.get(key) + if cached and cached.is_valid(): + self._user_tokens.move_to_end(key) + return cached.token + if cached: + del self._user_tokens[key] + return None + + def set_user_token( + self, + user_jwt: str, + token: str, + expires_at: float, + scope_key: str, + ) -> None: + """Cache an exchanged user token; evict LRU once size exceeds limit.""" + key = self._hash_key(user_jwt, scope_key) + self._user_tokens[key] = _CachedToken(token=token, expires_at=expires_at) + self._user_tokens.move_to_end(key) + while len(self._user_tokens) > self._config.max_user_token_cache_size: + evicted, _ = self._user_tokens.popitem(last=False) + logger.debug("User token cache full — evicted '%s'", evicted) + + def invalidate_user_token(self, user_jwt: str, scope_key: str) -> None: + """Drop the cached user token for `(user_jwt, scope_key)` (no-op if absent).""" + key = self._hash_key(user_jwt, scope_key) + if self._user_tokens.pop(key, None): + logger.debug("Invalidated user token (scope_key=%s)", scope_key) + + # --- Utility --- + + def compute_expires_at(self, token_data: dict) -> float: + """Resolve the cache expiry timestamp (monotonic) for a token response.""" + return compute_expires_at(token_data, self._config) + + def compute_expires_at_from_bearer(self, auth_header: str) -> float: + """Resolve the cache expiry timestamp for a bearer auth header string. + + Strips the 'Bearer ' prefix and parses the `exp` claim from the JWT. + Falls back to the config-provided fallback TTL if parsing fails. + """ + buffer = self._config.token_expiry_buffer_seconds + + jwt = ( + auth_header[7:] + if auth_header.lower().startswith("bearer ") + else auth_header + ) + exp = _parse_jwt_exp(jwt) + if exp is not None: + return _monotonic_expiry_from_epoch(float(exp), buffer) + + return time.monotonic() + self._config.fallback_token_ttl_seconds + + # --- Maintenance --- + + def clear(self) -> None: + """Drop all cached tokens. Forces a fresh fetch on next access.""" + self._system_tokens.clear() + self._user_tokens.clear() + + @staticmethod + def _hash_key(user_jwt: str, scope_key: str) -> str: + """Derive a short, stable cache key from `(user_jwt, scope_key)` via sha256.""" + material = f"{user_jwt}|{scope_key}" + return hashlib.sha256(material.encode()).hexdigest()[:16] diff --git a/src/sap_cloud_sdk/agentgateway/agw_client.py b/src/sap_cloud_sdk/agentgateway/agw_client.py index f385118e..3b7a0c64 100644 --- a/src/sap_cloud_sdk/agentgateway/agw_client.py +++ b/src/sap_cloud_sdk/agentgateway/agw_client.py @@ -27,6 +27,7 @@ get_mcp_tools_lob, ) from sap_cloud_sdk.agentgateway._models import AuthResult, MCPTool +from sap_cloud_sdk.agentgateway._token_cache import _GatewayUrlCache, _TokenCache from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError from sap_cloud_sdk.core.telemetry import Module, Operation, record_metrics @@ -110,6 +111,8 @@ def __init__( """ self._tenant_subdomain = tenant_subdomain self._config = config or ClientConfig() + self._token_cache = _TokenCache(self._config) + self._gateway_url_cache = _GatewayUrlCache() @staticmethod def _resolve_value( @@ -182,6 +185,7 @@ async def get_system_auth(self, app_tid: str | None = None) -> AuthResult: credentials, self._config.timeout, app_tid, + self._token_cache, ) return AuthResult( access_token=token, @@ -193,7 +197,11 @@ async def get_system_auth(self, app_tid: str | None = None) -> AuthResult: logger.warning("app_tid parameter ignored for LoB agent flow") tenant = self._resolve_tenant_subdomain() - token, gateway_url = await fetch_system_auth(tenant) + token, gateway_url = await fetch_system_auth( + tenant, + token_cache=self._token_cache, + gateway_url_cache=self._gateway_url_cache, + ) return AuthResult(access_token=token, gateway_url=gateway_url) except AgentGatewaySDKError: @@ -255,6 +263,7 @@ async def get_user_auth( resolved_user_token, self._config.timeout, app_tid, + self._token_cache, ) return AuthResult( access_token=token, @@ -266,7 +275,12 @@ async def get_user_auth( logger.warning("app_tid parameter ignored for LoB agent flow") tenant = self._resolve_tenant_subdomain() - token, gateway_url = await fetch_user_auth(resolved_user_token, tenant) + token, gateway_url = await fetch_user_auth( + resolved_user_token, + tenant, + token_cache=self._token_cache, + gateway_url_cache=self._gateway_url_cache, + ) return AuthResult(access_token=token, gateway_url=gateway_url) except AgentGatewaySDKError: diff --git a/src/sap_cloud_sdk/agentgateway/config.py b/src/sap_cloud_sdk/agentgateway/config.py index 427f96b3..96f44b1a 100644 --- a/src/sap_cloud_sdk/agentgateway/config.py +++ b/src/sap_cloud_sdk/agentgateway/config.py @@ -3,6 +3,10 @@ from dataclasses import dataclass DEFAULT_TIMEOUT_SECONDS = 60.0 +DEFAULT_FALLBACK_TOKEN_TTL_SECONDS = 300.0 +DEFAULT_TOKEN_EXPIRY_BUFFER_SECONDS = 30.0 +DEFAULT_MAX_SYSTEM_TOKEN_CACHE_SIZE = 32 +DEFAULT_MAX_USER_TOKEN_CACHE_SIZE = 256 @dataclass @@ -12,6 +16,16 @@ class ClientConfig: Attributes: timeout: HTTP timeout in seconds for token requests and MCP server calls. Defaults to 60 seconds. + fallback_token_ttl_seconds: Fallback cache TTL used when a token + response does not provide expiry metadata. + token_expiry_buffer_seconds: Safety buffer subtracted from explicit + token expiries before a cached token is considered stale. + max_system_token_cache_size: Maximum number of cached system tokens. + max_user_token_cache_size: Maximum number of cached user tokens. """ timeout: float = DEFAULT_TIMEOUT_SECONDS + fallback_token_ttl_seconds: float = DEFAULT_FALLBACK_TOKEN_TTL_SECONDS + token_expiry_buffer_seconds: float = DEFAULT_TOKEN_EXPIRY_BUFFER_SECONDS + max_system_token_cache_size: int = DEFAULT_MAX_SYSTEM_TOKEN_CACHE_SIZE + max_user_token_cache_size: int = DEFAULT_MAX_USER_TOKEN_CACHE_SIZE diff --git a/src/sap_cloud_sdk/agentgateway/user-guide.md b/src/sap_cloud_sdk/agentgateway/user-guide.md index 240c6391..6ad21576 100644 --- a/src/sap_cloud_sdk/agentgateway/user-guide.md +++ b/src/sap_cloud_sdk/agentgateway/user-guide.md @@ -43,9 +43,10 @@ result = await agw_client.call_mcp_tool( LoB agents use BTP Destination Service for credential management. Tools are auto-discovered from destination fragments. ```python -from sap_cloud_sdk.agentgateway import create_client +from sap_cloud_sdk.agentgateway import ClientConfig, create_client -agw_client = create_client(tenant_subdomain="my-tenant") +config = ClientConfig(timeout=30.0) +agw_client = create_client(tenant_subdomain="my-tenant", config=config) # Discover tools (auto-discovered from destination fragments) tools = await agw_client.list_mcp_tools() @@ -99,10 +100,38 @@ The SDK automatically detects the agent type based on the presence of a credenti ```python def create_client( tenant_subdomain: str | Callable[[], str] | None = None, + config: ClientConfig | None = None, ) -> AgentGatewayClient ``` - `tenant_subdomain`: Required for LoB agents, ignored for Customer agents. Can be a string or callable. +- `config`: Optional `ClientConfig` used to control HTTP timeout and in-memory token cache behavior. + +### ClientConfig + +Use `ClientConfig` to tune request timeouts and token cache behavior for a client instance. + +```python +from sap_cloud_sdk.agentgateway import ClientConfig, create_client + +config = ClientConfig( + timeout=30.0, + fallback_token_ttl_seconds=300.0, + token_expiry_buffer_seconds=30.0, + max_system_token_cache_size=32, + max_user_token_cache_size=256, +) + +agw_client = create_client(tenant_subdomain="my-tenant", config=config) +``` + +- `timeout`: HTTP timeout in seconds for token requests and MCP calls. Default: `60.0`. +- `fallback_token_ttl_seconds`: Used when the token response does not include expiry metadata. Default: `300.0`. +- `token_expiry_buffer_seconds`: Safety buffer subtracted from explicit token expiries before a cached token is reused. Default: `30.0`. +- `max_system_token_cache_size`: Maximum number of cached system tokens per client instance. Default: `32`. +- `max_user_token_cache_size`: Maximum number of cached exchanged user tokens per client instance. Default: `256`. + +The SDK keeps token caches per `AgentGatewayClient` instance and reuses valid cached tokens for repeated authentication calls. System and user token caches are bounded independently with least-recently-used eviction. ### AgentGatewayClient diff --git a/tests/agentgateway/unit/test_agw_client.py b/tests/agentgateway/unit/test_agw_client.py index 3521aaed..63d91d17 100644 --- a/tests/agentgateway/unit/test_agw_client.py +++ b/tests/agentgateway/unit/test_agw_client.py @@ -31,6 +31,16 @@ def mock_tool(): ) +def _client_token_cache(client: AgentGatewayClient): + """Access the client-owned token cache for white-box tests.""" + return getattr(client, "_token_cache") + + +def _client_gateway_url_cache(client: AgentGatewayClient): + """Access the client-owned gateway URL cache for white-box tests.""" + return getattr(client, "_gateway_url_cache") + + # ============================================================ # Test: create_client factory # ============================================================ @@ -117,13 +127,19 @@ async def test_lob_flow_returns_auth_result(self): return_value=("raw-system-jwt-token", "https://agw.example.com"), ) as mock_auth: agw_client = create_client(tenant_subdomain="my-tenant") + token_cache = _client_token_cache(agw_client) + gateway_url_cache = _client_gateway_url_cache(agw_client) result = await agw_client.get_system_auth() assert isinstance(result, AuthResult) assert result.access_token == "raw-system-jwt-token" assert result.gateway_url == "https://agw.example.com" - mock_auth.assert_called_once_with("my-tenant") + mock_auth.assert_called_once_with( + "my-tenant", + token_cache=token_cache, + gateway_url_cache=gateway_url_cache, + ) @pytest.mark.asyncio async def test_customer_flow_returns_auth_result(self): @@ -142,6 +158,7 @@ async def test_customer_flow_returns_auth_result(self): mock_load.return_value = mock_creds agw_client = create_client() + token_cache = _client_token_cache(agw_client) result = await agw_client.get_system_auth(app_tid="test-tid") @@ -149,7 +166,9 @@ async def test_customer_flow_returns_auth_result(self): assert result.access_token == "customer-system-token" assert result.gateway_url == "https://agw.customer.com" mock_load.assert_called_once_with("/path/to/credentials") - mock_mtls.assert_called_once_with(mock_creds, 60.0, "test-tid") + mock_mtls.assert_called_once_with( + mock_creds, 60.0, "test-tid", token_cache + ) @pytest.mark.asyncio async def test_missing_tenant_raises_for_lob(self): @@ -176,10 +195,16 @@ async def test_callable_tenant_subdomain(self): ) as mock_auth: get_tenant = lambda: "dynamic-tenant" agw_client = create_client(tenant_subdomain=get_tenant) + token_cache = _client_token_cache(agw_client) + gateway_url_cache = _client_gateway_url_cache(agw_client) await agw_client.get_system_auth() - mock_auth.assert_called_once_with("dynamic-tenant") + mock_auth.assert_called_once_with( + "dynamic-tenant", + token_cache=token_cache, + gateway_url_cache=gateway_url_cache, + ) @pytest.mark.asyncio async def test_wraps_unexpected_errors(self): @@ -218,13 +243,20 @@ async def test_lob_flow_returns_auth_result(self): return_value=("raw-user-jwt-token", "https://agw.example.com"), ) as mock_auth: agw_client = create_client(tenant_subdomain="my-tenant") + token_cache = _client_token_cache(agw_client) + gateway_url_cache = _client_gateway_url_cache(agw_client) result = await agw_client.get_user_auth(user_token="user-jwt") assert isinstance(result, AuthResult) assert result.access_token == "raw-user-jwt-token" assert result.gateway_url == "https://agw.example.com" - mock_auth.assert_called_once_with("user-jwt", "my-tenant") + mock_auth.assert_called_once_with( + "user-jwt", + "my-tenant", + token_cache=token_cache, + gateway_url_cache=gateway_url_cache, + ) @pytest.mark.asyncio async def test_customer_flow_exchanges_token(self): @@ -243,6 +275,7 @@ async def test_customer_flow_exchanges_token(self): mock_load.return_value = mock_creds agw_client = create_client() + token_cache = _client_token_cache(agw_client) result = await agw_client.get_user_auth( user_token="user-jwt", app_tid="test-tid" @@ -252,7 +285,7 @@ async def test_customer_flow_exchanges_token(self): assert result.access_token == "exchanged-token" assert result.gateway_url == "https://agw.customer.com" mock_exchange.assert_called_once_with( - mock_creds, "user-jwt", 60.0, "test-tid" + mock_creds, "user-jwt", 60.0, "test-tid", token_cache ) @pytest.mark.asyncio @@ -280,10 +313,17 @@ async def test_callable_user_token(self): ) as mock_auth: agw_client = create_client(tenant_subdomain="my-tenant") get_token = lambda: "dynamic-user-jwt" + token_cache = _client_token_cache(agw_client) + gateway_url_cache = _client_gateway_url_cache(agw_client) await agw_client.get_user_auth(user_token=get_token) - mock_auth.assert_called_once_with("dynamic-user-jwt", "my-tenant") + mock_auth.assert_called_once_with( + "dynamic-user-jwt", + "my-tenant", + token_cache=token_cache, + gateway_url_cache=gateway_url_cache, + ) @pytest.mark.asyncio async def test_missing_tenant_raises_for_lob(self): diff --git a/tests/agentgateway/unit/test_config.py b/tests/agentgateway/unit/test_config.py index 9499b3a3..98e6f774 100644 --- a/tests/agentgateway/unit/test_config.py +++ b/tests/agentgateway/unit/test_config.py @@ -10,6 +10,10 @@ def test_default_values(self): """ClientConfig has sensible defaults.""" config = ClientConfig() assert config.timeout == 60.0 + assert config.fallback_token_ttl_seconds == 300.0 + assert config.token_expiry_buffer_seconds == 30.0 + assert config.max_system_token_cache_size == 32 + assert config.max_user_token_cache_size == 256 def test_custom_timeout(self): """ClientConfig accepts custom timeout.""" @@ -18,9 +22,10 @@ def test_custom_timeout(self): def test_create_client_with_config(self): """create_client accepts a ClientConfig.""" - config = ClientConfig(timeout=90.0) + config = ClientConfig(timeout=90.0, fallback_token_ttl_seconds=90.0) client = create_client(config=config) assert client._config.timeout == 90.0 + assert client._config.fallback_token_ttl_seconds == 90.0 def test_create_client_without_config_uses_defaults(self): """create_client uses default config when none provided.""" diff --git a/tests/agentgateway/unit/test_customer.py b/tests/agentgateway/unit/test_customer.py index f8db5fe9..4469ab47 100644 --- a/tests/agentgateway/unit/test_customer.py +++ b/tests/agentgateway/unit/test_customer.py @@ -21,6 +21,8 @@ IntegrationDependency, MCPTool, ) +from sap_cloud_sdk.agentgateway._token_cache import _TokenCache +from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError @@ -334,6 +336,60 @@ def test_raises_on_failed_request(self, credentials): with pytest.raises(AgentGatewaySDKError, match="Token request failed"): get_system_token_mtls(credentials, timeout=60.0) + def test_reuses_cached_system_token(self, credentials): + """Reuse cached system token until it expires.""" + token_cache = _TokenCache(ClientConfig()) + + with patch( + "sap_cloud_sdk.agentgateway._customer._request_token_mtls", + return_value={"access_token": "system-token-123", "expires_in": 300}, + ) as mock_request: + first = get_system_token_mtls( + credentials, timeout=60.0, token_cache=token_cache + ) + second = get_system_token_mtls( + credentials, timeout=60.0, token_cache=token_cache + ) + + assert first == "system-token-123" + assert second == "system-token-123" + mock_request.assert_called_once() + + def test_scopes_system_token_cache_by_app_tid(self, credentials): + """Keep app-tid-specific system tokens isolated in the cache.""" + token_cache = _TokenCache(ClientConfig()) + + with patch( + "sap_cloud_sdk.agentgateway._customer._request_token_mtls", + side_effect=[ + {"access_token": "token-tid-1", "expires_in": 300}, + {"access_token": "token-tid-2", "expires_in": 300}, + ], + ) as mock_request: + first = get_system_token_mtls( + credentials, + timeout=60.0, + app_tid="tid-1", + token_cache=token_cache, + ) + second = get_system_token_mtls( + credentials, + timeout=60.0, + app_tid="tid-1", + token_cache=token_cache, + ) + third = get_system_token_mtls( + credentials, + timeout=60.0, + app_tid="tid-2", + token_cache=token_cache, + ) + + assert first == "token-tid-1" + assert second == "token-tid-1" + assert third == "token-tid-2" + assert mock_request.call_count == 2 + # ============================================================ # Test: exchange_user_token @@ -411,6 +467,31 @@ def test_passes_app_tid_when_provided(self, credentials): data = call_args.kwargs.get("data", {}) assert data["app_tid"] == "test-tid" + def test_reuses_cached_user_token(self, credentials): + """Reuse exchanged user token until it expires.""" + token_cache = _TokenCache(ClientConfig()) + + with patch( + "sap_cloud_sdk.agentgateway._customer._request_token_mtls", + return_value={"access_token": "exchanged-token-123", "expires_in": 300}, + ) as mock_request: + first = exchange_user_token( + credentials, + "user-jwt-token", + timeout=60.0, + token_cache=token_cache, + ) + second = exchange_user_token( + credentials, + "user-jwt-token", + timeout=60.0, + token_cache=token_cache, + ) + + assert first == "exchanged-token-123" + assert second == "exchanged-token-123" + mock_request.assert_called_once() + # ============================================================ # Test: get_mcp_tools_customer diff --git a/tests/agentgateway/unit/test_lob.py b/tests/agentgateway/unit/test_lob.py index 8088a4c8..4ee48294 100644 --- a/tests/agentgateway/unit/test_lob.py +++ b/tests/agentgateway/unit/test_lob.py @@ -21,6 +21,8 @@ _IAS_USER_LABEL_VALUE, ) from sap_cloud_sdk.agentgateway._models import MCPTool +from sap_cloud_sdk.agentgateway._token_cache import _TokenCache +from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway.exceptions import MCPServerNotFoundError from sap_cloud_sdk.destination import ConsumptionLevel @@ -347,6 +349,38 @@ async def test_fetches_system_auth(self): ) assert call_args[0][2].fragment_level == ConsumptionLevel.INSTANCE + @pytest.mark.asyncio + async def test_reuses_cached_system_auth(self): + """Reuse tenant-scoped system auth until it expires.""" + token_cache = _TokenCache(ClientConfig()) + gateway_url_cache: dict[str, str] = {} + + with patch.dict(os.environ, {"APPFND_CONHOS_LANDSCAPE": "eu10"}): + with ( + patch( + "sap_cloud_sdk.agentgateway._lob.get_ias_fragment_name", + return_value="sap-managed-runtime-agw-subscriber-ias-abc", + ), + patch( + "sap_cloud_sdk.agentgateway._lob._fetch_auth_token", + return_value=("system-token", "https://agw.example.com"), + ) as mock_fetch, + ): + first = await fetch_system_auth( + "tenant-sub", + token_cache=token_cache, + gateway_url_cache=gateway_url_cache, + ) + second = await fetch_system_auth( + "tenant-sub", + token_cache=token_cache, + gateway_url_cache=gateway_url_cache, + ) + + assert first == ("system-token", "https://agw.example.com") + assert second == ("system-token", "https://agw.example.com") + mock_fetch.assert_called_once() + # ============================================================ # Test: fetch_user_auth @@ -383,6 +417,40 @@ async def test_fetches_user_auth_with_ias_user_fragment(self): assert options.fragment_name == "sap-managed-runtime-agw-subscriber-ias-user-abc" assert options.fragment_level == ConsumptionLevel.INSTANCE + @pytest.mark.asyncio + async def test_reuses_cached_user_auth(self): + """Reuse tenant-scoped user auth until it expires.""" + token_cache = _TokenCache(ClientConfig()) + gateway_url_cache: dict[str, str] = {} + + with patch.dict(os.environ, {"APPFND_CONHOS_LANDSCAPE": "eu10"}): + with ( + patch( + "sap_cloud_sdk.agentgateway._lob.get_ias_user_fragment_name", + return_value="sap-managed-runtime-agw-subscriber-ias-user-abc", + ), + patch( + "sap_cloud_sdk.agentgateway._lob._fetch_auth_token", + return_value=("user-token", "https://agw.example.com"), + ) as mock_fetch, + ): + first = await fetch_user_auth( + "user-jwt", + "tenant-sub", + token_cache=token_cache, + gateway_url_cache=gateway_url_cache, + ) + second = await fetch_user_auth( + "user-jwt", + "tenant-sub", + token_cache=token_cache, + gateway_url_cache=gateway_url_cache, + ) + + assert first == ("user-token", "https://agw.example.com") + assert second == ("user-token", "https://agw.example.com") + mock_fetch.assert_called_once() + # ============================================================ # Test: get_mcp_tools_lob diff --git a/tests/agentgateway/unit/test_token_cache.py b/tests/agentgateway/unit/test_token_cache.py new file mode 100644 index 00000000..f66ed31a --- /dev/null +++ b/tests/agentgateway/unit/test_token_cache.py @@ -0,0 +1,328 @@ +"""Unit tests for token cache helpers with non-trivial logic. + +Cache class behavior is tested through AgentGatewayClient in other files. +Only helper logic and scope-key semantics are exercised directly here. +""" + +import base64 +import json +import time +from unittest.mock import patch + +from sap_cloud_sdk.agentgateway._token_cache import ( + _TokenCache, + _parse_jwt_exp, + compute_expires_at, +) +from sap_cloud_sdk.agentgateway.config import ClientConfig + + +def _make_jwt(claims: dict) -> str: + """Build a non-signed JWT for testing (header.payload.signature).""" + header = base64.urlsafe_b64encode(json.dumps({ + "alg": "none" + }).encode()).rstrip(b"=") + payload = base64.urlsafe_b64encode( + json.dumps(claims).encode()).rstrip(b"=") + return f"{header.decode()}.{payload.decode()}.signature" + + +class TestParseJwtExp: + """Tests for the unverified JWT `exp` claim parser.""" + + def test_extracts_exp(self): + """Extract `exp` claim from a well-formed JWT payload.""" + jwt = _make_jwt({"exp": 1700000000, "iat": 1699996400}) + assert _parse_jwt_exp(jwt) == 1700000000 + + def test_returns_none_when_exp_missing(self): + """Return None when payload has no `exp` claim.""" + jwt = _make_jwt({"iat": 1699996400}) + assert _parse_jwt_exp(jwt) is None + + def test_returns_none_for_malformed_jwt(self): + """Return None for strings that are not valid JWTs.""" + assert _parse_jwt_exp("not-a-jwt") is None + assert _parse_jwt_exp("") is None + assert _parse_jwt_exp("only.two") is None + + def test_returns_none_for_garbage_payload(self): + """Return None when the payload segment is not valid base64 or JSON.""" + assert _parse_jwt_exp("aaa.@@not-base64@@.bbb") is None + + +class TestComputeExpiresAt: + """Tests for cache expiry resolution from token responses.""" + + def test_prefers_response_expires_at(self): + """Use expires_at before other response metadata.""" + cfg = ClientConfig(token_expiry_buffer_seconds=30.0) + token_data = { + "expires_at": "1600", + "expires_in": 999, + "access_token": _make_jwt({"exp": 1900}), + } + + with ( + patch("sap_cloud_sdk.agentgateway._token_cache.time.time", + return_value=1000.0), + patch( + "sap_cloud_sdk.agentgateway._token_cache.time.monotonic", + return_value=50.0, + ), + ): + result = compute_expires_at(token_data, cfg) + + assert result == 620.0 + + def test_uses_expires_in_when_present(self): + """Prefer expires_in when no absolute expiry is present.""" + cfg = ClientConfig(token_expiry_buffer_seconds=15.0) + + with patch( + "sap_cloud_sdk.agentgateway._token_cache.time.monotonic", + return_value=20.0, + ): + result = compute_expires_at({"expires_in": 120}, cfg) + + assert result == 125.0 + + def test_falls_back_to_access_token_exp(self): + """Use access_token exp before id_token or fallback TTL.""" + cfg = ClientConfig(token_expiry_buffer_seconds=30.0) + token_data = { + "access_token": _make_jwt({"exp": 1500}), + "id_token": _make_jwt({"exp": 1700}), + } + + with ( + patch("sap_cloud_sdk.agentgateway._token_cache.time.time", + return_value=1000.0), + patch( + "sap_cloud_sdk.agentgateway._token_cache.time.monotonic", + return_value=75.0, + ), + ): + result = compute_expires_at(token_data, cfg) + + assert result == 545.0 + + def test_falls_back_to_id_token_exp(self): + """Use id_token exp when the access token is opaque.""" + cfg = ClientConfig(token_expiry_buffer_seconds=30.0) + token_data = { + "access_token": "opaque-token", + "id_token": _make_jwt({"exp": 1400}), + } + + with ( + patch("sap_cloud_sdk.agentgateway._token_cache.time.time", + return_value=1000.0), + patch( + "sap_cloud_sdk.agentgateway._token_cache.time.monotonic", + return_value=25.0, + ), + ): + result = compute_expires_at(token_data, cfg) + + assert result == 395.0 + + def test_uses_fallback_when_no_expiry_info(self): + """Use config fallback TTL when no expiry metadata is available.""" + cfg = ClientConfig(fallback_token_ttl_seconds=180.0) + + with patch( + "sap_cloud_sdk.agentgateway._token_cache.time.monotonic", + return_value=10.0, + ): + result = compute_expires_at({"access_token": "opaque"}, cfg) + + assert result == 190.0 + + def test_respects_jwt_exp_even_within_buffer(self): + """Treat JWTs inside the buffer as stale instead of extending them.""" + cfg = ClientConfig(token_expiry_buffer_seconds=30.0) + token_data = {"id_token": _make_jwt({"exp": 1020})} + + with ( + patch("sap_cloud_sdk.agentgateway._token_cache.time.time", + return_value=1000.0), + patch( + "sap_cloud_sdk.agentgateway._token_cache.time.monotonic", + return_value=20.0, + ), + ): + result = compute_expires_at(token_data, cfg) + + assert result == 10.0 + + +class TestComputeExpiresAtFromBearer: + """Tests for cache expiry resolution from a bearer auth header string.""" + + def test_uses_exp_from_bearer_jwt(self): + """Parse exp claim from Bearer JWT and apply buffer.""" + cache = _TokenCache(ClientConfig(token_expiry_buffer_seconds=20.0)) + auth_header = f"Bearer {_make_jwt({'exp': 2000})}" + + with ( + patch("sap_cloud_sdk.agentgateway._token_cache.time.time", + return_value=1250.0), + patch( + "sap_cloud_sdk.agentgateway._token_cache.time.monotonic", + return_value=40.0, + ), + ): + result = cache.compute_expires_at_from_bearer(auth_header) + + assert result == 770.0 + + def test_falls_back_when_no_exp_in_jwt(self): + """Use fallback TTL when JWT has no exp claim.""" + cache = _TokenCache(ClientConfig(fallback_token_ttl_seconds=300.0)) + jwt = _make_jwt({"sub": "user"}) + + with patch( + "sap_cloud_sdk.agentgateway._token_cache.time.monotonic", + return_value=5.0, + ): + result = cache.compute_expires_at_from_bearer(f"Bearer {jwt}") + + assert result == 305.0 + + def test_strips_bearer_prefix_case_insensitively(self): + """Strip 'bearer ' prefix regardless of case.""" + cache = _TokenCache( + ClientConfig(token_expiry_buffer_seconds=60, + fallback_token_ttl_seconds=300)) + future_exp = int(time.time()) + 600 + jwt = _make_jwt({"exp": future_exp}) + result_lower = cache.compute_expires_at_from_bearer(f"bearer {jwt}") + result_upper = cache.compute_expires_at_from_bearer(f"Bearer {jwt}") + assert abs(result_lower - result_upper) < 1 + + +class TestScopeIsolation: + """Tokens are isolated by scope key and user JWT.""" + + def test_system_tokens_isolated_by_scope_key(self): + cache = _TokenCache(ClientConfig()) + expires_at = time.monotonic() + 600 + + cache.set_system_token("token-a", expires_at, "scope-a") + cache.set_system_token("token-b", expires_at, "scope-b") + + assert cache.get_system_token("scope-a") == "token-a" + assert cache.get_system_token("scope-b") == "token-b" + + def test_user_tokens_isolated_by_scope_key(self): + cache = _TokenCache(ClientConfig()) + expires_at = time.monotonic() + 600 + + cache.set_user_token("user-jwt", "token-a", expires_at, "scope-a") + cache.set_user_token("user-jwt", "token-b", expires_at, "scope-b") + + assert cache.get_user_token("user-jwt", "scope-a") == "token-a" + assert cache.get_user_token("user-jwt", "scope-b") == "token-b" + + def test_invalidate_system_token_does_not_affect_other_scopes(self): + cache = _TokenCache(ClientConfig()) + expires_at = time.monotonic() + 600 + + cache.set_system_token("token-a", expires_at, "scope-a") + cache.set_system_token("token-b", expires_at, "scope-b") + cache.invalidate_system_token("scope-a") + + assert cache.get_system_token("scope-a") is None + assert cache.get_system_token("scope-b") == "token-b" + + def test_invalidate_user_token_does_not_affect_other_scopes(self): + cache = _TokenCache(ClientConfig()) + expires_at = time.monotonic() + 600 + + cache.set_user_token("user-jwt", "token-a", expires_at, "scope-a") + cache.set_user_token("user-jwt", "token-b", expires_at, "scope-b") + cache.invalidate_user_token("user-jwt", "scope-a") + + assert cache.get_user_token("user-jwt", "scope-a") is None + assert cache.get_user_token("user-jwt", "scope-b") == "token-b" + + +class TestLruEviction: + """LRU eviction respects max cache size and evicts least-recently-used entry.""" + + def test_system_token_evicts_lru_when_full(self): + cfg = ClientConfig(max_system_token_cache_size=2) + cache = _TokenCache(cfg) + expires_at = time.monotonic() + 600 + + cache.set_system_token("token-a", expires_at, "scope-a") + cache.set_system_token("token-b", expires_at, "scope-b") + # Access scope-a so scope-b becomes LRU + cache.get_system_token("scope-a") + # Adding a third entry should evict scope-b (LRU) + cache.set_system_token("token-c", expires_at, "scope-c") + + assert cache.get_system_token("scope-b") is None + assert cache.get_system_token("scope-a") == "token-a" + assert cache.get_system_token("scope-c") == "token-c" + + def test_user_token_evicts_lru_when_full(self): + cfg = ClientConfig(max_user_token_cache_size=2) + cache = _TokenCache(cfg) + expires_at = time.monotonic() + 600 + + cache.set_user_token("jwt-a", "token-a", expires_at, "scope") + cache.set_user_token("jwt-b", "token-b", expires_at, "scope") + # Access jwt-a so jwt-b becomes LRU + cache.get_user_token("jwt-a", "scope") + # Adding a third entry should evict jwt-b (LRU) + cache.set_user_token("jwt-c", "token-c", expires_at, "scope") + + assert cache.get_user_token("jwt-b", "scope") is None + assert cache.get_user_token("jwt-a", "scope") == "token-a" + assert cache.get_user_token("jwt-c", "scope") == "token-c" + + def test_system_token_never_exceeds_max_size(self): + max_size = 5 + cfg = ClientConfig(max_system_token_cache_size=max_size) + cache = _TokenCache(cfg) + expires_at = time.monotonic() + 600 + + for i in range(max_size + 3): + cache.set_system_token(f"token-{i}", expires_at, f"scope-{i}") + + assert len(cache._system_tokens) == max_size + + def test_user_token_never_exceeds_max_size(self): + max_size = 4 + cfg = ClientConfig(max_user_token_cache_size=max_size) + cache = _TokenCache(cfg) + expires_at = time.monotonic() + 600 + + for i in range(max_size + 3): + cache.set_user_token(f"jwt-{i}", f"token-{i}", expires_at, "scope") + + assert len(cache._user_tokens) == max_size + + +class TestExpiredTokenEviction: + """Expired tokens are removed from the cache on get.""" + + def test_get_system_token_removes_expired_entry(self): + cache = _TokenCache(ClientConfig()) + cache.set_system_token("stale-token", time.monotonic() - 1, "scope-x") + + result = cache.get_system_token("scope-x") + + assert result is None + assert "scope-x" not in cache._system_tokens + + def test_get_user_token_removes_expired_entry(self): + cache = _TokenCache(ClientConfig()) + cache.set_user_token("user-jwt", "stale-token", time.monotonic() - 1, "scope-x") + + result = cache.get_user_token("user-jwt", "scope-x") + + assert result is None + assert len(cache._user_tokens) == 0 From 0b9814b9ab04f5554cf67b629b9ed726e0befb8e Mon Sep 17 00:00:00 2001 From: I763688 Date: Mon, 1 Jun 2026 13:54:17 +0100 Subject: [PATCH 2/5] fix: use correct cache type in tests --- tests/agentgateway/unit/test_lob.py | 30 ++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/agentgateway/unit/test_lob.py b/tests/agentgateway/unit/test_lob.py index 4ee48294..6c605b96 100644 --- a/tests/agentgateway/unit/test_lob.py +++ b/tests/agentgateway/unit/test_lob.py @@ -21,7 +21,7 @@ _IAS_USER_LABEL_VALUE, ) from sap_cloud_sdk.agentgateway._models import MCPTool -from sap_cloud_sdk.agentgateway._token_cache import _TokenCache +from sap_cloud_sdk.agentgateway._token_cache import _GatewayUrlCache, _TokenCache from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway.exceptions import MCPServerNotFoundError from sap_cloud_sdk.destination import ConsumptionLevel @@ -353,7 +353,7 @@ async def test_fetches_system_auth(self): async def test_reuses_cached_system_auth(self): """Reuse tenant-scoped system auth until it expires.""" token_cache = _TokenCache(ClientConfig()) - gateway_url_cache: dict[str, str] = {} + gateway_url_cache = _GatewayUrlCache() with patch.dict(os.environ, {"APPFND_CONHOS_LANDSCAPE": "eu10"}): with ( @@ -381,6 +381,18 @@ async def test_reuses_cached_system_auth(self): assert second == ("system-token", "https://agw.example.com") mock_fetch.assert_called_once() + @pytest.mark.asyncio + async def test_raises_when_only_token_cache_provided(self): + """Raise ValueError when token_cache given without gateway_url_cache.""" + with pytest.raises(ValueError, match="both be provided or both be None"): + await fetch_system_auth("tenant-sub", token_cache=_TokenCache(ClientConfig())) + + @pytest.mark.asyncio + async def test_raises_when_only_gateway_url_cache_provided(self): + """Raise ValueError when gateway_url_cache given without token_cache.""" + with pytest.raises(ValueError, match="both be provided or both be None"): + await fetch_system_auth("tenant-sub", gateway_url_cache=_GatewayUrlCache()) + # ============================================================ # Test: fetch_user_auth @@ -421,7 +433,7 @@ async def test_fetches_user_auth_with_ias_user_fragment(self): async def test_reuses_cached_user_auth(self): """Reuse tenant-scoped user auth until it expires.""" token_cache = _TokenCache(ClientConfig()) - gateway_url_cache: dict[str, str] = {} + gateway_url_cache = _GatewayUrlCache() with patch.dict(os.environ, {"APPFND_CONHOS_LANDSCAPE": "eu10"}): with ( @@ -451,6 +463,18 @@ async def test_reuses_cached_user_auth(self): assert second == ("user-token", "https://agw.example.com") mock_fetch.assert_called_once() + @pytest.mark.asyncio + async def test_raises_when_only_token_cache_provided(self): + """Raise ValueError when token_cache given without gateway_url_cache.""" + with pytest.raises(ValueError, match="both be provided or both be None"): + await fetch_user_auth("user-jwt", "tenant-sub", token_cache=_TokenCache(ClientConfig())) + + @pytest.mark.asyncio + async def test_raises_when_only_gateway_url_cache_provided(self): + """Raise ValueError when gateway_url_cache given without token_cache.""" + with pytest.raises(ValueError, match="both be provided or both be None"): + await fetch_user_auth("user-jwt", "tenant-sub", gateway_url_cache=_GatewayUrlCache()) + # ============================================================ # Test: get_mcp_tools_lob From a16d227083382f30797dca3e369d4e4deaf8efe5 Mon Sep 17 00:00:00 2001 From: I763688 Date: Mon, 1 Jun 2026 13:55:24 +0100 Subject: [PATCH 3/5] fix: add version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index da7a9bf1..b12be5af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sap-cloud-sdk" -version = "0.22.0" +version = "0.23.0" description = "SAP Cloud SDK for Python" readme = "README.md" license = "Apache-2.0" From 2cb225cbc82382990b1dd0d64bfc76785ea6e95a Mon Sep 17 00:00:00 2001 From: I763688 Date: Tue, 2 Jun 2026 08:24:15 +0100 Subject: [PATCH 4/5] chore: remove dead code and add config validation --- .../agentgateway/_token_cache.py | 27 +++++++++---------- src/sap_cloud_sdk/agentgateway/config.py | 7 +++++ 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/sap_cloud_sdk/agentgateway/_token_cache.py b/src/sap_cloud_sdk/agentgateway/_token_cache.py index 42397ab5..9c9b4c8d 100644 --- a/src/sap_cloud_sdk/agentgateway/_token_cache.py +++ b/src/sap_cloud_sdk/agentgateway/_token_cache.py @@ -92,7 +92,8 @@ def _parse_response_expires_at(expires_at: object) -> float | None: return parsed.timestamp() -def _monotonic_expiry_from_epoch(expiry_epoch_seconds: float, buffer: float) -> float: +def _monotonic_expiry_from_epoch(expiry_epoch_seconds: float, + buffer: float) -> float: """Translate a wall-clock expiry into a monotonic deadline.""" return time.monotonic() + (expiry_epoch_seconds - time.time()) - buffer @@ -160,9 +161,6 @@ def __setitem__(self, scope_key: str, url: str) -> None: while len(self._cache) > self._max_size: self._cache.popitem(last=False) - def __contains__(self, scope_key: str) -> bool: - return scope_key in self._cache - class _TokenCache: """Per-client token cache with TTL and LRU eviction. @@ -188,13 +186,14 @@ def get_system_token(self, scope_key: str) -> str | None: del self._system_tokens[scope_key] return None - def set_system_token(self, token: str, expires_at: float, scope_key: str) -> None: + def set_system_token(self, token: str, expires_at: float, + scope_key: str) -> None: """Cache a system token under `scope_key`; evict LRU once size exceeds limit.""" - self._system_tokens[scope_key] = _CachedToken( - token=token, expires_at=expires_at - ) + self._system_tokens[scope_key] = _CachedToken(token=token, + expires_at=expires_at) self._system_tokens.move_to_end(scope_key) - while len(self._system_tokens) > self._config.max_system_token_cache_size: + while len(self._system_tokens + ) > self._config.max_system_token_cache_size: evicted, _ = self._system_tokens.popitem(last=False) logger.debug("System token cache full — evicted '%s'", evicted) @@ -225,7 +224,8 @@ def set_user_token( ) -> None: """Cache an exchanged user token; evict LRU once size exceeds limit.""" key = self._hash_key(user_jwt, scope_key) - self._user_tokens[key] = _CachedToken(token=token, expires_at=expires_at) + self._user_tokens[key] = _CachedToken(token=token, + expires_at=expires_at) self._user_tokens.move_to_end(key) while len(self._user_tokens) > self._config.max_user_token_cache_size: evicted, _ = self._user_tokens.popitem(last=False) @@ -251,11 +251,8 @@ def compute_expires_at_from_bearer(self, auth_header: str) -> float: """ buffer = self._config.token_expiry_buffer_seconds - jwt = ( - auth_header[7:] - if auth_header.lower().startswith("bearer ") - else auth_header - ) + jwt = (auth_header[7:] + if auth_header.lower().startswith("bearer ") else auth_header) exp = _parse_jwt_exp(jwt) if exp is not None: return _monotonic_expiry_from_epoch(float(exp), buffer) diff --git a/src/sap_cloud_sdk/agentgateway/config.py b/src/sap_cloud_sdk/agentgateway/config.py index 96f44b1a..17495dbd 100644 --- a/src/sap_cloud_sdk/agentgateway/config.py +++ b/src/sap_cloud_sdk/agentgateway/config.py @@ -29,3 +29,10 @@ class ClientConfig: token_expiry_buffer_seconds: float = DEFAULT_TOKEN_EXPIRY_BUFFER_SECONDS max_system_token_cache_size: int = DEFAULT_MAX_SYSTEM_TOKEN_CACHE_SIZE max_user_token_cache_size: int = DEFAULT_MAX_USER_TOKEN_CACHE_SIZE + + def __post_init__(self) -> None: + if self.token_expiry_buffer_seconds >= self.fallback_token_ttl_seconds: + raise ValueError( + f"token_expiry_buffer_seconds ({self.token_expiry_buffer_seconds}) " + f"must be less than fallback_token_ttl_seconds ({self.fallback_token_ttl_seconds})" + ) From f2e5408411628787d5f3072636b38f8c55d7eaba Mon Sep 17 00:00:00 2001 From: I763688 Date: Tue, 2 Jun 2026 13:12:25 +0100 Subject: [PATCH 5/5] fix: format --- .../agentgateway/_token_cache.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/sap_cloud_sdk/agentgateway/_token_cache.py b/src/sap_cloud_sdk/agentgateway/_token_cache.py index 9c9b4c8d..7ea3a49b 100644 --- a/src/sap_cloud_sdk/agentgateway/_token_cache.py +++ b/src/sap_cloud_sdk/agentgateway/_token_cache.py @@ -92,8 +92,7 @@ def _parse_response_expires_at(expires_at: object) -> float | None: return parsed.timestamp() -def _monotonic_expiry_from_epoch(expiry_epoch_seconds: float, - buffer: float) -> float: +def _monotonic_expiry_from_epoch(expiry_epoch_seconds: float, buffer: float) -> float: """Translate a wall-clock expiry into a monotonic deadline.""" return time.monotonic() + (expiry_epoch_seconds - time.time()) - buffer @@ -186,14 +185,13 @@ def get_system_token(self, scope_key: str) -> str | None: del self._system_tokens[scope_key] return None - def set_system_token(self, token: str, expires_at: float, - scope_key: str) -> None: + def set_system_token(self, token: str, expires_at: float, scope_key: str) -> None: """Cache a system token under `scope_key`; evict LRU once size exceeds limit.""" - self._system_tokens[scope_key] = _CachedToken(token=token, - expires_at=expires_at) + self._system_tokens[scope_key] = _CachedToken( + token=token, expires_at=expires_at + ) self._system_tokens.move_to_end(scope_key) - while len(self._system_tokens - ) > self._config.max_system_token_cache_size: + while len(self._system_tokens) > self._config.max_system_token_cache_size: evicted, _ = self._system_tokens.popitem(last=False) logger.debug("System token cache full — evicted '%s'", evicted) @@ -224,8 +222,7 @@ def set_user_token( ) -> None: """Cache an exchanged user token; evict LRU once size exceeds limit.""" key = self._hash_key(user_jwt, scope_key) - self._user_tokens[key] = _CachedToken(token=token, - expires_at=expires_at) + self._user_tokens[key] = _CachedToken(token=token, expires_at=expires_at) self._user_tokens.move_to_end(key) while len(self._user_tokens) > self._config.max_user_token_cache_size: evicted, _ = self._user_tokens.popitem(last=False) @@ -251,8 +248,11 @@ def compute_expires_at_from_bearer(self, auth_header: str) -> float: """ buffer = self._config.token_expiry_buffer_seconds - jwt = (auth_header[7:] - if auth_header.lower().startswith("bearer ") else auth_header) + jwt = ( + auth_header[7:] + if auth_header.lower().startswith("bearer ") + else auth_header + ) exp = _parse_jwt_exp(jwt) if exp is not None: return _monotonic_expiry_from_epoch(float(exp), buffer)