Skip to content

Commit e83136b

Browse files
committed
feat: Add token cache to Agent Gateway client and LoB flow
1 parent 6f7f760 commit e83136b

6 files changed

Lines changed: 641 additions & 106 deletions

File tree

src/sap_cloud_sdk/agentgateway/_lob.py

Lines changed: 141 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
)
2424

2525
from sap_cloud_sdk.agentgateway._models import MCPTool
26-
from sap_cloud_sdk.agentgateway.exceptions import MCPServerNotFoundError
26+
from sap_cloud_sdk.agentgateway._token_cache import _TokenCache
27+
from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError, MCPServerNotFoundError
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -145,21 +146,29 @@ def get_ias_fragment_name(tenant_subdomain: str) -> str:
145146

146147
async def get_system_auth(
147148
tenant_subdomain: str,
149+
cache: _TokenCache,
148150
) -> str:
149151
"""Get system-scoped auth (Phase 1 - client credentials).
150152
151-
Looks up the IAS fragment (subscriber.ias label) and uses it to acquire
152-
a client-credentials token via BTP Destination Service.
153+
Checks the token cache first. On a miss, looks up the IAS fragment
154+
(subscriber.ias label) and acquires a client-credentials token via
155+
BTP Destination Service, then caches the result.
153156
154157
Args:
155158
tenant_subdomain: Tenant subdomain for multi-tenant lookup.
159+
cache: Token cache shared across calls.
156160
157161
Returns:
158162
Authorization header value (e.g., "Bearer xxx").
159163
160164
Raises:
161165
MCPServerNotFoundError: If no IAS fragment or auth token is found.
162166
"""
167+
cached = cache.get_system_token(tenant_subdomain)
168+
if cached:
169+
logger.debug("System token cache hit for tenant '%s'", tenant_subdomain)
170+
return cached
171+
163172
loop = asyncio.get_running_loop()
164173

165174
def _fetch_system_auth_sync():
@@ -179,27 +188,48 @@ def _fetch_system_auth_sync():
179188

180189
return _fetch_auth_token(dest_name, tenant_subdomain, options)
181190

182-
return await loop.run_in_executor(None, _fetch_system_auth_sync)
191+
auth = await loop.run_in_executor(None, _fetch_system_auth_sync)
192+
expires_at = cache.compute_expires_at_from_bearer(auth)
193+
cache.set_system_token(auth, expires_at, tenant_subdomain)
194+
return auth
183195

184196

185197
async def get_user_auth(
186198
mcp_fragment_name: str,
187199
user_token: str,
188200
tenant_subdomain: str,
201+
cache: _TokenCache,
189202
) -> str:
190203
"""Get user-scoped auth (Phase 2 - token exchange).
191204
205+
Checks the token cache first. On a miss, exchanges the user token via
206+
BTP Destination Service, then caches the result.
207+
208+
Cache key is scoped to (user_token, mcp_fragment_name, tenant_subdomain)
209+
since different fragments may yield differently-scoped tokens.
210+
192211
Args:
193212
mcp_fragment_name: MCP fragment name for token exchange.
194213
user_token: User's JWT for principal propagation.
195214
tenant_subdomain: Tenant subdomain for multi-tenant lookup.
215+
cache: Token cache shared across calls.
196216
197217
Returns:
198218
Authorization header value with user identity embedded.
199219
200220
Raises:
201221
MCPServerNotFoundError: If no auth token is returned.
202222
"""
223+
scope_key = f"{mcp_fragment_name}|{tenant_subdomain}"
224+
cached = cache.get_user_token(user_token, scope_key)
225+
if cached:
226+
logger.debug(
227+
"User token cache hit for tenant '%s', fragment '%s'",
228+
tenant_subdomain,
229+
mcp_fragment_name,
230+
)
231+
return cached
232+
203233
loop = asyncio.get_running_loop()
204234

205235
def _fetch_user_auth_sync():
@@ -220,7 +250,10 @@ def _fetch_user_auth_sync():
220250

