From ffef4a6ee02947f8c565f335d7c7dd7198accd84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=AF=E5=9F=BA=E9=AD=81?= <1412414664@qq.com> Date: Tue, 23 Jun 2026 23:38:04 +0800 Subject: [PATCH] fix: request JSON token responses --- src/mcp/client/auth/extensions/client_credentials.py | 9 ++++----- src/mcp/client/auth/oauth2.py | 5 +++-- src/mcp/client/auth/utils.py | 7 +++++++ tests/client/auth/extensions/test_client_credentials.py | 3 +++ tests/client/test_auth.py | 2 ++ 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py index 5efd596110..5ff94f11e6 100644 --- a/src/mcp/client/auth/extensions/client_credentials.py +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -18,6 +18,7 @@ from pydantic import BaseModel, Field from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage +from mcp.client.auth.utils import create_token_request_headers from mcp.shared.auth import AuthorizationCodeResult, OAuthClientInformationFull, OAuthClientMetadata @@ -92,7 +93,7 @@ async def _exchange_token_client_credentials(self) -> httpx.Request: "grant_type": "client_credentials", } - headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + headers = create_token_request_headers() # Use standard auth methods (client_secret_basic, client_secret_post, none) token_data, headers = self.context.prepare_token_auth(token_data, headers) @@ -320,7 +321,7 @@ async def _exchange_token_client_credentials(self) -> httpx.Request: "grant_type": "client_credentials", } - headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + headers = create_token_request_headers() # Add JWT client authentication (RFC 7523 Section 2.2) await self._add_client_authentication_jwt(token_data=token_data) @@ -480,6 +481,4 @@ async def _exchange_token_jwt_bearer(self) -> httpx.Request: token_data["scope"] = self.context.client_metadata.scope token_url = self._get_token_endpoint() - return httpx.Request( - "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} - ) + return httpx.Request("POST", token_url, data=token_data, headers=create_token_request_headers()) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 39858cba44..3c3609667b 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -25,6 +25,7 @@ create_client_info_from_metadata_url, create_client_registration_request, create_oauth_metadata_request, + create_token_request_headers, credentials_match_issuer, extract_field_from_www_auth, extract_resource_metadata_from_www_auth, @@ -409,7 +410,7 @@ async def _exchange_token_authorization_code( token_data["resource"] = self.context.get_resource_url() # RFC 8707 # Prepare authentication based on preferred method - headers = {"Content-Type": "application/x-www-form-urlencoded"} + headers = create_token_request_headers() token_data, headers = self.context.prepare_token_auth(token_data, headers) return httpx.Request("POST", token_url, data=token_data, headers=headers) @@ -461,7 +462,7 @@ async def _refresh_token(self) -> httpx.Request: refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 # Prepare authentication based on preferred method - headers = {"Content-Type": "application/x-www-form-urlencoded"} + headers = create_token_request_headers() refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers) return httpx.Request("POST", token_url, data=refresh_data, headers=headers) diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index f10264a330..2273d12306 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -16,6 +16,13 @@ from mcp.types import LATEST_PROTOCOL_VERSION +def create_token_request_headers() -> dict[str, str]: + return { + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", + } + + def extract_field_from_www_auth(response: Response, field_name: str) -> str | None: """Extract field from WWW-Authenticate header. diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index a964891316..a348b4b316 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -108,6 +108,7 @@ async def test_token_exchange_request_jwt_predefined(self, rfc7523_oauth_provide assert request.method == "POST" assert str(request.url) == "https://api.example.com/token" + assert request.headers["Accept"] == "application/json" assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" # Check form data @@ -252,6 +253,7 @@ async def test_exchange_token_client_credentials(self, mock_storage: MockTokenSt assert request.method == "POST" assert str(request.url) == "https://api.example.com/token" + assert request.headers["Accept"] == "application/json" content = urllib.parse.unquote_plus(request.content.decode()) assert "grant_type=client_credentials" in content @@ -398,6 +400,7 @@ async def mock_assertion_provider(audience: str) -> str: assert request.method == "POST" assert str(request.url) == "https://auth.example.com/token" + assert request.headers["Accept"] == "application/json" content = urllib.parse.unquote_plus(request.content.decode()) assert "grant_type=client_credentials" in content diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index cdbba1b588..e5f0e92275 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -602,6 +602,7 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider: O assert request.method == "POST" assert str(request.url) == "https://api.example.com/token" + assert request.headers["Accept"] == "application/json" assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" # Check form data @@ -628,6 +629,7 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, assert request.method == "POST" assert str(request.url) == "https://api.example.com/token" + assert request.headers["Accept"] == "application/json" assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" # Check form data