diff --git a/README.md b/README.md
index 5e8129c96e..fab9e05ef6 100644
--- a/README.md
+++ b/README.md
@@ -61,6 +61,7 @@
- [Writing MCP Clients](#writing-mcp-clients)
- [Client Display Utilities](#client-display-utilities)
- [OAuth Authentication for Clients](#oauth-authentication-for-clients)
+ - [Enterprise Managed Authorization](#enterprise-managed-authorization)
- [Parsing Tool Results](#parsing-tool-results)
- [MCP Primitives](#mcp-primitives)
- [Server Capabilities](#server-capabilities)
@@ -2356,6 +2357,129 @@ _Full example: [examples/snippets/clients/oauth_client.py](https://github.com/mo
For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/).
+#### Enterprise Managed Authorization
+
+The SDK includes support for Enterprise Managed Authorization (SEP-990), which enables MCP clients to connect to protected servers using enterprise Single Sign-On (SSO) systems. This implementation supports:
+
+- **RFC 8693**: OAuth 2.0 Token Exchange (ID Token → ID-JAG)
+- **RFC 7523**: JSON Web Token (JWT) Profile for OAuth 2.0 Authorization Grants (ID-JAG → Access Token)
+- Integration with enterprise identity providers (Okta, Azure AD, etc.)
+
+**Key Components:**
+
+The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provider to implement the enterprise authorization flow:
+
+**Token Exchange Flow:**
+
+1. **Obtain ID Token** from your enterprise IdP (e.g., Okta, Azure AD)
+2. **Exchange ID Token for ID-JAG** using RFC 8693 Token Exchange
+3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant
+4. **Use Access Token** to call protected MCP server tools
+
+**Example Usage:**
+
+```python
+import asyncio
+import httpx
+from pydantic import AnyUrl
+
+from mcp.client.auth.extensions import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+)
+from mcp.shared.auth import OAuthClientMetadata
+from mcp.client.auth import TokenStorage
+
+# Define token storage implementation
+class SimpleTokenStorage(TokenStorage):
+ def __init__(self):
+ self._tokens = None
+ self._client_info = None
+
+ async def get_tokens(self):
+ return self._tokens
+
+ async def set_tokens(self, tokens):
+ self._tokens = tokens
+
+ async def get_client_info(self):
+ return self._client_info
+
+ async def set_client_info(self, client_info):
+ self._client_info = client_info
+
+async def main():
+ # Step 1: Get ID token from your IdP (example with Okta)
+ id_token = await get_id_token_from_idp() # Your IdP authentication
+
+ # Step 2: Configure token exchange parameters
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://your-idp.com", # IdP issuer URL
+ mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID
+ scope="mcp:tools mcp:resources", # Optional scopes
+ )
+
+ # Step 3: Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ client_name="Enterprise MCP Client",
+ client_id="your-client-id",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Step 4: Perform token exchange and get access token
+ async with httpx.AsyncClient() as client:
+ # Exchange ID token for ID-JAG
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ print(f"Obtained ID-JAG: {id_jag[:50]}...")
+
+ # Exchange ID-JAG for access token
+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(
+ client, id_jag
+ )
+ print(f"Access token obtained, expires in: {access_token.expires_in}s")
+
+if __name__ == "__main__":
+ asyncio.run(main())
+```
+
+**Working with SAML Assertions:**
+
+If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions:
+
+```python
+token_exchange_params = TokenExchangeParameters.from_saml_assertion(
+ saml_assertion=saml_assertion_string,
+ mcp_server_auth_issuer="https://your-idp.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ scope="mcp:tools",
+)
+```
+
+**Decoding and Inspecting ID-JAG Tokens:**
+
+You can decode ID-JAG tokens to inspect their claims:
+
+```python
+from mcp.client.auth.extensions import decode_id_jag
+
+# Decode without signature verification (for inspection only)
+claims = decode_id_jag(id_jag)
+print(f"Subject: {claims.sub}")
+print(f"Issuer: {claims.iss}")
+print(f"Audience: {claims.aud}")
+print(f"Client ID: {claims.client_id}")
+print(f"Resource: {claims.resource}")
+```
+
### Parsing Tool Results
When calling tools through MCP, the `CallToolResult` object contains the tool's response in a structured format. Understanding how to parse this result is essential for properly handling tool outputs.
diff --git a/src/mcp/client/auth/extensions/__init__.py b/src/mcp/client/auth/extensions/__init__.py
index e69de29bb2..56ba368ef8 100644
--- a/src/mcp/client/auth/extensions/__init__.py
+++ b/src/mcp/client/auth/extensions/__init__.py
@@ -0,0 +1,19 @@
+"""MCP Client Auth Extensions."""
+
+from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ IDJAGClaims,
+ TokenExchangeParameters,
+ TokenExchangeResponse,
+ decode_id_jag,
+ validate_token_exchange_params,
+)
+
+__all__ = [
+ "EnterpriseAuthOAuthClientProvider",
+ "IDJAGClaims",
+ "TokenExchangeParameters",
+ "TokenExchangeResponse",
+ "decode_id_jag",
+ "validate_token_exchange_params",
+]
diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py
new file mode 100644
index 0000000000..e55283e968
--- /dev/null
+++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py
@@ -0,0 +1,412 @@
+"""
+Enterprise Managed Authorization extension for MCP (SEP-990).
+
+Implements RFC 8693 Token Exchange and RFC 7523 JWT Bearer Grant for
+enterprise SSO integration.
+"""
+
+import logging
+from typing import Any
+
+import httpx
+from pydantic import BaseModel, Field
+
+from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
+from mcp.shared.auth import OAuthClientMetadata, OAuthToken
+
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Data Models
+# ============================================================================
+
+
+class TokenExchangeParameters(BaseModel):
+ """Parameters for RFC 8693 Token Exchange request."""
+
+ requested_token_type: str = Field(
+ default="urn:ietf:params:oauth:token-type:id-jag",
+ description="Type of token being requested (ID-JAG)",
+ )
+
+ audience: str = Field(
+ ...,
+ description="Issuer URL of the MCP Server's authorization server",
+ )
+
+ resource: str = Field(
+ ...,
+ description="RFC 9728 Resource Identifier of the MCP Server",
+ )
+
+ scope: str | None = Field(
+ default=None,
+ description="Space-separated list of scopes being requested",
+ )
+
+ subject_token: str = Field(
+ ...,
+ description="ID Token or SAML assertion for the end user",
+ )
+
+ subject_token_type: str = Field(
+ ...,
+ description="Type of subject token (id_token or saml2)",
+ )
+
+ @classmethod
+ def from_id_token(
+ cls,
+ id_token: str,
+ mcp_server_auth_issuer: str,
+ mcp_server_resource_id: str,
+ scope: str | None = None,
+ ) -> "TokenExchangeParameters":
+ """Create parameters for OIDC ID Token exchange."""
+ return cls(
+ subject_token=id_token,
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience=mcp_server_auth_issuer,
+ resource=mcp_server_resource_id,
+ scope=scope,
+ )
+
+ @classmethod
+ def from_saml_assertion(
+ cls,
+ saml_assertion: str,
+ mcp_server_auth_issuer: str,
+ mcp_server_resource_id: str,
+ scope: str | None = None,
+ ) -> "TokenExchangeParameters":
+ """Create parameters for SAML assertion exchange."""
+ return cls(
+ subject_token=saml_assertion,
+ subject_token_type="urn:ietf:params:oauth:token-type:saml2",
+ audience=mcp_server_auth_issuer,
+ resource=mcp_server_resource_id,
+ scope=scope,
+ )
+
+
+class TokenExchangeResponse(BaseModel):
+ """Response from RFC 8693 Token Exchange."""
+
+ issued_token_type: str = Field(
+ ...,
+ description="Type of token issued (should be id-jag)",
+ )
+
+ access_token: str = Field(
+ ...,
+ description="The ID-JAG token (named access_token per RFC 8693)",
+ )
+
+ token_type: str = Field(
+ ...,
+ description="Token type (should be N_A for ID-JAG)",
+ )
+
+ scope: str | None = Field(
+ default=None,
+ description="Granted scopes",
+ )
+
+ expires_in: int | None = Field(
+ default=None,
+ description="Lifetime in seconds",
+ )
+
+ @property
+ def id_jag(self) -> str:
+ """Get the ID-JAG token."""
+ return self.access_token
+
+
+class IDJAGClaims(BaseModel):
+ """Claims structure for Identity Assertion JWT Authorization Grant."""
+
+ model_config = {"extra": "allow"}
+
+ # JWT header
+ typ: str = Field(
+ ...,
+ description="JWT type - must be 'oauth-id-jag+jwt'",
+ )
+
+ # Required claims
+ jti: str = Field(..., description="Unique JWT ID")
+ iss: str = Field(..., description="IdP issuer URL")
+ sub: str = Field(..., description="Subject (user) identifier")
+ aud: str = Field(..., description="MCP Server's auth server issuer")
+ resource: str = Field(..., description="MCP Server resource identifier")
+ client_id: str = Field(..., description="MCP Client identifier")
+ exp: int = Field(..., description="Expiration timestamp")
+ iat: int = Field(..., description="Issued-at timestamp")
+
+ # Optional claims
+ scope: str | None = Field(None, description="Space-separated scopes")
+ email: str | None = Field(None, description="User email")
+
+
+class EnterpriseAuthOAuthClientProvider(OAuthClientProvider):
+ """
+ OAuth client provider for Enterprise Managed Authorization (SEP-990).
+
+ Implements:
+ - RFC 8693: Token Exchange (ID Token → ID-JAG)
+ - RFC 7523: JWT Bearer Grant (ID-JAG → Access Token)
+ """
+
+ def __init__(
+ self,
+ server_url: str,
+ client_metadata: OAuthClientMetadata,
+ storage: TokenStorage,
+ idp_token_endpoint: str,
+ token_exchange_params: TokenExchangeParameters,
+ redirect_handler: Any = None,
+ callback_handler: Any = None,
+ timeout: float = 300.0,
+ ) -> None:
+ """
+ Initialize Enterprise Auth OAuth Client.
+
+ Args:
+ server_url: MCP server URL
+ client_metadata: OAuth client metadata
+ storage: Token storage implementation
+ idp_token_endpoint: Enterprise IdP token endpoint URL
+ token_exchange_params: Token exchange parameters
+ redirect_handler: Optional redirect handler
+ callback_handler: Optional callback handler
+ timeout: Request timeout in seconds
+ """
+ super().__init__(
+ server_url=server_url,
+ client_metadata=client_metadata,
+ storage=storage,
+ redirect_handler=redirect_handler,
+ callback_handler=callback_handler,
+ timeout=timeout,
+ )
+ self.idp_token_endpoint = idp_token_endpoint
+ self.token_exchange_params = token_exchange_params
+ self._id_jag: str | None = None
+
+ async def exchange_token_for_id_jag(
+ self,
+ client: httpx.AsyncClient,
+ ) -> str:
+ """
+ Exchange ID Token for ID-JAG using RFC 8693 Token Exchange.
+
+ Args:
+ client: HTTP client for making requests
+
+ Returns:
+ The ID-JAG token string
+
+ Raises:
+ OAuthTokenError: If token exchange fails
+ """
+ logger.info("Starting token exchange for ID-JAG")
+
+ # Build token exchange request
+ token_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "requested_token_type": self.token_exchange_params.requested_token_type,
+ "audience": self.token_exchange_params.audience,
+ "resource": self.token_exchange_params.resource,
+ "subject_token": self.token_exchange_params.subject_token,
+ "subject_token_type": self.token_exchange_params.subject_token_type,
+ }
+
+ if self.token_exchange_params.scope:
+ token_data["scope"] = self.token_exchange_params.scope
+
+ # Add client authentication if needed
+ if self.context.client_info:
+ if self.context.client_info.client_id is not None:
+ token_data["client_id"] = self.context.client_info.client_id
+ if self.context.client_info.client_secret is not None:
+ token_data["client_secret"] = self.context.client_info.client_secret
+
+ try:
+ response = await client.post(
+ self.idp_token_endpoint,
+ data=token_data,
+ timeout=self.context.timeout,
+ )
+
+ if response.status_code != 200:
+ error_data: dict[str, str] = (
+ response.json() if response.headers.get("content-type", "").startswith("application/json") else {}
+ )
+ error: str = error_data.get("error", "unknown_error")
+ error_description: str = error_data.get("error_description", "Token exchange failed")
+ raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}")
+
+ # Parse response
+ token_response = TokenExchangeResponse.model_validate_json(response.content)
+
+ # Validate response
+ if token_response.issued_token_type != "urn:ietf:params:oauth:token-type:id-jag":
+ raise OAuthTokenError(f"Unexpected token type: {token_response.issued_token_type}")
+
+ if token_response.token_type != "N_A":
+ logger.warning(f"Expected token_type 'N_A', got '{token_response.token_type}'")
+
+ logger.info("Successfully obtained ID-JAG")
+ self._id_jag = token_response.id_jag
+ return token_response.id_jag
+
+ except httpx.HTTPError as e:
+ raise OAuthTokenError(f"HTTP error during token exchange: {e}") from e
+
+ async def exchange_id_jag_for_access_token(
+ self,
+ client: httpx.AsyncClient,
+ id_jag: str,
+ ) -> OAuthToken:
+ """
+ Exchange ID-JAG for access token using RFC 7523 JWT Bearer Grant.
+
+ Args:
+ client: HTTP client for making requests
+ id_jag: The ID-JAG token
+
+ Returns:
+ OAuth access token
+
+ Raises:
+ OAuthTokenError: If JWT bearer grant fails
+ """
+ logger.info("Exchanging ID-JAG for access token")
+
+ # Discover token endpoint from MCP server if not already done
+ if not self.context.oauth_metadata or not self.context.oauth_metadata.token_endpoint:
+ raise OAuthFlowError("MCP server token endpoint not discovered")
+
+ token_endpoint = str(self.context.oauth_metadata.token_endpoint)
+
+ # Build JWT bearer grant request
+ token_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
+ "assertion": id_jag,
+ }
+
+ # Add client authentication
+ if self.context.client_info:
+ if self.context.client_info.client_id is not None:
+ token_data["client_id"] = self.context.client_info.client_id
+ if self.context.client_info.client_secret is not None:
+ token_data["client_secret"] = self.context.client_info.client_secret
+
+ try:
+ response = await client.post(
+ token_endpoint,
+ data=token_data,
+ timeout=self.context.timeout,
+ )
+
+ if response.status_code != 200:
+ error_data: dict[str, str] = (
+ response.json() if response.headers.get("content-type", "").startswith("application/json") else {}
+ )
+ error: str = error_data.get("error", "unknown_error")
+ error_description: str = error_data.get("error_description", "JWT bearer grant failed")
+ raise OAuthTokenError(f"JWT bearer grant failed: {error} - {error_description}")
+
+ # Parse OAuth token response
+ token = OAuthToken.model_validate_json(response.content)
+
+ # Store tokens
+ self.context.current_tokens = token
+ self.context.update_token_expiry(token)
+ await self.context.storage.set_tokens(token)
+
+ logger.info("Successfully obtained access token via ID-JAG")
+ return token
+
+ except httpx.HTTPError as e:
+ raise OAuthTokenError(f"HTTP error during JWT bearer grant: {e}") from e
+
+ async def _perform_authorization(self) -> httpx.Request:
+ """
+ Perform enterprise authorization flow.
+
+ Overrides parent method to use token exchange + JWT bearer grant
+ instead of standard authorization code flow.
+ """
+ # Check if we already have valid tokens
+ if self.context.is_token_valid():
+ # Return a dummy request - we don't need to make any request
+ return httpx.Request("GET", self.context.server_url)
+
+ # For now, raise NotImplementedError as this requires integration
+ # with the full httpx auth flow
+ raise NotImplementedError(
+ "Full enterprise auth flow integration not yet implemented. "
+ "Use exchange_token_for_id_jag and exchange_id_jag_for_access_token directly."
+ )
+
+
+# ============================================================================
+# Helper Functions
+# ============================================================================
+
+
+def decode_id_jag(id_jag: str, verify: bool = False) -> IDJAGClaims:
+ """
+ Decode an ID-JAG token without verification.
+
+ Args:
+ id_jag: The ID-JAG token string
+ verify: Whether to verify signature (requires key)
+
+ Returns:
+ Decoded ID-JAG claims
+
+ Note:
+ For verification, use server-side validation instead.
+ """
+ import jwt
+
+ # Decode without verification for inspection
+ claims = jwt.decode(id_jag, options={"verify_signature": False})
+ header = jwt.get_unverified_header(id_jag)
+
+ # Add typ from header to claims
+ claims["typ"] = header.get("typ", "")
+
+ return IDJAGClaims.model_validate(claims)
+
+
+def validate_token_exchange_params(
+ params: TokenExchangeParameters,
+) -> None:
+ """
+ Validate token exchange parameters.
+
+ Args:
+ params: Token exchange parameters to validate
+
+ Raises:
+ ValueError: If parameters are invalid
+ """
+ if not params.subject_token:
+ raise ValueError("subject_token is required")
+
+ if not params.audience:
+ raise ValueError("audience is required")
+
+ if not params.resource:
+ raise ValueError("resource is required")
+
+ if params.subject_token_type not in [
+ "urn:ietf:params:oauth:token-type:id_token",
+ "urn:ietf:params:oauth:token-type:saml2",
+ ]:
+ raise ValueError(f"Invalid subject_token_type: {params.subject_token_type}")
diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py
new file mode 100644
index 0000000000..4a10f4f664
--- /dev/null
+++ b/tests/client/auth/test_enterprise_managed_auth_client.py
@@ -0,0 +1,1126 @@
+"""Tests for Enterprise Managed Authorization client-side implementation."""
+
+import time
+from typing import Any
+from unittest.mock import AsyncMock, Mock, patch
+
+import httpx
+import jwt
+import pytest
+from pydantic import AnyHttpUrl, AnyUrl
+
+from mcp.client.auth import OAuthTokenError
+from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ IDJAGClaims,
+ TokenExchangeParameters,
+ TokenExchangeResponse,
+ decode_id_jag,
+ validate_token_exchange_params,
+)
+from mcp.shared.auth import OAuthClientMetadata
+
+# ============================================================================
+# Fixtures
+# ============================================================================
+
+
+@pytest.fixture
+def sample_id_token() -> str:
+ """Generate a sample ID token for testing."""
+ payload = {
+ "iss": "https://idp.example.com",
+ "sub": "user123",
+ "aud": "mcp-client-app",
+ "exp": int(time.time()) + 3600,
+ "iat": int(time.time()),
+ "email": "user@example.com",
+ }
+ return jwt.encode(payload, "secret", algorithm="HS256")
+
+
+@pytest.fixture
+def sample_id_jag() -> str:
+ """Generate a sample ID-JAG token for testing."""
+ payload = {
+ "jti": "unique-jwt-id-12345",
+ "iss": "https://idp.example.com",
+ "sub": "user123",
+ "aud": "https://auth.mcp-server.example/",
+ "resource": "https://mcp-server.example/",
+ "client_id": "mcp-client-app",
+ "exp": int(time.time()) + 300,
+ "iat": int(time.time()),
+ "scope": "read write",
+ }
+ token = jwt.encode(payload, "secret", algorithm="HS256")
+
+ # Manually add typ to header
+ header = jwt.get_unverified_header(token)
+ header["typ"] = "oauth-id-jag+jwt"
+
+ return jwt.encode(payload, "secret", algorithm="HS256", headers={"typ": "oauth-id-jag+jwt"})
+
+
+@pytest.fixture
+def mock_token_storage() -> Any:
+ """Create a mock token storage."""
+ storage = Mock()
+ storage.get_tokens = AsyncMock(return_value=None)
+ storage.set_tokens = AsyncMock()
+ storage.get_client_info = AsyncMock(return_value=None)
+ storage.set_client_info = AsyncMock()
+ return storage
+
+
+# ============================================================================
+# Tests for TokenExchangeParameters
+# ============================================================================
+
+
+def test_token_exchange_params_from_id_token():
+ """Test creating TokenExchangeParameters from ID token."""
+ params = TokenExchangeParameters.from_id_token(
+ id_token="eyJhbGc...",
+ mcp_server_auth_issuer="https://auth.server.example/",
+ mcp_server_resource_id="https://server.example/",
+ scope="read write",
+ )
+
+ assert params.subject_token == "eyJhbGc..."
+ assert params.subject_token_type == "urn:ietf:params:oauth:token-type:id_token"
+ assert params.audience == "https://auth.server.example/"
+ assert params.resource == "https://server.example/"
+ assert params.scope == "read write"
+ assert params.requested_token_type == "urn:ietf:params:oauth:token-type:id-jag"
+
+
+def test_token_exchange_params_from_saml_assertion():
+ """Test creating TokenExchangeParameters from SAML assertion."""
+ params = TokenExchangeParameters.from_saml_assertion(
+ saml_assertion="...",
+ mcp_server_auth_issuer="https://auth.server.example/",
+ mcp_server_resource_id="https://server.example/",
+ scope="read",
+ )
+
+ assert params.subject_token == "..."
+ assert params.subject_token_type == "urn:ietf:params:oauth:token-type:saml2"
+ assert params.audience == "https://auth.server.example/"
+ assert params.resource == "https://server.example/"
+ assert params.scope == "read"
+
+
+def test_validate_token_exchange_params_valid():
+ """Test validating valid token exchange parameters."""
+ params = TokenExchangeParameters.from_id_token(
+ id_token="token",
+ mcp_server_auth_issuer="https://auth.example/",
+ mcp_server_resource_id="https://server.example/",
+ )
+
+ # Should not raise
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_invalid_token_type():
+ """Test validation fails for invalid subject token type."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="invalid:type",
+ audience="https://auth.example/",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(ValueError, match="Invalid subject_token_type"):
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_missing_subject_token():
+ """Test validation fails for missing subject token."""
+ params = TokenExchangeParameters(
+ subject_token="",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="https://auth.example/",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(ValueError, match="subject_token is required"):
+ validate_token_exchange_params(params)
+
+
+# ============================================================================
+# Tests for TokenExchangeResponse
+# ============================================================================
+
+
+def test_token_exchange_response_parsing():
+ """Test parsing token exchange response."""
+ response_json = """{
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": "eyJhbGc...",
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300
+ }"""
+
+ response = TokenExchangeResponse.model_validate_json(response_json)
+
+ assert response.issued_token_type == "urn:ietf:params:oauth:token-type:id-jag"
+ assert response.id_jag == "eyJhbGc..."
+ assert response.access_token == "eyJhbGc..."
+ assert response.token_type == "N_A"
+ assert response.scope == "read write"
+ assert response.expires_in == 300
+
+
+def test_token_exchange_response_id_jag_property():
+ """Test id_jag property returns access_token."""
+ response = TokenExchangeResponse(
+ issued_token_type="urn:ietf:params:oauth:token-type:id-jag",
+ access_token="the-id-jag-token",
+ token_type="N_A",
+ )
+
+ assert response.id_jag == "the-id-jag-token"
+
+
+# ============================================================================
+# Tests for IDJAGClaims
+# ============================================================================
+
+
+def test_decode_id_jag(sample_id_jag: str):
+ """Test decoding ID-JAG token."""
+ claims = decode_id_jag(sample_id_jag)
+
+ assert claims.iss == "https://idp.example.com"
+ assert claims.sub == "user123"
+ assert claims.aud == "https://auth.mcp-server.example/"
+ assert claims.resource == "https://mcp-server.example/"
+ assert claims.client_id == "mcp-client-app"
+ assert claims.scope == "read write"
+
+
+def test_id_jag_claims_with_extra_fields():
+ """Test IDJAGClaims allows extra fields."""
+ claims_data = {
+ "typ": "oauth-id-jag+jwt",
+ "jti": "jti123",
+ "iss": "https://idp.example.com",
+ "sub": "user123",
+ "aud": "https://auth.server.example/",
+ "resource": "https://server.example/",
+ "client_id": "client123",
+ "exp": int(time.time()) + 300,
+ "iat": int(time.time()),
+ "scope": "read",
+ "email": "user@example.com",
+ "custom_claim": "custom_value", # Extra field
+ }
+
+ claims = IDJAGClaims.model_validate(claims_data)
+ assert claims.email == "user@example.com"
+ # Extra field should be preserved
+ assert claims.model_extra is not None and claims.model_extra.get("custom_claim") == "custom_value"
+
+
+# ============================================================================
+# Tests for EnterpriseAuthOAuthClientProvider
+# ============================================================================
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_success(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test successful token exchange for ID-JAG."""
+ # Create provider
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify
+ assert id_jag == sample_id_jag
+ assert provider._id_jag == sample_id_jag
+
+ # Verify request was made correctly
+ mock_client.post.assert_called_once()
+ call_args = mock_client.post.call_args
+ assert call_args[0][0] == "https://idp.example.com/oauth2/token"
+ assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:token-exchange"
+ assert call_args[1]["data"]["requested_token_type"] == "urn:ietf:params:oauth:token-type:id-jag"
+ assert call_args[1]["data"]["audience"] == "https://auth.mcp-server.example/"
+ assert call_args[1]["data"]["resource"] == "https://mcp-server.example/"
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_error(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange failure handling."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock error response
+ mock_response = httpx.Response(
+ status_code=400,
+ json={
+ "error": "invalid_request",
+ "error_description": "Invalid subject token",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="Token exchange failed"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with unexpected token type."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock response with wrong token type
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "access_token": "some-token",
+ "token_type": "Bearer",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="Unexpected token type"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_for_access_token_success(sample_id_jag: str, mock_token_storage: Any):
+ """Test successful JWT bearer grant to get access token."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ from mcp.shared.auth import OAuthMetadata
+
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "token_type": "Bearer",
+ "access_token": "mcp-access-token-12345",
+ "expires_in": 3600,
+ "scope": "read write",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform JWT bearer grant
+ token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+ # Verify
+ assert token.access_token == "mcp-access-token-12345"
+ assert token.token_type == "Bearer"
+ assert token.expires_in == 3600
+
+ # Verify tokens were stored
+ mock_token_storage.set_tokens.assert_called_once()
+
+ # Verify request was made correctly
+ mock_client.post.assert_called_once()
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer"
+ assert call_args[1]["data"]["assertion"] == sample_id_jag
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant fails without OAuth metadata."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # No OAuth metadata set
+ mock_client = Mock(spec=httpx.AsyncClient)
+
+ # Should raise OAuthFlowError
+ from mcp.client.auth import OAuthFlowError
+
+ with pytest.raises(OAuthFlowError, match="token endpoint not discovered"):
+ await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_not_implemented(mock_token_storage: Any):
+ """Test that _perform_authorization raises NotImplementedError."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Should raise NotImplementedError
+ with pytest.raises(NotImplementedError, match="not yet implemented"):
+ await provider._perform_authorization()
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any):
+ """Test that _perform_authorization returns dummy request when tokens are valid."""
+ from mcp.shared.auth import OAuthToken
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set valid tokens
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="valid-token",
+ expires_in=3600,
+ )
+ provider.context.token_expiry_time = time.time() + 3600
+
+ # Should return a dummy request
+ request = await provider._perform_authorization()
+ assert request.method == "GET"
+ assert str(request.url) == "https://mcp-server.example/"
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_authentication(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange with client authentication."""
+ from mcp.shared.auth import OAuthClientInformationFull
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info with secret
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret="test-client-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client credentials were included
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-client-id"
+ assert call_args[1]["data"]["client_secret"] == "test-client-secret"
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test token exchange with client_id but no client_secret (covers branch 232->235)."""
+ from mcp.shared.auth import OAuthClientInformationFull
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info WITHOUT secret (client_secret=None)
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret=None, # No secret
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client_id was included but NOT client_secret
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-client-id"
+ assert "client_secret" not in call_args[1]["data"]
+
+
+@pytest.mark.anyio
+async def test_exchange_token_http_error(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with HTTP error."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection failed"))
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="HTTP error during token exchange"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_non_json_error_response(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with non-JSON error response."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock error response with non-JSON content
+ mock_response = httpx.Response(
+ status_code=500,
+ content=b"Internal Server Error",
+ headers={"content-type": "text/plain"},
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError with default error
+ with pytest.raises(OAuthTokenError, match="Token exchange failed: unknown_error"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_warning_for_non_na_token_type(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange logs warning for non-N_A token type."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock response with different token_type
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "Bearer", # Not N_A
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should succeed but log warning
+ import logging
+
+ with patch.object(
+ logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning"
+ ) as mock_warning:
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+ assert id_jag == sample_id_jag
+ mock_warning.assert_called_once()
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_authentication(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant with client authentication."""
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info with secret
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret="test-client-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "token_type": "Bearer",
+ "access_token": "mcp-access-token-12345",
+ "expires_in": 3600,
+ "scope": "read write",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform JWT bearer grant
+ token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+ # Verify token was returned
+ assert token.access_token == "mcp-access-token-12345"
+
+ # Verify client credentials were included
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-client-id"
+ assert call_args[1]["data"]["client_secret"] == "test-client-secret"
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_id_only(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant with client_id but no client_secret (covers branch 304->307)."""
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info WITHOUT secret (client_secret=None)
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret=None, # No secret
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "token_type": "Bearer",
+ "access_token": "mcp-access-token-12345",
+ "expires_in": 3600,
+ "scope": "read write",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform JWT bearer grant
+ token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+ # Verify token was returned correctly
+ assert token.access_token == "mcp-access-token-12345"
+ assert token.token_type == "Bearer"
+
+ # Verify client_id was included but NOT client_secret
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-client-id"
+ assert "client_secret" not in call_args[1]["data"]
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_error_response(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant with error response."""
+ from mcp.shared.auth import OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock error response
+ mock_response = httpx.Response(
+ status_code=400,
+ json={
+ "error": "invalid_grant",
+ "error_description": "Invalid assertion",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="JWT bearer grant failed"):
+ await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_non_json_error(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant with non-JSON error response."""
+ from mcp.shared.auth import OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock error response with non-JSON content
+ mock_response = httpx.Response(
+ status_code=503,
+ content=b"Service Unavailable",
+ headers={"content-type": "text/html"},
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError with default error
+ with pytest.raises(OAuthTokenError, match="JWT bearer grant failed: unknown_error"):
+ await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_http_error(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant with HTTP error."""
+ from mcp.shared.auth import OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(side_effect=httpx.ReadTimeout("Request timeout"))
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="HTTP error during JWT bearer grant"):
+ await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_info_but_no_client_id(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange when client_info exists but client_id is None (covers line 231)."""
+ from mcp.shared.auth import OAuthClientInformationFull
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info with client_id=None
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id=None, # This should skip the client_id assignment
+ client_secret="test-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client_id was not included (None), but client_secret was included
+ call_args = mock_client.post.call_args
+ assert "client_id" not in call_args[1]["data"]
+ assert call_args[1]["data"]["client_secret"] == "test-secret"
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_info_but_no_client_id(sample_id_jag: str, mock_token_storage: Any):
+ """Test ID-JAG exchange when client_info exists but client_id is None (covers line 302)."""
+ from pydantic import AnyHttpUrl
+
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set client info with client_id=None
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id=None, # This should skip the client_id assignment
+ client_secret="test-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "token_type": "Bearer",
+ "access_token": "mcp-access-token-12345",
+ "expires_in": 3600,
+ "scope": "read write",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform JWT bearer grant
+ token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+ # Verify
+ assert token.access_token == "mcp-access-token-12345"
+ assert token.token_type == "Bearer"
+ assert token.expires_in == 3600
+
+ # Verify client_id was not included (None), but client_secret was included
+ call_args = mock_client.post.call_args
+ assert "client_id" not in call_args[1]["data"]
+ assert call_args[1]["data"]["client_secret"] == "test-secret"
+
+
+def test_validate_token_exchange_params_missing_audience():
+ """Test validation fails for missing audience."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(ValueError, match="audience is required"):
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_missing_resource():
+ """Test validation fails for missing resource."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="https://auth.example/",
+ resource="",
+ )
+
+ with pytest.raises(ValueError, match="resource is required"):
+ validate_token_exchange_params(params)