diff --git a/README.md b/README.md index 5e8129c96e..fab9e05ef6 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ - [Writing MCP Clients](#writing-mcp-clients) - [Client Display Utilities](#client-display-utilities) - [OAuth Authentication for Clients](#oauth-authentication-for-clients) + - [Enterprise Managed Authorization](#enterprise-managed-authorization) - [Parsing Tool Results](#parsing-tool-results) - [MCP Primitives](#mcp-primitives) - [Server Capabilities](#server-capabilities) @@ -2356,6 +2357,129 @@ _Full example: [examples/snippets/clients/oauth_client.py](https://github.com/mo For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/). +#### Enterprise Managed Authorization + +The SDK includes support for Enterprise Managed Authorization (SEP-990), which enables MCP clients to connect to protected servers using enterprise Single Sign-On (SSO) systems. This implementation supports: + +- **RFC 8693**: OAuth 2.0 Token Exchange (ID Token → ID-JAG) +- **RFC 7523**: JSON Web Token (JWT) Profile for OAuth 2.0 Authorization Grants (ID-JAG → Access Token) +- Integration with enterprise identity providers (Okta, Azure AD, etc.) + +**Key Components:** + +The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provider to implement the enterprise authorization flow: + +**Token Exchange Flow:** + +1. **Obtain ID Token** from your enterprise IdP (e.g., Okta, Azure AD) +2. **Exchange ID Token for ID-JAG** using RFC 8693 Token Exchange +3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant +4. **Use Access Token** to call protected MCP server tools + +**Example Usage:** + +```python +import asyncio +import httpx +from pydantic import AnyUrl + +from mcp.client.auth.extensions import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, +) +from mcp.shared.auth import OAuthClientMetadata +from mcp.client.auth import TokenStorage + +# Define token storage implementation +class SimpleTokenStorage(TokenStorage): + def __init__(self): + self._tokens = None + self._client_info = None + + async def get_tokens(self): + return self._tokens + + async def set_tokens(self, tokens): + self._tokens = tokens + + async def get_client_info(self): + return self._client_info + + async def set_client_info(self, client_info): + self._client_info = client_info + +async def main(): + # Step 1: Get ID token from your IdP (example with Okta) + id_token = await get_id_token_from_idp() # Your IdP authentication + + # Step 2: Configure token exchange parameters + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer="https://your-idp.com", # IdP issuer URL + mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID + scope="mcp:tools mcp:resources", # Optional scopes + ) + + # Step 3: Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example.com", + client_metadata=OAuthClientMetadata( + client_name="Enterprise MCP Client", + client_id="your-client-id", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", + token_exchange_params=token_exchange_params, + ) + + # Step 4: Perform token exchange and get access token + async with httpx.AsyncClient() as client: + # Exchange ID token for ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + print(f"Obtained ID-JAG: {id_jag[:50]}...") + + # Exchange ID-JAG for access token + access_token = await enterprise_auth.exchange_id_jag_for_access_token( + client, id_jag + ) + print(f"Access token obtained, expires in: {access_token.expires_in}s") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +**Working with SAML Assertions:** + +If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions: + +```python +token_exchange_params = TokenExchangeParameters.from_saml_assertion( + saml_assertion=saml_assertion_string, + mcp_server_auth_issuer="https://your-idp.com", + mcp_server_resource_id="https://mcp-server.example.com", + scope="mcp:tools", +) +``` + +**Decoding and Inspecting ID-JAG Tokens:** + +You can decode ID-JAG tokens to inspect their claims: + +```python +from mcp.client.auth.extensions import decode_id_jag + +# Decode without signature verification (for inspection only) +claims = decode_id_jag(id_jag) +print(f"Subject: {claims.sub}") +print(f"Issuer: {claims.iss}") +print(f"Audience: {claims.aud}") +print(f"Client ID: {claims.client_id}") +print(f"Resource: {claims.resource}") +``` + ### Parsing Tool Results When calling tools through MCP, the `CallToolResult` object contains the tool's response in a structured format. Understanding how to parse this result is essential for properly handling tool outputs. diff --git a/src/mcp/client/auth/extensions/__init__.py b/src/mcp/client/auth/extensions/__init__.py index e69de29bb2..56ba368ef8 100644 --- a/src/mcp/client/auth/extensions/__init__.py +++ b/src/mcp/client/auth/extensions/__init__.py @@ -0,0 +1,19 @@ +"""MCP Client Auth Extensions.""" + +from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + IDJAGClaims, + TokenExchangeParameters, + TokenExchangeResponse, + decode_id_jag, + validate_token_exchange_params, +) + +__all__ = [ + "EnterpriseAuthOAuthClientProvider", + "IDJAGClaims", + "TokenExchangeParameters", + "TokenExchangeResponse", + "decode_id_jag", + "validate_token_exchange_params", +] diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py new file mode 100644 index 0000000000..e55283e968 --- /dev/null +++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py @@ -0,0 +1,412 @@ +""" +Enterprise Managed Authorization extension for MCP (SEP-990). + +Implements RFC 8693 Token Exchange and RFC 7523 JWT Bearer Grant for +enterprise SSO integration. +""" + +import logging +from typing import Any + +import httpx +from pydantic import BaseModel, Field + +from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage +from mcp.shared.auth import OAuthClientMetadata, OAuthToken + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Data Models +# ============================================================================ + + +class TokenExchangeParameters(BaseModel): + """Parameters for RFC 8693 Token Exchange request.""" + + requested_token_type: str = Field( + default="urn:ietf:params:oauth:token-type:id-jag", + description="Type of token being requested (ID-JAG)", + ) + + audience: str = Field( + ..., + description="Issuer URL of the MCP Server's authorization server", + ) + + resource: str = Field( + ..., + description="RFC 9728 Resource Identifier of the MCP Server", + ) + + scope: str | None = Field( + default=None, + description="Space-separated list of scopes being requested", + ) + + subject_token: str = Field( + ..., + description="ID Token or SAML assertion for the end user", + ) + + subject_token_type: str = Field( + ..., + description="Type of subject token (id_token or saml2)", + ) + + @classmethod + def from_id_token( + cls, + id_token: str, + mcp_server_auth_issuer: str, + mcp_server_resource_id: str, + scope: str | None = None, + ) -> "TokenExchangeParameters": + """Create parameters for OIDC ID Token exchange.""" + return cls( + subject_token=id_token, + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience=mcp_server_auth_issuer, + resource=mcp_server_resource_id, + scope=scope, + ) + + @classmethod + def from_saml_assertion( + cls, + saml_assertion: str, + mcp_server_auth_issuer: str, + mcp_server_resource_id: str, + scope: str | None = None, + ) -> "TokenExchangeParameters": + """Create parameters for SAML assertion exchange.""" + return cls( + subject_token=saml_assertion, + subject_token_type="urn:ietf:params:oauth:token-type:saml2", + audience=mcp_server_auth_issuer, + resource=mcp_server_resource_id, + scope=scope, + ) + + +class TokenExchangeResponse(BaseModel): + """Response from RFC 8693 Token Exchange.""" + + issued_token_type: str = Field( + ..., + description="Type of token issued (should be id-jag)", + ) + + access_token: str = Field( + ..., + description="The ID-JAG token (named access_token per RFC 8693)", + ) + + token_type: str = Field( + ..., + description="Token type (should be N_A for ID-JAG)", + ) + + scope: str | None = Field( + default=None, + description="Granted scopes", + ) + + expires_in: int | None = Field( + default=None, + description="Lifetime in seconds", + ) + + @property + def id_jag(self) -> str: + """Get the ID-JAG token.""" + return self.access_token + + +class IDJAGClaims(BaseModel): + """Claims structure for Identity Assertion JWT Authorization Grant.""" + + model_config = {"extra": "allow"} + + # JWT header + typ: str = Field( + ..., + description="JWT type - must be 'oauth-id-jag+jwt'", + ) + + # Required claims + jti: str = Field(..., description="Unique JWT ID") + iss: str = Field(..., description="IdP issuer URL") + sub: str = Field(..., description="Subject (user) identifier") + aud: str = Field(..., description="MCP Server's auth server issuer") + resource: str = Field(..., description="MCP Server resource identifier") + client_id: str = Field(..., description="MCP Client identifier") + exp: int = Field(..., description="Expiration timestamp") + iat: int = Field(..., description="Issued-at timestamp") + + # Optional claims + scope: str | None = Field(None, description="Space-separated scopes") + email: str | None = Field(None, description="User email") + + +class EnterpriseAuthOAuthClientProvider(OAuthClientProvider): + """ + OAuth client provider for Enterprise Managed Authorization (SEP-990). + + Implements: + - RFC 8693: Token Exchange (ID Token → ID-JAG) + - RFC 7523: JWT Bearer Grant (ID-JAG → Access Token) + """ + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + idp_token_endpoint: str, + token_exchange_params: TokenExchangeParameters, + redirect_handler: Any = None, + callback_handler: Any = None, + timeout: float = 300.0, + ) -> None: + """ + Initialize Enterprise Auth OAuth Client. + + Args: + server_url: MCP server URL + client_metadata: OAuth client metadata + storage: Token storage implementation + idp_token_endpoint: Enterprise IdP token endpoint URL + token_exchange_params: Token exchange parameters + redirect_handler: Optional redirect handler + callback_handler: Optional callback handler + timeout: Request timeout in seconds + """ + super().__init__( + server_url=server_url, + client_metadata=client_metadata, + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + timeout=timeout, + ) + self.idp_token_endpoint = idp_token_endpoint + self.token_exchange_params = token_exchange_params + self._id_jag: str | None = None + + async def exchange_token_for_id_jag( + self, + client: httpx.AsyncClient, + ) -> str: + """ + Exchange ID Token for ID-JAG using RFC 8693 Token Exchange. + + Args: + client: HTTP client for making requests + + Returns: + The ID-JAG token string + + Raises: + OAuthTokenError: If token exchange fails + """ + logger.info("Starting token exchange for ID-JAG") + + # Build token exchange request + token_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "requested_token_type": self.token_exchange_params.requested_token_type, + "audience": self.token_exchange_params.audience, + "resource": self.token_exchange_params.resource, + "subject_token": self.token_exchange_params.subject_token, + "subject_token_type": self.token_exchange_params.subject_token_type, + } + + if self.token_exchange_params.scope: + token_data["scope"] = self.token_exchange_params.scope + + # Add client authentication if needed + if self.context.client_info: + if self.context.client_info.client_id is not None: + token_data["client_id"] = self.context.client_info.client_id + if self.context.client_info.client_secret is not None: + token_data["client_secret"] = self.context.client_info.client_secret + + try: + response = await client.post( + self.idp_token_endpoint, + data=token_data, + timeout=self.context.timeout, + ) + + if response.status_code != 200: + error_data: dict[str, str] = ( + response.json() if response.headers.get("content-type", "").startswith("application/json") else {} + ) + error: str = error_data.get("error", "unknown_error") + error_description: str = error_data.get("error_description", "Token exchange failed") + raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}") + + # Parse response + token_response = TokenExchangeResponse.model_validate_json(response.content) + + # Validate response + if token_response.issued_token_type != "urn:ietf:params:oauth:token-type:id-jag": + raise OAuthTokenError(f"Unexpected token type: {token_response.issued_token_type}") + + if token_response.token_type != "N_A": + logger.warning(f"Expected token_type 'N_A', got '{token_response.token_type}'") + + logger.info("Successfully obtained ID-JAG") + self._id_jag = token_response.id_jag + return token_response.id_jag + + except httpx.HTTPError as e: + raise OAuthTokenError(f"HTTP error during token exchange: {e}") from e + + async def exchange_id_jag_for_access_token( + self, + client: httpx.AsyncClient, + id_jag: str, + ) -> OAuthToken: + """ + Exchange ID-JAG for access token using RFC 7523 JWT Bearer Grant. + + Args: + client: HTTP client for making requests + id_jag: The ID-JAG token + + Returns: + OAuth access token + + Raises: + OAuthTokenError: If JWT bearer grant fails + """ + logger.info("Exchanging ID-JAG for access token") + + # Discover token endpoint from MCP server if not already done + if not self.context.oauth_metadata or not self.context.oauth_metadata.token_endpoint: + raise OAuthFlowError("MCP server token endpoint not discovered") + + token_endpoint = str(self.context.oauth_metadata.token_endpoint) + + # Build JWT bearer grant request + token_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": id_jag, + } + + # Add client authentication + if self.context.client_info: + if self.context.client_info.client_id is not None: + token_data["client_id"] = self.context.client_info.client_id + if self.context.client_info.client_secret is not None: + token_data["client_secret"] = self.context.client_info.client_secret + + try: + response = await client.post( + token_endpoint, + data=token_data, + timeout=self.context.timeout, + ) + + if response.status_code != 200: + error_data: dict[str, str] = ( + response.json() if response.headers.get("content-type", "").startswith("application/json") else {} + ) + error: str = error_data.get("error", "unknown_error") + error_description: str = error_data.get("error_description", "JWT bearer grant failed") + raise OAuthTokenError(f"JWT bearer grant failed: {error} - {error_description}") + + # Parse OAuth token response + token = OAuthToken.model_validate_json(response.content) + + # Store tokens + self.context.current_tokens = token + self.context.update_token_expiry(token) + await self.context.storage.set_tokens(token) + + logger.info("Successfully obtained access token via ID-JAG") + return token + + except httpx.HTTPError as e: + raise OAuthTokenError(f"HTTP error during JWT bearer grant: {e}") from e + + async def _perform_authorization(self) -> httpx.Request: + """ + Perform enterprise authorization flow. + + Overrides parent method to use token exchange + JWT bearer grant + instead of standard authorization code flow. + """ + # Check if we already have valid tokens + if self.context.is_token_valid(): + # Return a dummy request - we don't need to make any request + return httpx.Request("GET", self.context.server_url) + + # For now, raise NotImplementedError as this requires integration + # with the full httpx auth flow + raise NotImplementedError( + "Full enterprise auth flow integration not yet implemented. " + "Use exchange_token_for_id_jag and exchange_id_jag_for_access_token directly." + ) + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def decode_id_jag(id_jag: str, verify: bool = False) -> IDJAGClaims: + """ + Decode an ID-JAG token without verification. + + Args: + id_jag: The ID-JAG token string + verify: Whether to verify signature (requires key) + + Returns: + Decoded ID-JAG claims + + Note: + For verification, use server-side validation instead. + """ + import jwt + + # Decode without verification for inspection + claims = jwt.decode(id_jag, options={"verify_signature": False}) + header = jwt.get_unverified_header(id_jag) + + # Add typ from header to claims + claims["typ"] = header.get("typ", "") + + return IDJAGClaims.model_validate(claims) + + +def validate_token_exchange_params( + params: TokenExchangeParameters, +) -> None: + """ + Validate token exchange parameters. + + Args: + params: Token exchange parameters to validate + + Raises: + ValueError: If parameters are invalid + """ + if not params.subject_token: + raise ValueError("subject_token is required") + + if not params.audience: + raise ValueError("audience is required") + + if not params.resource: + raise ValueError("resource is required") + + if params.subject_token_type not in [ + "urn:ietf:params:oauth:token-type:id_token", + "urn:ietf:params:oauth:token-type:saml2", + ]: + raise ValueError(f"Invalid subject_token_type: {params.subject_token_type}") diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py new file mode 100644 index 0000000000..4a10f4f664 --- /dev/null +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -0,0 +1,1126 @@ +"""Tests for Enterprise Managed Authorization client-side implementation.""" + +import time +from typing import Any +from unittest.mock import AsyncMock, Mock, patch + +import httpx +import jwt +import pytest +from pydantic import AnyHttpUrl, AnyUrl + +from mcp.client.auth import OAuthTokenError +from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + IDJAGClaims, + TokenExchangeParameters, + TokenExchangeResponse, + decode_id_jag, + validate_token_exchange_params, +) +from mcp.shared.auth import OAuthClientMetadata + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def sample_id_token() -> str: + """Generate a sample ID token for testing.""" + payload = { + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "mcp-client-app", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + "email": "user@example.com", + } + return jwt.encode(payload, "secret", algorithm="HS256") + + +@pytest.fixture +def sample_id_jag() -> str: + """Generate a sample ID-JAG token for testing.""" + payload = { + "jti": "unique-jwt-id-12345", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.mcp-server.example/", + "resource": "https://mcp-server.example/", + "client_id": "mcp-client-app", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + "scope": "read write", + } + token = jwt.encode(payload, "secret", algorithm="HS256") + + # Manually add typ to header + header = jwt.get_unverified_header(token) + header["typ"] = "oauth-id-jag+jwt" + + return jwt.encode(payload, "secret", algorithm="HS256", headers={"typ": "oauth-id-jag+jwt"}) + + +@pytest.fixture +def mock_token_storage() -> Any: + """Create a mock token storage.""" + storage = Mock() + storage.get_tokens = AsyncMock(return_value=None) + storage.set_tokens = AsyncMock() + storage.get_client_info = AsyncMock(return_value=None) + storage.set_client_info = AsyncMock() + return storage + + +# ============================================================================ +# Tests for TokenExchangeParameters +# ============================================================================ + + +def test_token_exchange_params_from_id_token(): + """Test creating TokenExchangeParameters from ID token.""" + params = TokenExchangeParameters.from_id_token( + id_token="eyJhbGc...", + mcp_server_auth_issuer="https://auth.server.example/", + mcp_server_resource_id="https://server.example/", + scope="read write", + ) + + assert params.subject_token == "eyJhbGc..." + assert params.subject_token_type == "urn:ietf:params:oauth:token-type:id_token" + assert params.audience == "https://auth.server.example/" + assert params.resource == "https://server.example/" + assert params.scope == "read write" + assert params.requested_token_type == "urn:ietf:params:oauth:token-type:id-jag" + + +def test_token_exchange_params_from_saml_assertion(): + """Test creating TokenExchangeParameters from SAML assertion.""" + params = TokenExchangeParameters.from_saml_assertion( + saml_assertion="...", + mcp_server_auth_issuer="https://auth.server.example/", + mcp_server_resource_id="https://server.example/", + scope="read", + ) + + assert params.subject_token == "..." + assert params.subject_token_type == "urn:ietf:params:oauth:token-type:saml2" + assert params.audience == "https://auth.server.example/" + assert params.resource == "https://server.example/" + assert params.scope == "read" + + +def test_validate_token_exchange_params_valid(): + """Test validating valid token exchange parameters.""" + params = TokenExchangeParameters.from_id_token( + id_token="token", + mcp_server_auth_issuer="https://auth.example/", + mcp_server_resource_id="https://server.example/", + ) + + # Should not raise + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_invalid_token_type(): + """Test validation fails for invalid subject token type.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="invalid:type", + audience="https://auth.example/", + resource="https://server.example/", + ) + + with pytest.raises(ValueError, match="Invalid subject_token_type"): + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_missing_subject_token(): + """Test validation fails for missing subject token.""" + params = TokenExchangeParameters( + subject_token="", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="https://auth.example/", + resource="https://server.example/", + ) + + with pytest.raises(ValueError, match="subject_token is required"): + validate_token_exchange_params(params) + + +# ============================================================================ +# Tests for TokenExchangeResponse +# ============================================================================ + + +def test_token_exchange_response_parsing(): + """Test parsing token exchange response.""" + response_json = """{ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": "eyJhbGc...", + "token_type": "N_A", + "scope": "read write", + "expires_in": 300 + }""" + + response = TokenExchangeResponse.model_validate_json(response_json) + + assert response.issued_token_type == "urn:ietf:params:oauth:token-type:id-jag" + assert response.id_jag == "eyJhbGc..." + assert response.access_token == "eyJhbGc..." + assert response.token_type == "N_A" + assert response.scope == "read write" + assert response.expires_in == 300 + + +def test_token_exchange_response_id_jag_property(): + """Test id_jag property returns access_token.""" + response = TokenExchangeResponse( + issued_token_type="urn:ietf:params:oauth:token-type:id-jag", + access_token="the-id-jag-token", + token_type="N_A", + ) + + assert response.id_jag == "the-id-jag-token" + + +# ============================================================================ +# Tests for IDJAGClaims +# ============================================================================ + + +def test_decode_id_jag(sample_id_jag: str): + """Test decoding ID-JAG token.""" + claims = decode_id_jag(sample_id_jag) + + assert claims.iss == "https://idp.example.com" + assert claims.sub == "user123" + assert claims.aud == "https://auth.mcp-server.example/" + assert claims.resource == "https://mcp-server.example/" + assert claims.client_id == "mcp-client-app" + assert claims.scope == "read write" + + +def test_id_jag_claims_with_extra_fields(): + """Test IDJAGClaims allows extra fields.""" + claims_data = { + "typ": "oauth-id-jag+jwt", + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.server.example/", + "resource": "https://server.example/", + "client_id": "client123", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + "scope": "read", + "email": "user@example.com", + "custom_claim": "custom_value", # Extra field + } + + claims = IDJAGClaims.model_validate(claims_data) + assert claims.email == "user@example.com" + # Extra field should be preserved + assert claims.model_extra is not None and claims.model_extra.get("custom_claim") == "custom_value" + + +# ============================================================================ +# Tests for EnterpriseAuthOAuthClientProvider +# ============================================================================ + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_success(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test successful token exchange for ID-JAG.""" + # Create provider + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify + assert id_jag == sample_id_jag + assert provider._id_jag == sample_id_jag + + # Verify request was made correctly + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert call_args[0][0] == "https://idp.example.com/oauth2/token" + assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:token-exchange" + assert call_args[1]["data"]["requested_token_type"] == "urn:ietf:params:oauth:token-type:id-jag" + assert call_args[1]["data"]["audience"] == "https://auth.mcp-server.example/" + assert call_args[1]["data"]["resource"] == "https://mcp-server.example/" + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_error(sample_id_token: str, mock_token_storage: Any): + """Test token exchange failure handling.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock error response + mock_response = httpx.Response( + status_code=400, + json={ + "error": "invalid_request", + "error_description": "Invalid subject token", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="Token exchange failed"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with unexpected token type.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock response with wrong token type + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "access_token": "some-token", + "token_type": "Bearer", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="Unexpected token type"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_id_jag_for_access_token_success(sample_id_jag: str, mock_token_storage: Any): + """Test successful JWT bearer grant to get access token.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + from mcp.shared.auth import OAuthMetadata + + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify + assert token.access_token == "mcp-access-token-12345" + assert token.token_type == "Bearer" + assert token.expires_in == 3600 + + # Verify tokens were stored + mock_token_storage.set_tokens.assert_called_once() + + # Verify request was made correctly + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert call_args[1]["data"]["assertion"] == sample_id_jag + + +@pytest.mark.anyio +async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant fails without OAuth metadata.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # No OAuth metadata set + mock_client = Mock(spec=httpx.AsyncClient) + + # Should raise OAuthFlowError + from mcp.client.auth import OAuthFlowError + + with pytest.raises(OAuthFlowError, match="token endpoint not discovered"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_perform_authorization_not_implemented(mock_token_storage: Any): + """Test that _perform_authorization raises NotImplementedError.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Should raise NotImplementedError + with pytest.raises(NotImplementedError, match="not yet implemented"): + await provider._perform_authorization() + + +@pytest.mark.anyio +async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any): + """Test that _perform_authorization returns dummy request when tokens are valid.""" + from mcp.shared.auth import OAuthToken + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set valid tokens + provider.context.current_tokens = OAuthToken( + token_type="Bearer", + access_token="valid-token", + expires_in=3600, + ) + provider.context.token_expiry_time = time.time() + 3600 + + # Should return a dummy request + request = await provider._perform_authorization() + assert request.method == "GET" + assert str(request.url) == "https://mcp-server.example/" + + +@pytest.mark.anyio +async def test_exchange_token_with_client_authentication( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): + """Test token exchange with client authentication.""" + from mcp.shared.auth import OAuthClientInformationFull + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with secret + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret="test-client-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + + # Verify client credentials were included + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert call_args[1]["data"]["client_secret"] == "test-client-secret" + + +@pytest.mark.anyio +async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test token exchange with client_id but no client_secret (covers branch 232->235).""" + from mcp.shared.auth import OAuthClientInformationFull + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info WITHOUT secret (client_secret=None) + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret=None, # No secret + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + + # Verify client_id was included but NOT client_secret + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert "client_secret" not in call_args[1]["data"] + + +@pytest.mark.anyio +async def test_exchange_token_http_error(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with HTTP error.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection failed")) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="HTTP error during token exchange"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_non_json_error_response(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with non-JSON error response.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock error response with non-JSON content + mock_response = httpx.Response( + status_code=500, + content=b"Internal Server Error", + headers={"content-type": "text/plain"}, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError with default error + with pytest.raises(OAuthTokenError, match="Token exchange failed: unknown_error"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_warning_for_non_na_token_type( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): + """Test token exchange logs warning for non-N_A token type.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock response with different token_type + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "Bearer", # Not N_A + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should succeed but log warning + import logging + + with patch.object( + logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" + ) as mock_warning: + id_jag = await provider.exchange_token_for_id_jag(mock_client) + assert id_jag == sample_id_jag + mock_warning.assert_called_once() + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_authentication(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant with client authentication.""" + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with secret + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret="test-client-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify token was returned + assert token.access_token == "mcp-access-token-12345" + + # Verify client credentials were included + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert call_args[1]["data"]["client_secret"] == "test-client-secret" + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_id_only(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant with client_id but no client_secret (covers branch 304->307).""" + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info WITHOUT secret (client_secret=None) + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret=None, # No secret + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify token was returned correctly + assert token.access_token == "mcp-access-token-12345" + assert token.token_type == "Bearer" + + # Verify client_id was included but NOT client_secret + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert "client_secret" not in call_args[1]["data"] + + +@pytest.mark.anyio +async def test_exchange_id_jag_error_response(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant with error response.""" + from mcp.shared.auth import OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock error response + mock_response = httpx.Response( + status_code=400, + json={ + "error": "invalid_grant", + "error_description": "Invalid assertion", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="JWT bearer grant failed"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_exchange_id_jag_non_json_error(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant with non-JSON error response.""" + from mcp.shared.auth import OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock error response with non-JSON content + mock_response = httpx.Response( + status_code=503, + content=b"Service Unavailable", + headers={"content-type": "text/html"}, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError with default error + with pytest.raises(OAuthTokenError, match="JWT bearer grant failed: unknown_error"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_exchange_id_jag_http_error(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant with HTTP error.""" + from mcp.shared.auth import OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(side_effect=httpx.ReadTimeout("Request timeout")) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="HTTP error during JWT bearer grant"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_exchange_token_with_client_info_but_no_client_id( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): + """Test token exchange when client_info exists but client_id is None (covers line 231).""" + from mcp.shared.auth import OAuthClientInformationFull + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with client_id=None + provider.context.client_info = OAuthClientInformationFull( + client_id=None, # This should skip the client_id assignment + client_secret="test-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + + # Verify client_id was not included (None), but client_secret was included + call_args = mock_client.post.call_args + assert "client_id" not in call_args[1]["data"] + assert call_args[1]["data"]["client_secret"] == "test-secret" + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_info_but_no_client_id(sample_id_jag: str, mock_token_storage: Any): + """Test ID-JAG exchange when client_info exists but client_id is None (covers line 302).""" + from pydantic import AnyHttpUrl + + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Set client info with client_id=None + provider.context.client_info = OAuthClientInformationFull( + client_id=None, # This should skip the client_id assignment + client_secret="test-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify + assert token.access_token == "mcp-access-token-12345" + assert token.token_type == "Bearer" + assert token.expires_in == 3600 + + # Verify client_id was not included (None), but client_secret was included + call_args = mock_client.post.call_args + assert "client_id" not in call_args[1]["data"] + assert call_args[1]["data"]["client_secret"] == "test-secret" + + +def test_validate_token_exchange_params_missing_audience(): + """Test validation fails for missing audience.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="", + resource="https://server.example/", + ) + + with pytest.raises(ValueError, match="audience is required"): + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_missing_resource(): + """Test validation fails for missing resource.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="https://auth.example/", + resource="", + ) + + with pytest.raises(ValueError, match="resource is required"): + validate_token_exchange_params(params)