Skip to content
179 changes: 152 additions & 27 deletions src/sap_cloud_sdk/agentgateway/_customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
IntegrationDependency,
Comment thread
NicoleMGomes marked this conversation as resolved.
MCPTool,
)
from sap_cloud_sdk.agentgateway._token_cache import _TokenCache
from sap_cloud_sdk.agentgateway.config import ClientConfig
from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -211,19 +213,24 @@ def _request_token_mtls(
credentials: CustomerCredentials,
grant_type: str,
timeout: float,
cache: _TokenCache,
app_tid: str | None = None,
extra_data: dict | None = None,
) -> str:
) -> tuple[str, float]:
"""Make mTLS token request to IAS.

Args:
credentials: Customer credentials with certificate and private key.
grant_type: OAuth2 grant type.
timeout: HTTP timeout in seconds.
cache: Token cache to calculate expiry (for buffer and fallback TTL).
app_tid: BTP Application Tenant ID of subscriber (optional).
extra_data: Additional form data for the token request.

Returns:
Access token string.
Tuple of (access_token, expires_at) where expires_at is a
time.monotonic() value indicating when the cached token should
be refreshed (already includes the configured buffer).

Raises:
AgentGatewaySDKError: If token request fails.
Expand Down Expand Up @@ -282,8 +289,10 @@ def _request_token_mtls(
f"Token response missing 'access_token'. Keys: {list(token_data.keys())}"
)

expires_at = cache.compute_expires_at(token_data)

logger.debug("Token acquired successfully (length: %d)", len(access_token))
return access_token
return access_token, expires_at

except httpx.RequestError as e:
raise AgentGatewaySDKError(f"Token request failed: {e}")
Expand All @@ -292,61 +301,83 @@ def _request_token_mtls(
def get_system_token_mtls(
credentials: CustomerCredentials,
timeout: float,
cache: _TokenCache,
app_tid: str | None = None,
) -> str:
"""Get system-scoped token using mTLS client credentials flow.

Used for tool discovery where user identity is not needed.
Used for tool discovery where user identity is not needed. Returns
a cached token if still valid; otherwise acquires a fresh one.

Args:
credentials: Customer credentials.
timeout: HTTP timeout in seconds.
cache: Token cache to consult and update.
app_tid: BTP Application Tenant ID of subscriber (optional).

Returns:
System-scoped access token.
"""
cached = cache.get_system_token(credentials.client_id)
if cached:
logger.debug("Using cached system token (client_id=%s)", credentials.client_id)
return cached

logger.info("Acquiring system token via mTLS client credentials")
return _request_token_mtls(
token, expires_at = _request_token_mtls(
credentials,
grant_type=_GRANT_TYPE_CLIENT_CREDENTIALS,
timeout=timeout,
cache=cache,
app_tid=app_tid,
extra_data={"response_type": "token"},
)
cache.set_system_token(token, expires_at, credentials.client_id)
return token


def exchange_user_token(
credentials: CustomerCredentials,
user_token: str,
timeout: float,
cache: _TokenCache,
app_tid: str | None = None,
) -> str:
"""Exchange user token for AGW-scoped token using jwt-bearer grant.

Used for tool invocation where user identity must be preserved
for principal propagation.
for principal propagation. Returns a cached exchanged token if
still valid; otherwise acquires a fresh one.

Args:
credentials: Customer credentials.
user_token: User's JWT token to exchange.
timeout: HTTP timeout in seconds.
cache: Token cache to consult and update.
app_tid: BTP Application Tenant ID of subscriber (optional).

Returns:
AGW-scoped access token with user identity.
"""
cached = cache.get_user_token(user_token, credentials.client_id)
if cached:
logger.debug("Using cached user token (client_id=%s)", credentials.client_id)
return cached

logger.info("Exchanging user token for AGW-scoped token via jwt-bearer grant")
return _request_token_mtls(
token, expires_at = _request_token_mtls(
credentials,
grant_type=_GRANT_TYPE_JWT_BEARER,
timeout=timeout,
cache=cache,
app_tid=app_tid,
extra_data={
"assertion": user_token,
"token_format": "jwt",
},
)
cache.set_user_token(user_token, token, expires_at, credentials.client_id)
return token


def _build_mcp_url(gateway_url: str, ord_id: str, gt_id: str) -> str:
Expand Down Expand Up @@ -433,6 +464,8 @@ async def _list_server_tools(
async def get_mcp_tools_customer(
credentials: CustomerCredentials,
timeout: float,
config: ClientConfig,
cache: _TokenCache,
app_tid: str | None = None,
) -> list[MCPTool]:
"""List all MCP tools from servers defined in credentials.
Expand All @@ -442,6 +475,9 @@ async def get_mcp_tools_customer(

Args:
credentials: Customer credentials with integrationDependencies.
timeout: HTTP timeout in seconds.
config: Client configuration.
cache: Token cache shared across calls.
app_tid: BTP Application Tenant ID of subscriber (optional).

Returns:
Expand All @@ -459,11 +495,16 @@ async def get_mcp_tools_customer(

logger.info("Discovering tools from %d MCP server(s)", len(dependencies))

# Get system token for discovery
loop = asyncio.get_running_loop()
system_token = await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, app_tid
)
system_token = None

# Define a helper closure to refetch the system token on demand, since it may need to be
# refreshed during the discovery loop if any server returns a 401
async def refetch_system_token() -> str:
loop = asyncio.get_running_loop()
new_token = await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, cache, app_tid
)
return new_token

