Skip to content

Commit 457f6ee

Browse files
committed
fix(client): preserve existing query params on OAuth authorization_endpoint
Closes #2776 The authorization code grant built the redirect URL with `f"{auth_endpoint}?{urlencode(auth_params)}"`, which produces an invalid URL when the server-advertised authorization_endpoint already carries a query string. For example Salesforce advertises `.../services/oauth2/authorize?prompt=select_account`, yielding `...authorize?prompt=select_account?response_type=code&...` (two `?` separators), so the client navigates to a malformed URL and the server rejects the request. Fix: parse the endpoint, merge its existing query params with the flow-generated auth_params (flow params win on conflict), and re-encode into a single well-formed query string. None-valued params are dropped rather than serialized as the literal "None". Tests: add TestAuthorizationEndpointWithQuery covering the helper (no/with/conflicting existing query) plus an end-to-end _perform_authorization_code_grant assertion that the captured redirect URL preserves the server param and stays well-formed. 101 passed.
1 parent 4472428 commit 457f6ee

2 files changed

Lines changed: 102 additions & 3 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import secrets
1010
import string
1111
import time
12-
from collections.abc import AsyncGenerator, Awaitable, Callable
12+
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
1313
from dataclasses import dataclass, field
1414
from typing import Any, Protocol
15-
from urllib.parse import quote, urlencode, urljoin, urlparse
15+
from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse
1616

1717
import anyio
1818
import httpx
@@ -59,6 +59,22 @@
5959
logger = logging.getLogger(__name__)
6060

6161

62+
def _build_authorization_url(auth_endpoint: str, auth_params: Mapping[str, str | None]) -> str:
63+
"""Build an authorization URL, preserving any query params already on the endpoint.
64+
65+
Servers may advertise an ``authorization_endpoint`` that already carries query
66+
parameters (e.g. ``https://example.com/authorize?prompt=select_account``).
67+
Naively appending ``?<params>`` would produce an invalid URL with two ``?``
68+
separators, so the existing query is parsed and merged with ``auth_params``.
69+
Flow-generated params take precedence on key conflicts; ``None`` values are
70+
dropped rather than serialized as the literal string ``"None"``.
71+
"""
72+
parsed = urlparse(auth_endpoint)
73+
merged_params = dict(parse_qsl(parsed.query, keep_blank_values=True))
74+
merged_params.update({key: value for key, value in auth_params.items() if value is not None})
75+
return urlunparse(parsed._replace(query=urlencode(merged_params)))
76+
77+
6278
class PKCEParameters(BaseModel):
6379
"""PKCE (Proof Key for Code Exchange) parameters."""
6480

@@ -357,7 +373,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
357373
if "offline_access" in self.context.client_metadata.scope.split():
358374
auth_params["prompt"] = "consent"
359375

360-
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
376+
authorization_url = _build_authorization_url(auth_endpoint, auth_params)
361377
await self.context.redirect_handler(authorization_url)
362378

363379
# Wait for callback

tests/client/test_auth.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from mcp.client.auth import OAuthClientProvider, PKCEParameters
1515
from mcp.client.auth.exceptions import OAuthFlowError
16+
from mcp.client.auth.oauth2 import _build_authorization_url
1617
from mcp.client.auth.utils import (
1718
build_oauth_authorization_server_metadata_discovery_urls,
1819
build_protected_resource_metadata_discovery_urls,
@@ -2833,3 +2834,85 @@ def test_credentials_match_issuer_url_shaped_dcr_id_is_not_portable():
28332834
issuer="https://as.example.com",
28342835
)
28352836
assert credentials_match_issuer(info, "https://other", "https://client.example/metadata.json") is False
2837+
2838+
class TestAuthorizationEndpointWithQuery:
2839+
"""Regression tests for #2776 - authorization_endpoint carrying query params."""
2840+
2841+
def test_build_authorization_url_no_existing_query(self):
2842+
url = _build_authorization_url(
2843+
"https://auth.example.com/authorize",
2844+
{"response_type": "code", "client_id": "abc"},
2845+
)
2846+
parsed = urlparse(url)
2847+
params = parse_qs(parsed.query)
2848+
assert parsed.path == "/authorize"
2849+
assert params["response_type"] == ["code"]
2850+
assert params["client_id"] == ["abc"]
2851+
# No malformed double "?" separator.
2852+
assert url.count("?") == 1
2853+
2854+
def test_build_authorization_url_preserves_existing_query(self):
2855+
# e.g. Salesforce advertises .../authorize?prompt=select_account
2856+
url = _build_authorization_url(
2857+
"https://test.salesforce.com/services/oauth2/authorize?prompt=select_account",
2858+
{"response_type": "code", "client_id": "abc"},
2859+
)
2860+
parsed = urlparse(url)
2861+
params = parse_qs(parsed.query)
2862+
assert parsed.path == "/services/oauth2/authorize"
2863+
# The server-provided param survives...
2864+
assert params["prompt"] == ["select_account"]
2865+
# ...alongside the flow-generated params.
2866+
assert params["response_type"] == ["code"]
2867+
assert params["client_id"] == ["abc"]
2868+
# Exactly one "?" - the old f-string produced "...?prompt=...?response_type=...".
2869+
assert url.count("?") == 1
2870+
2871+
def test_build_authorization_url_flow_params_win_on_conflict(self):
2872+
url = _build_authorization_url(
2873+
"https://auth.example.com/authorize?response_type=token",
2874+
{"response_type": "code"},
2875+
)
2876+
params = parse_qs(urlparse(url).query)
2877+
assert params["response_type"] == ["code"]
2878+
2879+
@pytest.mark.anyio
2880+
async def test_perform_authorization_preserves_endpoint_query(self, oauth_provider: OAuthClientProvider):
2881+
"""End-to-end: redirect URL stays valid when the endpoint has a query string."""
2882+
oauth_provider.context.oauth_metadata = OAuthMetadata(
2883+
issuer=AnyHttpUrl("https://test.salesforce.com"),
2884+
authorization_endpoint=AnyHttpUrl(
2885+
"https://test.salesforce.com/services/oauth2/authorize?prompt=select_account"
2886+
),
2887+
token_endpoint=AnyHttpUrl("https://test.salesforce.com/services/oauth2/token"),
2888+
)
2889+
oauth_provider.context.client_info = OAuthClientInformationFull(
2890+
client_id="test_client_id",
2891+
client_secret="test_client_secret",
2892+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2893+
)
2894+
2895+
captured_url: str | None = None
2896+
captured_state: str | None = None
2897+
2898+
async def capture_redirect(url: str) -> None:
2899+
nonlocal captured_url, captured_state
2900+
captured_url = url
2901+
captured_state = parse_qs(urlparse(url).query).get("state", [None])[0]
2902+
2903+
async def mock_callback() -> AuthorizationCodeResult:
2904+
return AuthorizationCodeResult(code="test_auth_code", state=captured_state)
2905+
2906+
oauth_provider.context.redirect_handler = capture_redirect
2907+
oauth_provider.context.callback_handler = mock_callback
2908+
2909+
await oauth_provider._perform_authorization_code_grant()
2910+
2911+
assert captured_url is not None
2912+
parsed = urlparse(captured_url)
2913+
params = parse_qs(parsed.query)
2914+
assert parsed.path == "/services/oauth2/authorize"
2915+
assert params["prompt"] == ["select_account"]
2916+
assert params["response_type"] == ["code"]
2917+
assert params["client_id"] == ["test_client_id"]
2918+
assert captured_url.count("?") == 1

0 commit comments

Comments
 (0)