221251
return _fetch_auth_token(dest_name, tenant_subdomain, options)
222252

223-
return await loop.run_in_executor(None, _fetch_user_auth_sync)
253+
auth = await loop.run_in_executor(None, _fetch_user_auth_sync)
254+
expires_at = cache.compute_expires_at_from_bearer(auth)
255+
cache.set_user_token(user_token, auth, expires_at, scope_key)
256+
return auth
224257

225258

226259
async def list_server_tools(
@@ -271,13 +304,18 @@ async def list_server_tools(
271304
async def get_mcp_tools_lob(
272305
tenant_subdomain: str,
273306
timeout: float,
307+
cache: _TokenCache,
274308
) -> list[MCPTool]:
275309
"""List all MCP tools using LoB flow (destination-based).
276310
277311
Uses Phase 1 auth (client-scoped) via BTP Destination Service.
312+
On a 401 from an MCP server, invalidates the cached system token and
313+
retries once before skipping the fragment.
278314
279315
Args:
280316
tenant_subdomain: Tenant subdomain for multi-tenant lookup.
317+
timeout: HTTP timeout in seconds.
318+
cache: Token cache shared across calls.
281319
282320
Returns:
283321
List of MCPTool objects from all MCP servers.
@@ -295,6 +333,11 @@ async def get_mcp_tools_lob(
295333
)
296334
return tools
297335

336+
system_auth = None
337+
338+
async def _refetch_system_auth() -> str:
339+
return await get_system_auth(tenant_subdomain, cache)
340+
298341
for fragment in fragments:
299342
fragment_name = fragment.name
300343
mcp_url = fragment.properties.get("URL") or fragment.properties.get("url")
@@ -305,22 +348,36 @@ async def get_mcp_tools_lob(
305348
)
306349
continue
307350

308-
try:
309-
system_auth = await get_system_auth(tenant_subdomain)
310-
server_tools = await list_server_tools(
311-
mcp_url, system_auth, fragment_name, timeout
312-
)
313-
tools.extend(server_tools)
314-
logger.debug(
315-
"Loaded %d tool(s) from fragment '%s'",
316-
len(server_tools),
317-
fragment_name,
318-
)
319-
except Exception:
320-
logger.exception(
321-
"Failed to load tools from fragment '%s' — skipping",
322-
fragment_name,
323-
)
351+
for attempt in (1, 2):
352+
if not system_auth:
353+
# Auth failure here is immediately fatal — same token is needed for
354+
# all fragments, so there is no point continuing.
355+
system_auth = await _refetch_system_auth()
356+
357+
try:
358+
server_tools = await list_server_tools(
359+
mcp_url, system_auth, fragment_name, timeout
360+
)
361+
tools.extend(server_tools)
362+
logger.debug(
363+
"Loaded %d tool(s) from fragment '%s'",
364+
len(server_tools),
365+
fragment_name,
366+
)
367+
except Exception as exc:
368+
unwrapped = _unwrap_exception_group(exc)
369+
if _is_unauthorized(unwrapped) and attempt == 1:
370+
logger.info(
371+
"401 from '%s' — invalidating cached system token and retrying",
372+
fragment_name,
373+
)
374+
cache.invalidate_system_token(tenant_subdomain)
375+
system_auth = None
376+
continue
377+
logger.exception(
378+
"Failed to load tools from fragment '%s' — skipping", fragment_name
379+
)
380+
break
324381

325382
logger.info("Loaded %d MCP tool(s) from %d fragment(s)", len(tools), len(fragments))
326383
return tools
@@ -331,45 +388,90 @@ async def call_mcp_tool_lob(
331388
user_token: str,
332389
tenant_subdomain: str,
333390
timeout: float,
391+
cache: _TokenCache,
334392
**kwargs,
335393
) -> str:
336394
"""Invoke an MCP tool using LoB flow (destination-based).
337395
338396
Uses Phase 2 auth (user-scoped) via token exchange.
339397
Principal propagation ensures LoB systems see user identity.
398+
On a 401, invalidates the cached user token and retries once.
340399
341400
Args:
342401
tool: MCPTool object (from list_mcp_tools).
343402
user_token: User's JWT for principal propagation.
344403
tenant_subdomain: Tenant subdomain for token exchange.
404+
timeout: HTTP timeout in seconds.
405+
cache: Token cache shared across calls.
345406
**kwargs: Tool input parameters.
346407
347408
Returns:
348409
Tool execution result as string.
349410
350411
Raises:
351412
MCPServerNotFoundError: If destination/auth fails.
413+
AgentGatewaySDKError: If tool invocation fails after 401 retry.
352414
"""
353415
if not tool.fragment_name:
354416
raise MCPServerNotFoundError(
355417
f"Tool '{tool.name}' missing fragment_name for LoB invocation"
356418
)
357-
user_auth = await get_user_auth(tool.fragment_name, user_token, tenant_subdomain)
358419

359-
async with httpx.AsyncClient(
360-
headers={"Authorization": user_auth, "x-correlation-id": str(uuid.uuid4())},
361-
timeout=timeout,
362-
) as http_client:
363-
async with streamable_http_client(tool.url, http_client=http_client) as (
364-
read,
365-
write,
366-
_,
367-
):
368-
async with ClientSession(read, write) as session:
369-
await session.initialize()
370-
result = await session.call_tool(tool.name, kwargs)
371-
if not result.content:
372-
logger.warning("Tool '%s' returned empty content", tool.name)
373-
return ""
374-
first = result.content[0]
375-
return str(getattr(first, "text", ""))
420+
scope_key = f"{tool.fragment_name}|{tenant_subdomain}"
421+
last_exc: Exception | None = None
422+
423+
for attempt in (1, 2):
424+
user_auth = await get_user_auth(
425+
tool.fragment_name, user_token, tenant_subdomain, cache
426+
)
427+
try:
428+
async with httpx.AsyncClient(
429+
headers={
430+
"Authorization": user_auth,
431+
"x-correlation-id": str(uuid.uuid4()),
432+
},
433+
timeout=timeout,
434+
) as http_client:
435+
async with streamable_http_client(
436+
tool.url, http_client=http_client
437+
) as (read, write, _):
438+
async with ClientSession(read, write) as session:
439+
await session.initialize()
440+
result = await session.call_tool(tool.name, kwargs)
441+
if not result.content:
442+
logger.warning(
443+
"Tool '%s' returned empty content", tool.name
444+
)
445+
return ""
446+
first = result.content[0]
447+
return str(getattr(first, "text", ""))
448+
except Exception as exc:
449+
unwrapped = _unwrap_exception_group(exc)
450+
if _is_unauthorized(unwrapped) and attempt == 1:
451+
logger.info(
452+
"401 from MCP server for tool '%s' — invalidating cached token and retrying",
453+
tool.name,
454+
)
455+
cache.invalidate_user_token(user_token, scope_key)
456+
last_exc = exc
457+
continue
458+
raise
459+
460+
# Defensive — should not be reachable; second attempt either returns or raises.
461+
raise AgentGatewaySDKError(
462+
f"Tool invocation for '{tool.name}' failed after 401 retry: {last_exc}"
463+
)
464+
465+
466+
def _unwrap_exception_group(exc: BaseException) -> BaseException:
467+
"""Unwrap nested ExceptionGroups to find the underlying cause."""
468+
while isinstance(exc, BaseExceptionGroup) and exc.exceptions:
469+
exc = exc.exceptions[0]
470+
return exc
471+
472+
473+
def _is_unauthorized(exc: BaseException) -> bool:
474+
"""Detect a 401 response from the MCP server (httpx-based)."""
475+
if isinstance(exc, httpx.HTTPStatusError):
476+
return exc.response is not None and exc.response.status_code == 401
477+
return False

src/sap_cloud_sdk/agentgateway/_token_cache.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
"""Token cache for Agent Gateway customer flow.
1+
"""Token cache for Agent Gateway flows.
22
33
Caches IAS tokens (system + user-exchanged) per client to avoid redundant
4-
mTLS token requests during agentic loops. LoB flow uses BTP Destination
5-
Service which has its own caching, so this module only serves the customer
6-
flow.
4+
token requests during agentic loops. Used by both customer flow (mTLS) and
5+
LoB flow (BTP Destination Service).
76
87
Keying:
98
- System tokens are keyed by `client_id` (or "_default" when unset).
@@ -61,6 +60,36 @@ def _parse_jwt_exp(jwt: str) -> int | None:
6160
return None
6261

6362

63+
def compute_expires_at(token_data: dict, config: ClientConfig) -> float:
64+
"""Resolve the cache expiry timestamp (monotonic) for a token response.
65+
66+
Resolution order:
67+
1. `expires_in` from the response, minus the buffer.
68+
2. `exp` claim from `id_token` (translated from wall clock to monotonic),
69+
minus the buffer.
70+
3. Config-provided fallback TTL.
71+
"""
72+
now_mono = time.monotonic()
73+
buffer = config.token_expiry_buffer_seconds
74+
75+
expires_in = token_data.get("expires_in")
76+
if expires_in is not None:
77+
try:
78+
return now_mono + int(expires_in) - buffer
79+
except (ValueError, TypeError):
80+
pass
81+
82+
id_token = token_data.get("id_token")
83+
if id_token:
84+
exp = _parse_jwt_exp(id_token)
85+
if exp is not None:
86+
remaining = exp - time.time()
87+
if remaining > buffer:
88+
return now_mono + remaining - buffer
89+
90+
return now_mono + config.fallback_token_ttl_seconds
91+
92+
6493
class _TokenCache:
6594
"""Per-client token cache with TTL and LRU eviction.
6695
@@ -144,31 +173,24 @@ def invalidate_user_token(self, user_jwt: str, client_id: str) -> None:
144173
# --- Utility ---
145174

146175
def compute_expires_at(self, token_data: dict) -> float:
147-
"""Resolve the cache expiry timestamp (monotonic) for a token response.
176+
"""Resolve the cache expiry timestamp (monotonic) for a token response."""
177+
return compute_expires_at(token_data, self._config)
178+
179+
def compute_expires_at_from_bearer(self, auth_header: str) -> float:
180+
"""Resolve the cache expiry timestamp for a bearer auth header string.
148181
149-
Resolution order:
150-
1. `expires_in` from the response, minus the buffer.
151-
2. `exp` claim from `id_token` (translated from wall clock to monotonic),
152-
minus the buffer.
153-
3. Config-provided fallback TTL.
182+
Strips the 'Bearer ' prefix and parses the `exp` claim from the JWT.
183+
Falls back to the config-provided fallback TTL if parsing fails.
154184
"""
155185
now_mono = time.monotonic()
156186
buffer = self._config.token_expiry_buffer_seconds
157187

158-
expires_in = token_data.get("expires_in")
159-
if expires_in is not None:
160-
try:
161-
return now_mono + int(expires_in) - buffer
162-
except (ValueError, TypeError):
163-
pass
164-
165-
id_token = token_data.get("id_token")
166-
if id_token:
167-
exp = _parse_jwt_exp(id_token)
168-
if exp is not None:
169-
remaining = exp - time.time()
170-
if remaining > buffer:
171-
return now_mono + remaining - buffer
188+
jwt = auth_header[7:] if auth_header.lower().startswith("bearer ") else auth_header
189+
exp = _parse_jwt_exp(jwt)
190+
if exp is not None:
191+
remaining = exp - time.time()
192+
if remaining > buffer:
193+
return now_mono + remaining - buffer
172194

173195
return now_mono + self._config.fallback_token_ttl_seconds
174196

0 commit comments

Comments
 (0)