tools: list[MCPTool] = []

Expand All @@ -476,12 +517,28 @@ async def get_mcp_tools_customer(
dep.global_tenant_id,
)

try:
server_tools = await _list_server_tools(url, system_token, dep, timeout)
tools.extend(server_tools)
logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id)
except Exception:
logger.exception("Failed to load tools from %s — skipping", dep.ord_id)
for attempt in (1, 2):
if not system_token:
# won't catch exceptions here - if token acquisition fails,
# we want the discovery to fail immediately
system_token = await refetch_system_token()

try:
server_tools = await _list_server_tools(url, system_token, dep, timeout)
tools.extend(server_tools)
logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id)
except Exception as exc:
unwrapped = _unwrap_exception_group(exc)
if _is_unauthorized(unwrapped) and attempt == 1:
logger.info(
"401 from %s — invalidating cached system token and retrying",
dep.ord_id,
)
cache.invalidate_system_token(credentials.client_id)
system_token = None # Force refetch on next loop iteration
continue
logger.exception("Failed to load tools from %s — skipping", dep.ord_id)
break # Success, move to next server

logger.info(
"Loaded %d MCP tool(s) from %d server(s)", len(tools), len(dependencies)
Expand All @@ -494,6 +551,8 @@ async def call_mcp_tool_customer(
tool: MCPTool,
user_token: str | None,
timeout: float,
config: ClientConfig,
cache: _TokenCache,
app_tid: str | None = None,
**kwargs,
) -> str:
Expand All @@ -502,11 +561,16 @@ async def call_mcp_tool_customer(
If user_token is provided, exchanges it for an AGW-scoped token to preserve
user identity for principal propagation. Otherwise, falls back to system token.

On a 401 from the MCP server, drops the cached token and retries once.

Args:
credentials: Customer credentials.
tool: MCPTool to invoke.
user_token: User's JWT token for principal propagation (optional).
If None, system token is used instead (no principal propagation).
timeout: HTTP timeout in seconds.
config: Client configuration.
cache: Token cache shared across calls.
app_tid: BTP Application Tenant ID of subscriber (optional).
**kwargs: Tool input parameters.

Expand All @@ -517,26 +581,73 @@ async def call_mcp_tool_customer(

loop = asyncio.get_running_loop()

if user_token:
# Exchange user token for AGW-scoped token (with principal propagation)
agw_token = await loop.run_in_executor(
None, exchange_user_token, credentials, user_token, timeout, app_tid
)
else:
async def _acquire_token() -> str:
if user_token:
return await loop.run_in_executor(
None,
exchange_user_token,
credentials,
user_token,
timeout,
cache,
app_tid,
)
# TODO: IBD workaround - use system token when user_token is not available.
# This bypasses principal propagation. Remove this fallback once IBD
# supports proper user token flow.
logger.warning(
"No user_token provided - using system token for tool invocation. "
"Principal propagation will NOT work."
)
agw_token = await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, app_tid
return await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, cache, app_tid
)

def _invalidate_token() -> None:
if user_token:
cache.invalidate_user_token(user_token, credentials.client_id)
else:
cache.invalidate_system_token(credentials.client_id)

last_exc: Exception | None = None
for attempt in (1, 2):
agw_token = await _acquire_token()
try:
return await _invoke_tool(tool, agw_token, timeout, **kwargs)
except Exception as exc:
unwrapped = _unwrap_exception_group(exc)
if _is_unauthorized(unwrapped) and attempt == 1:
logger.info(
"401 from MCP server for tool '%s' — invalidating cached token and retrying",
tool.name,
)
_invalidate_token()
last_exc = exc
continue
raise

# Defensive — should not be reachable; second attempt either returns or raises.
raise AgentGatewaySDKError(
f"Tool invocation for '{tool.name}' failed after 401 retry: {last_exc}"
)


async def _invoke_tool(
tool: MCPTool,
auth_token: str,
timeout: float,
**kwargs,
) -> str:
"""Open an MCP session to `tool.url` and invoke `tool.name` with `kwargs`.

Returns the first content block's text, or empty string when content is
empty. Raises whatever the MCP transport / session raises (notably
`httpx.HTTPStatusError` on 401, which the caller uses to drive cache
invalidation and retry).
"""
async with httpx.AsyncClient(
headers={
"Authorization": f"Bearer {agw_token}",
"Authorization": f"Bearer {auth_token}",
"x-correlation-id": str(uuid.uuid4()),
},
timeout=timeout,
Expand All @@ -556,3 +667,17 @@ async def call_mcp_tool_customer(

first = result.content[0]
return str(getattr(first, "text", ""))


def _unwrap_exception_group(exc: BaseException) -> BaseException:
"""Unwrap nested ExceptionGroups to find the underlying cause."""
while isinstance(exc, BaseExceptionGroup) and exc.exceptions:
exc = exc.exceptions[0]
return exc


def _is_unauthorized(exc: BaseException) -> bool:
"""Detect a 401 response from the MCP server (httpx-based)."""
if isinstance(exc, httpx.HTTPStatusError):
return exc.response is not None and exc.response.status_code == 401
return False
Loading
Loading