2323)
2424
2525from sap_cloud_sdk .agentgateway ._models import MCPTool
26- from sap_cloud_sdk .agentgateway .exceptions import MCPServerNotFoundError
26+ from sap_cloud_sdk .agentgateway ._token_cache import _TokenCache
27+ from sap_cloud_sdk .agentgateway .exceptions import AgentGatewaySDKError , MCPServerNotFoundError
2728
2829logger = logging .getLogger (__name__ )
2930
@@ -145,21 +146,29 @@ def get_ias_fragment_name(tenant_subdomain: str) -> str:
145146
146147async def get_system_auth (
147148 tenant_subdomain : str ,
149+ cache : _TokenCache ,
148150) -> str :
149151 """Get system-scoped auth (Phase 1 - client credentials).
150152
151- Looks up the IAS fragment (subscriber.ias label) and uses it to acquire
152- a client-credentials token via BTP Destination Service.
153+ Checks the token cache first. On a miss, looks up the IAS fragment
154+ (subscriber.ias label) and acquires a client-credentials token via
155+ BTP Destination Service, then caches the result.
153156
154157 Args:
155158 tenant_subdomain: Tenant subdomain for multi-tenant lookup.
159+ cache: Token cache shared across calls.
156160
157161 Returns:
158162 Authorization header value (e.g., "Bearer xxx").
159163
160164 Raises:
161165 MCPServerNotFoundError: If no IAS fragment or auth token is found.
162166 """
167+ cached = cache .get_system_token (tenant_subdomain )
168+ if cached :
169+ logger .debug ("System token cache hit for tenant '%s'" , tenant_subdomain )
170+ return cached
171+
163172 loop = asyncio .get_running_loop ()
164173
165174 def _fetch_system_auth_sync ():
@@ -179,27 +188,48 @@ def _fetch_system_auth_sync():
179188
180189 return _fetch_auth_token (dest_name , tenant_subdomain , options )
181190
182- return await loop .run_in_executor (None , _fetch_system_auth_sync )
191+ auth = await loop .run_in_executor (None , _fetch_system_auth_sync )
192+ expires_at = cache .compute_expires_at_from_bearer (auth )
193+ cache .set_system_token (auth , expires_at , tenant_subdomain )
194+ return auth
183195
184196
185197async def get_user_auth (
186198 mcp_fragment_name : str ,
187199 user_token : str ,
188200 tenant_subdomain : str ,
201+ cache : _TokenCache ,
189202) -> str :
190203 """Get user-scoped auth (Phase 2 - token exchange).
191204
205+ Checks the token cache first. On a miss, exchanges the user token via
206+ BTP Destination Service, then caches the result.
207+
208+ Cache key is scoped to (user_token, mcp_fragment_name, tenant_subdomain)
209+ since different fragments may yield differently-scoped tokens.
210+
192211 Args:
193212 mcp_fragment_name: MCP fragment name for token exchange.
194213 user_token: User's JWT for principal propagation.
195214 tenant_subdomain: Tenant subdomain for multi-tenant lookup.
215+ cache: Token cache shared across calls.
196216
197217 Returns:
198218 Authorization header value with user identity embedded.
199219
200220 Raises:
201221 MCPServerNotFoundError: If no auth token is returned.
202222 """
223+ scope_key = f"{ mcp_fragment_name } |{ tenant_subdomain } "
224+ cached = cache .get_user_token (user_token , scope_key )
225+ if cached :
226+ logger .debug (
227+ "User token cache hit for tenant '%s', fragment '%s'" ,
228+ tenant_subdomain ,
229+ mcp_fragment_name ,
230+ )
231+ return cached
232+
203233 loop = asyncio .get_running_loop ()
204234
205235 def _fetch_user_auth_sync ():
@@ -220,7 +250,10 @@ def _fetch_user_auth_sync():
220250
221251 return _fetch_auth_token (dest_name , tenant_subdomain , options )
222252
223- return await loop .run_in_executor (None , _fetch_user_auth_sync )
253+ auth = await loop .run_in_executor (None , _fetch_user_auth_sync )
254+ expires_at = cache .compute_expires_at_from_bearer (auth )
255+ cache .set_user_token (user_token , auth , expires_at , scope_key )
256+ return auth
224257
225258
226259async def list_server_tools (
@@ -271,13 +304,18 @@ async def list_server_tools(
271304async def get_mcp_tools_lob (
272305 tenant_subdomain : str ,
273306 timeout : float ,
307+ cache : _TokenCache ,
274308) -> list [MCPTool ]:
275309 """List all MCP tools using LoB flow (destination-based).
276310
277311 Uses Phase 1 auth (client-scoped) via BTP Destination Service.
312+ On a 401 from an MCP server, invalidates the cached system token and
313+ retries once before skipping the fragment.
278314
279315 Args:
280316 tenant_subdomain: Tenant subdomain for multi-tenant lookup.
317+ timeout: HTTP timeout in seconds.
318+ cache: Token cache shared across calls.
281319
282320 Returns:
283321 List of MCPTool objects from all MCP servers.
@@ -295,6 +333,11 @@ async def get_mcp_tools_lob(
295333 )
296334 return tools
297335
336+ system_auth = None
337+
338+ async def _refetch_system_auth () -> str :
339+ return await get_system_auth (tenant_subdomain , cache )
340+
298341 for fragment in fragments :
299342 fragment_name = fragment .name
300343 mcp_url = fragment .properties .get ("URL" ) or fragment .properties .get ("url" )
@@ -305,22 +348,36 @@ async def get_mcp_tools_lob(
305348 )
306349 continue
307350
308- try :
309- system_auth = await get_system_auth (tenant_subdomain )
310- server_tools = await list_server_tools (
311- mcp_url , system_auth , fragment_name , timeout
312- )
313- tools .extend (server_tools )
314- logger .debug (
315- "Loaded %d tool(s) from fragment '%s'" ,
316- len (server_tools ),
317- fragment_name ,
318- )
319- except Exception :
320- logger .exception (
321- "Failed to load tools from fragment '%s' — skipping" ,
322- fragment_name ,
323- )
351+ for attempt in (1 , 2 ):
352+ if not system_auth :
353+ # Auth failure here is immediately fatal — same token is needed for
354+ # all fragments, so there is no point continuing.
355+ system_auth = await _refetch_system_auth ()
356+
357+ try :
358+ server_tools = await list_server_tools (
359+ mcp_url , system_auth , fragment_name , timeout
360+ )
361+ tools .extend (server_tools )
362+ logger .debug (
363+ "Loaded %d tool(s) from fragment '%s'" ,
364+ len (server_tools ),
365+ fragment_name ,
366+ )
367+ except Exception as exc :
368+ unwrapped = _unwrap_exception_group (exc )
369+ if _is_unauthorized (unwrapped ) and attempt == 1 :
370+ logger .info (
371+ "401 from '%s' — invalidating cached system token and retrying" ,
372+ fragment_name ,
373+ )
374+ cache .invalidate_system_token (tenant_subdomain )
375+ system_auth = None
376+ continue
377+ logger .exception (
378+ "Failed to load tools from fragment '%s' — skipping" , fragment_name
379+ )
380+ break
324381
325382 logger .info ("Loaded %d MCP tool(s) from %d fragment(s)" , len (tools ), len (fragments ))
326383 return tools
@@ -331,45 +388,90 @@ async def call_mcp_tool_lob(
331388 user_token : str ,
332389 tenant_subdomain : str ,
333390 timeout : float ,
391+ cache : _TokenCache ,
334392 ** kwargs ,
335393) -> str :
336394 """Invoke an MCP tool using LoB flow (destination-based).
337395
338396 Uses Phase 2 auth (user-scoped) via token exchange.
339397 Principal propagation ensures LoB systems see user identity.
398+ On a 401, invalidates the cached user token and retries once.
340399
341400 Args:
342401 tool: MCPTool object (from list_mcp_tools).
343402 user_token: User's JWT for principal propagation.
344403 tenant_subdomain: Tenant subdomain for token exchange.
404+ timeout: HTTP timeout in seconds.
405+ cache: Token cache shared across calls.
345406 **kwargs: Tool input parameters.
346407
347408 Returns:
348409 Tool execution result as string.
349410
350411 Raises:
351412 MCPServerNotFoundError: If destination/auth fails.
413+ AgentGatewaySDKError: If tool invocation fails after 401 retry.
352414 """
353415 if not tool .fragment_name :
354416 raise MCPServerNotFoundError (
355417 f"Tool '{ tool .name } ' missing fragment_name for LoB invocation"
356418 )
357- user_auth = await get_user_auth (tool .fragment_name , user_token , tenant_subdomain )
358419
359- async with httpx .AsyncClient (
360- headers = {"Authorization" : user_auth , "x-correlation-id" : str (uuid .uuid4 ())},
361- timeout = timeout ,
362- ) as http_client :
363- async with streamable_http_client (tool .url , http_client = http_client ) as (
364- read ,
365- write ,
366- _ ,
367- ):
368- async with ClientSession (read , write ) as session :
369- await session .initialize ()
370- result = await session .call_tool (tool .name , kwargs )
371- if not result .content :
372- logger .warning ("Tool '%s' returned empty content" , tool .name )
373- return ""
374- first = result .content [0 ]
375- return str (getattr (first , "text" , "" ))
420+ scope_key = f"{ tool .fragment_name } |{ tenant_subdomain } "
421+ last_exc : Exception | None = None
422+
423+ for attempt in (1 , 2 ):
424+ user_auth = await get_user_auth (
425+ tool .fragment_name , user_token , tenant_subdomain , cache
426+ )
427+ try :
428+ async with httpx .AsyncClient (
429+ headers = {
430+ "Authorization" : user_auth ,
431+ "x-correlation-id" : str (uuid .uuid4 ()),
432+ },
433+ timeout = timeout ,
434+ ) as http_client :
435+ async with streamable_http_client (
436+ tool .url , http_client = http_client
437+ ) as (read , write , _ ):
438+ async with ClientSession (read , write ) as session :
439+ await session .initialize ()
440+ result = await session .call_tool (tool .name , kwargs )
441+ if not result .content :
442+ logger .warning (
443+ "Tool '%s' returned empty content" , tool .name
444+ )
445+ return ""
446+ first = result .content [0 ]
447+ return str (getattr (first , "text" , "" ))
448+ except Exception as exc :
449+ unwrapped = _unwrap_exception_group (exc )
450+ if _is_unauthorized (unwrapped ) and attempt == 1 :
451+ logger .info (
452+ "401 from MCP server for tool '%s' — invalidating cached token and retrying" ,
453+ tool .name ,
454+ )
455+ cache .invalidate_user_token (user_token , scope_key )
456+ last_exc = exc
457+ continue
458+ raise
459+
460+ # Defensive — should not be reachable; second attempt either returns or raises.
461+ raise AgentGatewaySDKError (
462+ f"Tool invocation for '{ tool .name } ' failed after 401 retry: { last_exc } "
463+ )
464+
465+
466+ def _unwrap_exception_group (exc : BaseException ) -> BaseException :
467+ """Unwrap nested ExceptionGroups to find the underlying cause."""
468+ while isinstance (exc , BaseExceptionGroup ) and exc .exceptions :
469+ exc = exc .exceptions [0 ]
470+ return exc
471+
472+
473+ def _is_unauthorized (exc : BaseException ) -> bool :
474+ """Detect a 401 response from the MCP server (httpx-based)."""
475+ if isinstance (exc , httpx .HTTPStatusError ):
476+ return exc .response is not None and exc .response .status_code == 401
477+ return False
0 commit comments