From cb9e2605cf7615bd695009e41f6059ce2ddfba6c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 22 Jun 2026 14:10:03 +0000 Subject: [PATCH 01/22] Rename SUPPORTED_PROTOCOL_VERSIONS to HANDSHAKE_PROTOCOL_VERSIONS; consolidate header constant - HANDSHAKE_PROTOCOL_VERSIONS names what the constant actually holds (versions reachable via the initialize handshake); SUPPORTED_PROTOCOL_VERSIONS survives as a deprecated union of HANDSHAKE + MODERN for v1.x compatibility - The three handshake-ceiling call sites (initialize offer, server negotiate fallback, for_loop seed) now read HANDSHAKE_PROTOCOL_VERSIONS[-1] instead of LATEST_PROTOCOL_VERSION - Era-routing in the streamable-HTTP manager reads HANDSHAKE_PROTOCOL_VERSIONS (interim; the body-primary classifier is the structural fix) - mcp-protocol-version header constant: three duplicate definitions collapsed to the single MCP_PROTOCOL_VERSION_HEADER in shared/inbound; client and server importers point at the canonical module - migration.md documents the SUPPORTED deprecation --- docs/migration.md | 4 ++++ src/mcp/client/auth/oauth2.py | 4 ++-- src/mcp/client/auth/utils.py | 4 ++-- src/mcp/client/session.py | 6 +++--- src/mcp/client/streamable_http.py | 4 ++-- src/mcp/server/auth/routes.py | 2 +- src/mcp/server/connection.py | 8 +++++--- src/mcp/server/runner.py | 5 ++--- src/mcp/server/streamable_http.py | 4 ++-- src/mcp/server/streamable_http_manager.py | 6 +++--- src/mcp/shared/version.py | 17 +++++++++++++---- tests/client/test_session.py | 4 ++-- tests/client/test_streamable_http.py | 6 +++--- tests/server/test_connection.py | 8 ++++---- tests/server/test_runner.py | 4 ++-- tests/shared/test_version.py | 6 ++++-- 16 files changed, 54 insertions(+), 38 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index bf06690c45..19c534aeb8 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -764,6 +764,10 @@ async def my_tool(ctx: Context) -> str: ... async def my_tool(ctx: Context[MyLifespanState]) -> str: ... ``` +### Version constants + +`SUPPORTED_PROTOCOL_VERSIONS` is deprecated — it's now the union of `HANDSHAKE_PROTOCOL_VERSIONS` (initialize-handshake versions) and `MODERN_PROTOCOL_VERSIONS` (per-request-envelope versions). If you were using it to mean "versions the initialize handshake accepts", switch to `HANDSHAKE_PROTOCOL_VERSIONS`. + ### `ProgressContext` and `progress()` context manager removed The `mcp.shared.progress` module (`ProgressContext`, `Progress`, and the `progress()` context manager) has been removed. This module had no real-world adoption — all users send progress notifications via `Context.report_progress()` or `session.send_progress_notification()` directly. diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 39858cba44..f21231b924 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -40,7 +40,6 @@ validate_authorization_response_iss, validate_metadata_issuer, ) -from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( AuthorizationCodeResult, OAuthClientInformationFull, @@ -54,6 +53,7 @@ check_resource_allowed, resource_url_from_server_url, ) +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.version import is_version_at_least logger = logging.getLogger(__name__) @@ -534,7 +534,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. await self._initialize() # Capture protocol version from request headers - self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) + self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index f10264a330..16f711dd45 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -5,7 +5,6 @@ from pydantic import AnyUrl, ValidationError from mcp.client.auth import OAuthFlowError, OAuthRegistrationError, OAuthTokenError -from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, @@ -13,6 +12,7 @@ OAuthToken, ProtectedResourceMetadata, ) +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.types import LATEST_PROTOCOL_VERSION @@ -273,7 +273,7 @@ def validate_metadata_issuer(oauth_metadata: OAuthMetadata, expected_issuer: str def create_oauth_metadata_request(url: str) -> Request: - return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + return Request("GET", url, headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_PROTOCOL_VERSION}) def create_client_registration_request( diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 4b24e98b1d..b9ebf6c9e1 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -21,7 +21,7 @@ from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.transport_context import TransportContext -from mcp.shared.version import MODERN_PROTOCOL_VERSIONS, SUPPORTED_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, @@ -320,7 +320,7 @@ async def initialize(self) -> types.InitializeResult: params=types.InitializeRequestParams( protocol_version=self._pinned_version if self._pinned_version is not None - else types.LATEST_PROTOCOL_VERSION, + else HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=capabilities, client_info=self._client_info, ), @@ -328,7 +328,7 @@ async def initialize(self) -> types.InitializeResult: types.InitializeResult, ) - if result.protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: + if result.protocol_version not in HANDSHAKE_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocol_version}") self._initialize_result = result diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index a703a48afb..93a99831d4 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -20,6 +20,7 @@ from mcp.shared._compat import resync_tracer from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.version import MODERN_PROTOCOL_VERSIONS from mcp.types import ( @@ -46,7 +47,6 @@ StreamReader = ContextReceiveStream[SessionMessage] MCP_SESSION_ID = "mcp-session-id" -MCP_PROTOCOL_VERSION = "mcp-protocol-version" MCP_METHOD = "mcp-method" MCP_NAME = "mcp-name" LAST_EVENT_ID = "last-event-id" @@ -138,7 +138,7 @@ def _prepare_headers(self) -> dict[str, str]: if self.session_id: headers[MCP_SESSION_ID] = self.session_id if self.protocol_version: - headers[MCP_PROTOCOL_VERSION] = self.protocol_version + headers[MCP_PROTOCOL_VERSION_HEADER] = self.protocol_version return headers def _is_initialization_request(self, message: JSONRPCMessage) -> bool: diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index a72e819477..d88b6d1b13 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -17,8 +17,8 @@ from mcp.server.auth.middleware.client_auth import ClientAuthenticator from mcp.server.auth.provider import OAuthAuthorizationServerProvider from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions -from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER def validate_issuer_url(url: AnyHttpUrl): diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index f5bfc18dfb..7412ccfc22 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -31,8 +31,8 @@ from mcp.shared.dispatcher import CallOptions, Outbound from mcp.shared.exceptions import MCPDeprecationWarning, NoBackChannelError from mcp.shared.peer import Meta, dump_params +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS from mcp.types import ( - LATEST_PROTOCOL_VERSION, ClientCapabilities, CreateMessageRequest, CreateMessageResult, @@ -192,12 +192,14 @@ def for_loop( Not born-ready: `initialized` is set later by the kernel when `notifications/initialized` arrives. `protocol_version` is seeded from - the transport hint (or `LATEST_PROTOCOL_VERSION`) so it's never `None`; + the transport hint (or `HANDSHAKE_PROTOCOL_VERSIONS[-1]`) so it's never `None`; the handshake overwrites it once negotiated. """ return cls( outbound, - protocol_version=protocol_version_hint if protocol_version_hint is not None else LATEST_PROTOCOL_VERSION, + protocol_version=protocol_version_hint + if protocol_version_hint is not None + else HANDSHAKE_PROTOCOL_VERSIONS[-1], session_id=session_id, ) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index bbc16abe9f..f8c8d2ae44 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -36,11 +36,10 @@ from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher, handler_exception_to_error_data from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.transport_context import TransportContext -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS from mcp.types import ( INTERNAL_ERROR, INVALID_PARAMS, - LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, ErrorData, Implementation, @@ -414,7 +413,7 @@ def _negotiate_initialize(params: Mapping[str, Any] | None) -> tuple[InitializeR """Validate `initialize` params and pick the protocol version.""" init = InitializeRequestParams.model_validate(params or {}, by_name=False) requested = init.protocol_version - negotiated = requested if requested in SUPPORTED_PROTOCOL_VERSIONS else LATEST_PROTOCOL_VERSION + negotiated = requested if requested in HANDSHAKE_PROTOCOL_VERSIONS else HANDSHAKE_PROTOCOL_VERSIONS[-1] return init, negotiated def _handle_initialize(self, params: Mapping[str, Any] | None) -> InitializeResult: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f8aec6c9e2..aa682cbf2a 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -28,6 +28,7 @@ from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.version import is_version_at_least from mcp.types import ( @@ -50,7 +51,6 @@ # Header names MCP_SESSION_ID_HEADER = "mcp-session-id" -MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" LAST_EVENT_ID_HEADER = "last-event-id" # Content types @@ -818,7 +818,7 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non async def _validate_request_headers(self, request: Request, send: Send) -> bool: # Protocol-version validation lives in the manager's era-routing: only - # values in `SUPPORTED_PROTOCOL_VERSIONS` (or no header at all) reach + # values in `HANDSHAKE_PROTOCOL_VERSIONS` (or no header at all) reach # this transport, so the legacy version-gate is gone. return await self._validate_session(request, send) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 648dcc827f..f9329f8564 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -19,16 +19,16 @@ from mcp.server.connection import Connection from mcp.server.runner import serve_connection, serve_loop from mcp.server.streamable_http import ( - MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, EventStore, StreamableHTTPServerTransport, ) from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._compat import resync_tracer +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.transport_context import TransportContext -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS from mcp.types import DEFAULT_NEGOTIATED_VERSION, INVALID_REQUEST, ErrorData, JSONRPCError if TYPE_CHECKING: @@ -169,7 +169,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No # and return a structured rejection. 2025 paths below remain unchanged. header = MCP_PROTOCOL_VERSION_HEADER.encode("ascii") pv = next((v.decode("latin-1") for k, v in scope["headers"] if k == header), None) - if pv is not None and pv not in SUPPORTED_PROTOCOL_VERSIONS: + if pv is not None and pv not in HANDSHAKE_PROTOCOL_VERSIONS: await handle_modern_request(self.app, self.security_settings, self._lifespan_state, scope, receive, send) return diff --git a/src/mcp/shared/version.py b/src/mcp/shared/version.py index 09aacb6956..2232c1d3c4 100644 --- a/src/mcp/shared/version.py +++ b/src/mcp/shared/version.py @@ -9,8 +9,6 @@ from typing import Final -from mcp.types import LATEST_PROTOCOL_VERSION - KNOWN_PROTOCOL_VERSIONS: Final[tuple[str, ...]] = ( "2024-11-05", "2025-03-26", @@ -20,11 +18,22 @@ ) """Every released protocol revision, oldest to newest.""" +HANDSHAKE_PROTOCOL_VERSIONS: Final[tuple[str, ...]] = ( + "2024-11-05", + "2025-03-26", + "2025-06-18", + "2025-11-25", +) +"""Protocol revisions reachable via the initialize handshake.""" + MODERN_PROTOCOL_VERSIONS: Final[tuple[str, ...]] = ("2026-07-28",) """Protocol revisions that use the stateless per-request envelope.""" -SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", "2025-06-18", LATEST_PROTOCOL_VERSION] -"""Protocol revisions this SDK can negotiate.""" +SUPPORTED_PROTOCOL_VERSIONS: tuple[str, ...] = (*HANDSHAKE_PROTOCOL_VERSIONS, *MODERN_PROTOCOL_VERSIONS) +"""Deprecated: prefer HANDSHAKE_PROTOCOL_VERSIONS or MODERN_PROTOCOL_VERSIONS. + +Kept as the union for v1.x compatibility. +""" def is_version_at_least(version: str, minimum: str) -> bool: diff --git a/tests/client/test_session.py b/tests/client/test_session.py index c171360de2..1c3767303f 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -18,7 +18,7 @@ from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.transport_context import TransportContext -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS from mcp.types import ( CONNECTION_CLOSED, INTERNAL_ERROR, @@ -313,7 +313,7 @@ async def mock_server(): # Assert the result with negotiated version assert isinstance(result, InitializeResult) assert result.protocol_version == "2024-11-05" - assert result.protocol_version in SUPPORTED_PROTOCOL_VERSIONS + assert result.protocol_version in HANDSHAKE_PROTOCOL_VERSIONS @pytest.mark.anyio diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index bbe3e67fee..b13de4a5ee 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -16,11 +16,11 @@ from mcp.client import ClientSession from mcp.client.streamable_http import ( - MCP_PROTOCOL_VERSION, StreamableHTTPTransport, _encode_header_value, streamable_http_client, ) +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse @@ -160,7 +160,7 @@ def test_stateful_constructor_pin_is_ignored_and_the_negotiated_version_wins() - """A pre-2026 pin is a session-layer concern; the transport must not stamp it on the initialize request and must adopt the server's negotiated version for later headers.""" transport = StreamableHTTPTransport("http://test/mcp", protocol_version="2025-06-18") - assert MCP_PROTOCOL_VERSION not in transport._prepare_headers() # pyright: ignore[reportPrivateUsage] + assert MCP_PROTOCOL_VERSION_HEADER not in transport._prepare_headers() # pyright: ignore[reportPrivateUsage] init = JSONRPCResponse( jsonrpc="2.0", id=1, @@ -172,4 +172,4 @@ def test_stateful_constructor_pin_is_ignored_and_the_negotiated_version_wins() - ) transport._maybe_extract_protocol_version_from_message(init) # pyright: ignore[reportPrivateUsage] assert transport.protocol_version == "2025-03-26" - assert transport._prepare_headers()[MCP_PROTOCOL_VERSION] == "2025-03-26" # pyright: ignore[reportPrivateUsage] + assert transport._prepare_headers()[MCP_PROTOCOL_VERSION_HEADER] == "2025-03-26" # pyright: ignore[reportPrivateUsage] diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index 8ca1ae8a7a..5e504ff4d7 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -19,7 +19,7 @@ from mcp.server.connection import Connection from mcp.shared.dispatcher import CallOptions from mcp.shared.exceptions import NoBackChannelError -from mcp.shared.version import MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, @@ -115,10 +115,10 @@ def test_from_envelope_with_explicit_outbound_has_standalone_channel(): def test_for_loop_seeds_version_from_hint_or_latest_and_is_not_born_ready(): """SDK-defined: `for_loop` seeds `protocol_version` from the hint when given, - else `LATEST_PROTOCOL_VERSION`; the connection awaits the initialize handshake.""" + else `HANDSHAKE_PROTOCOL_VERSIONS[-1]`; the connection awaits the initialize handshake.""" out = StubOutbound() conn = Connection.for_loop(out) - assert conn.protocol_version == LATEST_PROTOCOL_VERSION + assert conn.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] assert conn.has_standalone_channel is True assert not conn.initialized.is_set() assert conn.initialize_accepted is False @@ -229,7 +229,7 @@ async def test_send_request_validates_the_client_result_against_the_surface_sche async def test_send_request_passes_a_spec_valid_client_result(): """A spec-valid client result passes the surface gate and parses to the typed model.""" conn = Connection.for_loop(StubOutbound(result={"roots": [{"uri": "file:///ws"}]})) - assert conn.protocol_version == LATEST_PROTOCOL_VERSION + assert conn.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] result = await conn.send_request(ListRootsRequest()) assert isinstance(result, ListRootsResult) assert str(result.roots[0].uri) == "file:///ws" diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index c5a99ae07a..50f9d0a2ec 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -39,7 +39,7 @@ from mcp.shared.message import MessageMetadata from mcp.shared.peer import dump_params from mcp.shared.transport_context import TransportContext -from mcp.shared.version import MODERN_PROTOCOL_VERSIONS, SUPPORTED_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( INTERNAL_ERROR, INVALID_PARAMS, @@ -998,7 +998,7 @@ async def list_tools(ctx: Ctx, params: PaginatedRequestParams | None) -> ListToo @pytest.mark.anyio async def test_runner_initialize_echoes_supported_version_and_falls_back_to_latest(server: SrvT): - oldest = SUPPORTED_PROTOCOL_VERSIONS[0] + oldest = HANDSHAKE_PROTOCOL_VERSIONS[0] async with connected_runner(server, initialized=False) as (client, _): params = {**_initialize_params(), "protocolVersion": oldest} result = await client.send_raw_request("initialize", params) diff --git a/tests/shared/test_version.py b/tests/shared/test_version.py index baffa032fe..595bb03bc0 100644 --- a/tests/shared/test_version.py +++ b/tests/shared/test_version.py @@ -3,8 +3,9 @@ import pytest from mcp.shared.version import ( + HANDSHAKE_PROTOCOL_VERSIONS, KNOWN_PROTOCOL_VERSIONS, - SUPPORTED_PROTOCOL_VERSIONS, + MODERN_PROTOCOL_VERSIONS, is_version_at_least, ) @@ -51,7 +52,8 @@ def test_is_version_at_least_matches_lexicographic_for_known_versions(version: s def test_supported_versions_are_known() -> None: """Every negotiable revision must be in the ordering registry.""" - assert set(SUPPORTED_PROTOCOL_VERSIONS) <= set(KNOWN_PROTOCOL_VERSIONS) + assert set(HANDSHAKE_PROTOCOL_VERSIONS) <= set(KNOWN_PROTOCOL_VERSIONS) + assert set(MODERN_PROTOCOL_VERSIONS) <= set(KNOWN_PROTOCOL_VERSIONS) def test_known_versions_are_strictly_ordered() -> None: From e0990bff2f3b2b4cd2c064b38e04f9c4e5850ff8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 22 Jun 2026 14:35:21 +0000 Subject: [PATCH 02/22] Thread per-message headers from session to transport via CallOptions/metadata sidecar Additive infrastructure for the client-side outbound stamp: - CallOptions gains a headers key; ClientMessageMetadata gains a headers field - _plan_outbound projects opts['headers'] onto the metadata (same path resumption tokens take); JSONRPCDispatcher.notify accepts opts and threads headers through - Outbound.notify Protocol grows opts=None; all implementers updated (Connection, _NoChannelOutbound, _SingleExchangeDispatchContext, peer, context, DirectDispatcher, test stubs) - StreamableHTTPTransport's POST path merges metadata.headers into the request (alongside existing _prepare_headers/_per_message_headers, which are removed in the next commit) - MCP_METHOD_HEADER, MCP_NAME_HEADER, encode_header_value moved to shared/inbound (single source for the header names) - Tests pin both new paths --- src/mcp/client/streamable_http.py | 22 ++++----------- src/mcp/server/_streamable_http_modern.py | 2 +- src/mcp/server/connection.py | 6 ++-- src/mcp/shared/context.py | 4 +-- src/mcp/shared/direct_dispatcher.py | 7 +++-- src/mcp/shared/dispatcher.py | 5 +++- src/mcp/shared/inbound.py | 22 +++++++++++++++ src/mcp/shared/jsonrpc_dispatcher.py | 12 +++++--- src/mcp/shared/message.py | 2 ++ src/mcp/shared/peer.py | 4 +-- tests/client/test_session.py | 4 +-- tests/client/test_streamable_http.py | 34 +++++++++++++++++++++-- tests/server/test_connection.py | 2 +- tests/server/test_runner.py | 2 +- tests/server/test_session.py | 2 +- tests/server/test_stateless_mode.py | 2 +- tests/shared/test_jsonrpc_dispatcher.py | 33 ++++++++++++++++++++++ tests/shared/test_peer.py | 2 +- 18 files changed, 124 insertions(+), 43 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 93a99831d4..b6e6e105aa 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -2,10 +2,8 @@ from __future__ import annotations as _annotations -import base64 import contextlib import logging -import re from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager from dataclasses import dataclass @@ -20,7 +18,7 @@ from mcp.shared._compat import resync_tracer from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._httpx_utils import create_mcp_http_client -from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER +from mcp.shared.inbound import MCP_METHOD_HEADER, MCP_NAME_HEADER, MCP_PROTOCOL_VERSION_HEADER, encode_header_value from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.version import MODERN_PROTOCOL_VERSIONS from mcp.types import ( @@ -47,24 +45,12 @@ StreamReader = ContextReceiveStream[SessionMessage] MCP_SESSION_ID = "mcp-session-id" -MCP_METHOD = "mcp-method" -MCP_NAME = "mcp-name" LAST_EVENT_ID = "last-event-id" # Reconnection defaults DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up -_B64_SENTINEL = re.compile(r"^=\?base64\?.*\?=$") -# RFC 7230 token chars minus DEL; visible ASCII 0x20-0x7E is the practical bound for a header value. -_HEADER_SAFE = re.compile(r"^[\x20-\x7E]*$") - - -def _encode_header_value(value: str) -> str: - if _HEADER_SAFE.fullmatch(value) and value == value.strip() and not _B64_SENTINEL.fullmatch(value): - return value - return f"=?base64?{base64.b64encode(value.encode('utf-8')).decode('ascii')}?=" - class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" @@ -112,7 +98,7 @@ def _per_message_headers(self, message: JSONRPCMessage) -> dict[str, str]: return {} if not isinstance(message, JSONRPCRequest | JSONRPCNotification): return {} - headers: dict[str, str] = {MCP_METHOD: message.method} + headers: dict[str, str] = {MCP_METHOD_HEADER: message.method} # TODO: Mcp-Name is also REQUIRED for prompts/get (params.name) and resources/read # (params.uri); a method->param-key map replaces this gate when those land. if ( @@ -121,7 +107,7 @@ def _per_message_headers(self, message: JSONRPCMessage) -> dict[str, str]: and message.params and isinstance(name := message.params.get("name"), str) ): - headers[MCP_NAME] = _encode_header_value(name) + headers[MCP_NAME_HEADER] = encode_header_value(name) return headers def _prepare_headers(self) -> dict[str, str]: @@ -302,6 +288,8 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: headers = self._prepare_headers() message = ctx.session_message.message headers.update(self._per_message_headers(message)) + if ctx.metadata is not None and ctx.metadata.headers is not None: + headers.update(ctx.metadata.headers) is_initialization = self._is_initialization_request(message) async with ctx.client.stream( diff --git a/src/mcp/server/_streamable_http_modern.py b/src/mcp/server/_streamable_http_modern.py index 6a81786b3b..2ddeb78ed3 100644 --- a/src/mcp/server/_streamable_http_modern.py +++ b/src/mcp/server/_streamable_http_modern.py @@ -77,7 +77,7 @@ async def send_raw_request( ) -> dict[str, Any]: raise NoBackChannelError(method) - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: # TODO(D-005a): buffer and stream as SSE once the JSON-vs-SSE response mode lands. return None diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index 7412ccfc22..06691535c0 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -93,7 +93,7 @@ async def send_raw_request( ) -> dict[str, Any]: raise NoBackChannelError(method) - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: logger.debug("dropped %s: no standalone channel", method) @@ -277,14 +277,14 @@ async def send_request( cls = result_type if result_type is not None else _RESULT_FOR[type(req)] return cls.model_validate(raw, by_name=False) - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: """Send a best-effort notification on the standalone stream. Never raises. If there's no standalone channel or the stream is broken, the notification is dropped and debug-logged. """ try: - await self.outbound.notify(method, params) + await self.outbound.notify(method, params, opts) except (anyio.BrokenResourceError, anyio.ClosedResourceError): logger.debug("dropped %s: standalone stream closed", method) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 11a0aae0a2..eef1fa3855 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -73,9 +73,9 @@ async def send_raw_request( """ return await self._dctx.send_raw_request(method, params, opts) - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: """Send a notification to the peer on the back-channel.""" - await self._dctx.notify(method, params) + await self._dctx.notify(method, params, opts) async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: """Report progress for this request, if the peer supplied a progress token. diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 4460be4e0d..d521840bef 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -63,7 +63,7 @@ class _DirectDispatchContext: def can_send_request(self) -> bool: return self.transport.can_send_request - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: await self._back_notify(method, params) async def send_raw_request( @@ -133,12 +133,13 @@ async def send_raw_request( raise RuntimeError("DirectDispatcher.send_raw_request called before run()") return await self._peer._dispatch_request(method, params, opts) - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: """Send a notification by invoking the peer's `on_notify` directly. Fire-and-forget: usable before `run()` (delivery waits for the peer to start), and after close it is silently dropped, matching - `JSONRPCDispatcher.notify`. + `JSONRPCDispatcher.notify`. `opts` is accepted for `Dispatcher` + conformance; there is no HTTP layer here so `headers` is ignored. """ if self._peer is None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index 888e55ba33..8b343f2555 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -85,6 +85,9 @@ class CallOptions(TypedDict, total=False): resumption is removed in the next protocol revision. """ + headers: dict[str, str] + """Transport-layer hint: HTTP transports merge these onto the outgoing request; non-HTTP transports ignore.""" + @runtime_checkable class Outbound(Protocol): @@ -111,7 +114,7 @@ async def send_raw_request( """ ... - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: """Send a fire-and-forget notification.""" ... diff --git a/src/mcp/shared/inbound.py b/src/mcp/shared/inbound.py index 04aa93c141..fb4c765246 100644 --- a/src/mcp/shared/inbound.py +++ b/src/mcp/shared/inbound.py @@ -8,6 +8,8 @@ status. """ +import base64 +import re from collections.abc import Mapping, Sequence from dataclasses import dataclass from types import MappingProxyType @@ -34,13 +36,33 @@ "ERROR_CODE_HTTP_STATUS", "InboundLadderRejection", "InboundModernRoute", + "MCP_METHOD_HEADER", + "MCP_NAME_HEADER", "MCP_PROTOCOL_VERSION_HEADER", "classify_inbound_request", + "encode_header_value", ] MCP_PROTOCOL_VERSION_HEADER: Final = "mcp-protocol-version" """Canonical lowercase name of the HTTP header carrying the MCP protocol version.""" +MCP_METHOD_HEADER: Final = "mcp-method" +"""Canonical lowercase name of the HTTP header carrying the JSON-RPC method.""" + +MCP_NAME_HEADER: Final = "mcp-name" +"""Canonical lowercase name of the HTTP header carrying the resource name (tool/prompt/resource URI).""" + +_B64_SENTINEL = re.compile(r"^=\?base64\?.*\?=$") +# RFC 7230 token chars minus DEL; visible ASCII 0x20-0x7E is the practical bound for a header value. +_HEADER_SAFE = re.compile(r"^[\x20-\x7E]*$") + + +def encode_header_value(value: str) -> str: + if _HEADER_SAFE.fullmatch(value) and value == value.strip() and not _B64_SENTINEL.fullmatch(value): + return value + return f"=?base64?{base64.b64encode(value.encode('utf-8')).decode('ascii')}?=" + + # INTERNAL_ERROR is deliberately unmapped (→ HTTP 200): the spec assigns no status to # -32603, and whether handler-origin errors get 5xx is an open S4 question — see TODO(L66). ERROR_CODE_HTTP_STATUS: Final[Mapping[int, int]] = MappingProxyType( diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 7fabafff65..859fd5d7d9 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -132,11 +132,11 @@ def request_id(self) -> RequestId | None: def can_send_request(self) -> bool: return self.transport.can_send_request and not self._closed - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: if self._closed: logger.debug("dropped %s: dispatch context closed", method) return - await self._dispatcher.notify(method, params, _related_request_id=self._request_id) + await self._dispatcher.notify(method, params, opts, _related_request_id=self._request_id) async def send_raw_request( self, @@ -209,6 +209,7 @@ def _plan_outbound(related_request_id: RequestId | None, opts: CallOptions | Non cancel_on_abandon = opts.get("cancel_on_abandon", True) token = opts.get("resumption_token") on_token = opts.get("on_resumption_token") + headers = opts.get("headers") if related_request_id is not None: if token is not None or on_token is not None: logger.debug( @@ -217,9 +218,11 @@ def _plan_outbound(related_request_id: RequestId | None, opts: CallOptions | Non return _OutboundPlan(ServerMessageMetadata(related_request_id=related_request_id), cancel_on_abandon) if token is not None or on_token is not None: return _OutboundPlan( - ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token), + ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token, headers=headers), cancel_on_abandon=False, ) + if headers: + return _OutboundPlan(ClientMessageMetadata(headers=headers), cancel_on_abandon) return _OutboundPlan(None, cancel_on_abandon) @@ -395,6 +398,7 @@ async def notify( self, method: str, params: Mapping[str, Any] | None, + opts: CallOptions | None = None, *, _related_request_id: RequestId | None = None, ) -> None: @@ -414,7 +418,7 @@ async def notify( else: msg = JSONRPCNotification(jsonrpc="2.0", method=method) try: - await self._write(msg, _plan_outbound(_related_request_id, None).metadata) + await self._write(msg, _plan_outbound(_related_request_id, opts).metadata) except (anyio.BrokenResourceError, anyio.ClosedResourceError): # Transport tore down before run() noticed EOF. logger.debug("dropped %s: write stream closed", method) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index dba263ad5a..fda0fb8cc3 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -24,6 +24,8 @@ class ClientMessageMetadata: resumption_token: ResumptionToken | None = None on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = None + # Per-message HTTP headers (e.g. MCP-Protocol-Version, Mcp-Method) the transport should set. + headers: dict[str, str] | None = None @dataclass diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py index ddf5c1c8ce..2cf7b36823 100644 --- a/src/mcp/shared/peer.py +++ b/src/mcp/shared/peer.py @@ -81,8 +81,8 @@ async def send_raw_request( ) -> dict[str, Any]: return await self._outbound.send_raw_request(method, params, opts) - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: - await self._outbound.notify(method, params) + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: + await self._outbound.notify(method, params, opts) @overload @deprecated("The sampling capability is deprecated as of 2026-07-28 (SEP-2577).", category=MCPDeprecationWarning) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 1c3767303f..0751b5b81d 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1333,7 +1333,7 @@ async def send_raw_request( ).model_dump(by_alias=True, mode="json", exclude_none=True) return {} - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: pass dispatcher = RecordingDispatcher() @@ -1387,7 +1387,7 @@ async def send_raw_request( ) -> dict[str, Any]: raise NotImplementedError - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: raise NotImplementedError session = ClientSession(dispatcher=NeverStartsDispatcher()) diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index b13de4a5ee..7803cef70a 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -17,10 +17,10 @@ from mcp.client import ClientSession from mcp.client.streamable_http import ( StreamableHTTPTransport, - _encode_header_value, streamable_http_client, ) -from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER, encode_header_value +from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse @@ -89,7 +89,7 @@ def test_mcp_name_header_values_are_base64_wrapped_when_unsafe_for_an_http_field or trailing space is wrapped because RFC 7230 forbids it in field-values (h11 rejects on real transports); an empty value is allowed and passes verbatim. """ - encoded = _encode_header_value(raw) + encoded = encode_header_value(raw) assert encoded == expected if wrapped: assert encoded.startswith("=?base64?") and encoded.endswith("?=") @@ -140,6 +140,34 @@ def handler(request: httpx.Request) -> httpx.Response: assert all("mcp-session-id" not in r.headers for r in recorded) +@pytest.mark.anyio +async def test_post_request_merges_per_message_metadata_headers() -> None: + """`ClientMessageMetadata.headers` on a `SessionMessage` are merged into the outgoing POST headers + (SDK-defined: the headers sidecar is the path the session uses to reach the transport).""" + recorded: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + recorded.append(request) + body = json.loads(request.content) + return httpx.Response(200, json={"jsonrpc": "2.0", "id": body["id"], "result": {}}) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http, protocol_version="2026-07-28") as (read, write), + ): + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/list", params={}), + metadata=ClientMessageMetadata(headers={"x-test": "v"}), + ) + ) + reply = await read.receive() + assert isinstance(reply, SessionMessage) + assert [r.method for r in recorded] == ["POST"] + assert recorded[0].headers["x-test"] == "v" + + def test_modern_constructor_pin_is_not_overwritten_by_an_initialize_result() -> None: """A 2026-07-28+ pin wins over the InitializeResult snoop (no initialize is ever sent).""" transport = StreamableHTTPTransport("http://test/mcp", protocol_version="2026-07-28") diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index 5e504ff4d7..9c2a1f5c46 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -58,7 +58,7 @@ async def send_raw_request( self.requests.append((method, params)) return self._result - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: if self._raise_on_send is not None: raise self._raise_on_send() self.notifications.append((method, params)) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 50f9d0a2ec..1987a05ce1 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -1310,7 +1310,7 @@ async def send_raw_request( ) -> dict[str, Any]: raise NotImplementedError - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: raise NotImplementedError async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 84d5e3aa93..2379ccb8bb 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -44,7 +44,7 @@ async def send_raw_request( self.requests.append((method, params, opts)) return self.result - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: self.notifications.append((method, params)) diff --git a/tests/server/test_stateless_mode.py b/tests/server/test_stateless_mode.py index 91d344253a..49c04dd847 100644 --- a/tests/server/test_stateless_mode.py +++ b/tests/server/test_stateless_mode.py @@ -37,7 +37,7 @@ async def send_raw_request( self.requests.append((method, params, opts)) return self.result - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: self.notifications.append((method, params)) diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 588c1dcc21..660b5cb3af 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -1755,6 +1755,39 @@ def test_plan_outbound_with_resumption_token_returns_client_metadata_and_suppres assert _plan_outbound(None, {}) == _OutboundPlan(metadata=None, cancel_on_abandon=True) +@pytest.mark.anyio +async def test_send_raw_request_projects_opts_headers_onto_message_metadata(): + """`opts["headers"]` alone yields `ClientMessageMetadata(headers=...)` on the outbound `SessionMessage` + (SDK-defined: the headers sidecar is the path the session uses to reach the transport).""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + + async def caller() -> None: + await client.send_raw_request("tools/list", None, {"headers": {"x-test": "v"}}) + + tg.start_soon(caller) + with anyio.fail_after(5): + outbound = await c2s_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + assert isinstance(outbound.metadata, ClientMessageMetadata) + assert outbound.metadata.headers == {"x-test": "v"} + assert outbound.metadata.resumption_token is None + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=outbound.message.id, result={})) + ) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + @pytest.mark.anyio async def test_response_with_string_id_correlates_to_int_keyed_pending_request(): """A peer that echoes the request ID as a JSON string still resolves the waiter.""" diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py index d17af88520..89e931b3b1 100644 --- a/tests/shared/test_peer.py +++ b/tests/shared/test_peer.py @@ -185,7 +185,7 @@ async def send_raw_request( ) -> dict[str, Any]: raise NotImplementedError - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: Any = None) -> None: sent.append((method, params)) await ClientPeer(_Out()).notify("n", {"x": 1}) From 52546fa1b7f43512790818b61e808d6c95980242 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 22 Jun 2026 15:03:41 +0000 Subject: [PATCH 03/22] ClientSession: install per-request stamp at connect time; transport becomes pv-agnostic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The era difference is now which stamp closure was installed, not a flag send_request reads: - Three stamp builders: _preconnect_stamp (cancel-suppressed only), _make_handshake_stamp (pv header), _make_modern_stamp (_meta envelope + cancel-suppressed + pv/method/name headers) - ClientSession.adopt(InitializeResult | DiscoverResult) installs negotiated state without wire traffic; .initialize() now calls .adopt(result) so the handshake stamp is installed before notifications/initialized goes out - send_request and send_notification call self._stamp(data, opts) unconditionally — _stateless_pinned, _pinned_version, and the inline envelope branch are deleted - ClientSession(protocol_version=) and Client.protocol_version removed - StreamableHTTPTransport drops protocol_version, _per_message_headers, _maybe_extract_protocol_version_from_message; _prepare_headers no longer derives the pv header. The transport caches the pv header from the first stamped POST's metadata and reuses it on transport-internal GET/DELETE - streamable_http_client(protocol_version=) removed - Interaction-suite [streamable-http-2026-07-28] arm now drives via ClientSession + .adopt(DiscoverResult); pagination/cancellation tests adapted to the Client|ClientSession common subset - migration.md documents the removals --- docs/migration.md | 8 + src/mcp/client/client.py | 14 +- src/mcp/client/session.py | 131 ++++++++++------- src/mcp/client/streamable_http.py | 83 ++--------- tests/client/test_client.py | 2 +- tests/client/test_session.py | 79 +++------- tests/client/test_streamable_http.py | 137 +----------------- tests/interaction/_connect.py | 71 ++++++--- tests/interaction/conftest.py | 2 +- .../interaction/lowlevel/test_cancellation.py | 3 +- tests/interaction/lowlevel/test_pagination.py | 30 +++- .../transports/test_hosting_http_modern.py | 16 +- 12 files changed, 215 insertions(+), 361 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 19c534aeb8..91d24f354e 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -159,6 +159,10 @@ The `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters have been re Note: `sse_client` retains its `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters — only the streamable HTTP transport changed. +### `protocol_version` removed from `StreamableHTTPTransport` and `streamable_http_client` + +The `protocol_version` attribute on `StreamableHTTPTransport` and the `protocol_version` parameter on `streamable_http_client` have been removed. The transport no longer holds per-connection protocol state; era-dependent headers (e.g. `MCP-Protocol-Version`) are supplied per-message by the session, so the transport never needs to know the negotiated version. + ### `terminate_windows_process` removed The deprecated `mcp.os.win32.utilities.terminate_windows_process` function has been @@ -350,6 +354,10 @@ if result is not None: The high-level `Client.initialize_result` returns the same `InitializeResult` but is non-nullable — initialization is guaranteed inside the context manager, so no `None` check is needed. This replaces v1's `Client.server_capabilities`; use `client.initialize_result.capabilities` instead. +### `ClientSession(protocol_version=)` removed + +The `protocol_version` constructor parameter on `ClientSession` has been removed. To install a known protocol version without performing the `initialize` handshake (e.g. when reconnecting to an existing session), call `session.adopt(result)` after construction with a stored `InitializeResult`. + ### `McpError` renamed to `MCPError` The `McpError` exception class has been renamed to `MCPError` for consistent naming with the MCP acronym style used throughout the SDK. diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index b69c3e5101..3476ada1f6 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -95,17 +95,6 @@ async def main(): client_info: Implementation | None = None """Client implementation info to send to server.""" - protocol_version: str | None = None - """Pin the protocol version instead of negotiating it. - - Pinning to ``2026-07-28`` or later selects the stateless transport era: no initialize - handshake is sent on the wire (the session synthesizes its `InitializeResult` locally), - and for HTTP the ``MCP-Protocol-Version`` header is set from the first request. A modern - pin currently requires a URL or `Transport`; the in-memory `Server`/`MCPServer` path - does not yet have a modern entry point. - Leave as ``None`` to negotiate the version via the initialize handshake. - """ - elicitation_callback: ElicitationFnT | None = None """Callback for handling elicitation requests.""" @@ -117,7 +106,7 @@ def __post_init__(self) -> None: if isinstance(self.server, Server | MCPServer): self._transport = InMemoryTransport(self.server, raise_exceptions=self.raise_exceptions) elif isinstance(self.server, str): - self._transport = streamable_http_client(self.server, protocol_version=self.protocol_version) + self._transport = streamable_http_client(self.server) else: self._transport = self.server @@ -140,7 +129,6 @@ async def __aenter__(self) -> Client: message_handler=self.message_handler, client_info=self.client_info, elicitation_callback=self.elicitation_callback, - protocol_version=self.protocol_version, ) ) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index b9ebf6c9e1..ab162c4d8c 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from collections.abc import Mapping +from collections.abc import Callable, Mapping from dataclasses import dataclass from types import TracebackType from typing import Any, Protocol, cast @@ -17,6 +17,12 @@ from mcp.shared._compat import resync_tracer from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher, ProgressFnT from mcp.shared.exceptions import MCPDeprecationWarning, MCPError +from mcp.shared.inbound import ( + MCP_METHOD_HEADER, + MCP_NAME_HEADER, + MCP_PROTOCOL_VERSION_HEADER, + encode_header_value, +) from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder @@ -37,6 +43,38 @@ logger = logging.getLogger("client") + +def _preconnect_stamp(data: dict[str, Any], opts: CallOptions) -> None: + # Only initialize/discover go out before connect; both forbid cancellation. + opts["cancel_on_abandon"] = False + + +def _make_handshake_stamp(protocol_version: str) -> Callable[[dict[str, Any], CallOptions], None]: + def stamp(data: dict[str, Any], opts: CallOptions) -> None: + opts.setdefault("headers", {})[MCP_PROTOCOL_VERSION_HEADER] = protocol_version + + return stamp + + +def _make_modern_stamp( + protocol_version: str, client_info: dict[str, Any], capabilities: dict[str, Any] +) -> Callable[[dict[str, Any], CallOptions], None]: + def stamp(data: dict[str, Any], opts: CallOptions) -> None: + params = data.setdefault("params", {}) + meta = params.setdefault("_meta", {}) + meta[PROTOCOL_VERSION_META_KEY] = protocol_version + meta[CLIENT_INFO_META_KEY] = client_info + meta[CLIENT_CAPABILITIES_META_KEY] = capabilities + opts["cancel_on_abandon"] = False + headers = opts.setdefault("headers", {}) + headers[MCP_PROTOCOL_VERSION_HEADER] = protocol_version + headers[MCP_METHOD_HEADER] = data["method"] + if data["method"] == "tools/call" and isinstance(name := params.get("name"), str): + headers[MCP_NAME_HEADER] = encode_header_value(name) + + return stamp + + ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) @@ -149,14 +187,11 @@ def __init__( message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, *, - protocol_version: str | None = None, sampling_capabilities: types.SamplingCapability | None = None, dispatcher: Dispatcher[Any] | None = None, ) -> None: self._session_read_timeout_seconds = read_timeout_seconds self._client_info = client_info or DEFAULT_CLIENT_INFO - self._pinned_version = protocol_version - self._stateless_pinned = protocol_version in MODERN_PROTOCOL_VERSIONS self._sampling_callback = sampling_callback or _default_sampling_callback self._sampling_capabilities = sampling_capabilities self._elicitation_callback = elicitation_callback or _default_elicitation_callback @@ -164,19 +199,8 @@ def __init__( self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} - self._initialize_result: types.InitializeResult | None - if self._stateless_pinned: - assert protocol_version is not None - # A stateless-pinned session is born initialized: there is no handshake - # at 2026-07-28+, so we synthesize the result locally. `server_info` is a - # placeholder until `server/discover` is implemented to populate it. - self._initialize_result = types.InitializeResult( - protocol_version=protocol_version, - capabilities=types.ServerCapabilities(), - server_info=types.Implementation(name="", version=""), - ) - else: - self._initialize_result = None + self._initialize_result: types.InitializeResult | None = None + self._stamp: Callable[[dict[str, Any], CallOptions], None] = _preconnect_stamp self._task_group: anyio.abc.TaskGroup | None = None if dispatcher is not None: if read_stream is not None or write_stream is not None: @@ -242,19 +266,7 @@ async def send_request( data = request.model_dump(by_alias=True, mode="json", exclude_none=True) method: str = data["method"] opts: CallOptions = {} - if self._stateless_pinned: - params = data.setdefault("params", {}) - envelope_meta = params.setdefault("_meta", {}) - envelope_meta[PROTOCOL_VERSION_META_KEY] = self._pinned_version - envelope_meta[CLIENT_INFO_META_KEY] = self._client_info.model_dump( - by_alias=True, mode="json", exclude_none=True - ) - envelope_meta[CLIENT_CAPABILITIES_META_KEY] = self._build_capabilities().model_dump( - by_alias=True, mode="json", exclude_none=True - ) - # Stateless pinned mode: disconnect-as-cancel is the spec mechanism, so the - # dispatcher must not emit notifications/cancelled when the caller abandons. - opts["cancel_on_abandon"] = False + self._stamp(data, opts) timeout = ( request_read_timeout_seconds if request_read_timeout_seconds is not None @@ -269,9 +281,6 @@ async def send_request( opts["resumption_token"] = metadata.resumption_token if metadata.on_resumption_token_update is not None: opts["on_resumption_token"] = metadata.on_resumption_token_update - if method == "initialize": - # The spec forbids cancelling initialize. - opts["cancel_on_abandon"] = False raw = await self._dispatcher.send_raw_request(method, data.get("params"), opts) # Literal fallback covers pre-handshake and stateless; matches runner.py. version = self.protocol_version or "2025-11-25" @@ -288,7 +297,9 @@ async def send_notification(self, notification: types.ClientNotification) -> Non dropped with a debug log instead of raising. """ data = notification.model_dump(by_alias=True, mode="json", exclude_none=True) - await self._dispatcher.notify(data["method"], data.get("params")) + opts: CallOptions = {} + self._stamp(data, opts) + await self._dispatcher.notify(data["method"], data.get("params"), opts) def _build_capabilities(self) -> types.ClientCapabilities: sampling = ( @@ -314,14 +325,11 @@ def _build_capabilities(self) -> types.ClientCapabilities: async def initialize(self) -> types.InitializeResult: if self._initialize_result is not None: return self._initialize_result - capabilities = self._build_capabilities() result = await self.send_request( types.InitializeRequest( params=types.InitializeRequestParams( - protocol_version=self._pinned_version - if self._pinned_version is not None - else HANDSHAKE_PROTOCOL_VERSIONS[-1], - capabilities=capabilities, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + capabilities=self._build_capabilities(), client_info=self._client_info, ), ), @@ -331,34 +339,49 @@ async def initialize(self) -> types.InitializeResult: if result.protocol_version not in HANDSHAKE_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocol_version}") - self._initialize_result = result + self.adopt(result) await self.send_notification(types.InitializedNotification()) return result + def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None: + """Install negotiated state from a result the caller already holds (no wire traffic).""" + if isinstance(result, types.DiscoverResult): + mutual = [v for v in MODERN_PROTOCOL_VERSIONS if v in result.supported_versions] + if not mutual: + raise RuntimeError( + f"No mutually supported modern protocol version " + f"(server: {result.supported_versions}, client: {list(MODERN_PROTOCOL_VERSIONS)})" + ) + protocol_version = mutual[-1] + client_info = self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True) + capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) + self._stamp = _make_modern_stamp(protocol_version, client_info, capabilities) + self._initialize_result = types.InitializeResult( + protocol_version=protocol_version, + capabilities=result.capabilities, + server_info=result.server_info, + instructions=result.instructions, + ) + else: + self._stamp = _make_handshake_stamp(result.protocol_version) + self._initialize_result = result + @property def initialize_result(self) -> types.InitializeResult | None: - """The server's InitializeResult. None until initialize() has been called. + """The server's InitializeResult. None until `initialize()` or `adopt()`. - A stateless-pinned session (protocol_version >= 2026-07-28) is born - initialized: this property is populated at construction with a - synthesized result and `initialize()` returns it without touching the - wire. Contains server_info, capabilities, instructions, and the - negotiated protocol_version. + Contains server_info, capabilities, instructions, and the negotiated + protocol_version. For a modern session adopted from a DiscoverResult, + this is synthesized locally with the chosen protocol version. """ return self._initialize_result @property def protocol_version(self) -> str | None: - """Negotiated or pinned protocol version. None until initialize() unless pinned at construction. - - Once `initialize()` has completed, this is the version the server actually - negotiated (which can differ from a stateful pin); before that, the pin. - """ - if self._initialize_result is not None: - return self._initialize_result.protocol_version - return self._pinned_version + """Negotiated protocol version. None until `initialize()` or `adopt()`.""" + return self._initialize_result.protocol_version if self._initialize_result is not None else None async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: """Send a ping request.""" diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index b6e6e105aa..962e104e1e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -18,15 +18,13 @@ from mcp.shared._compat import resync_tracer from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._httpx_utils import create_mcp_http_client -from mcp.shared.inbound import MCP_METHOD_HEADER, MCP_NAME_HEADER, MCP_PROTOCOL_VERSION_HEADER, encode_header_value +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.message import ClientMessageMetadata, SessionMessage -from mcp.shared.version import MODERN_PROTOCOL_VERSIONS from mcp.types import ( INTERNAL_ERROR, INVALID_REQUEST, PARSE_ERROR, ErrorData, - InitializeResult, JSONRPCError, JSONRPCMessage, JSONRPCNotification, @@ -74,41 +72,17 @@ class RequestContext: class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" - def __init__(self, url: str, protocol_version: str | None = None) -> None: + def __init__(self, url: str) -> None: """Initialize the StreamableHTTP transport. Args: url: The endpoint URL. - protocol_version: Pin the MCP-Protocol-Version header from the first request. - Only honoured for stateless 2026-07-28+ sessions that never send - initialize; for earlier (stateful) versions the header is populated - from the negotiated InitializeResult, so a pre-2026 value is ignored. """ self.url = url self.session_id: str | None = None - self.protocol_version: str | None = protocol_version if protocol_version in MODERN_PROTOCOL_VERSIONS else None - - def _per_message_headers(self, message: JSONRPCMessage) -> dict[str, str]: - """Per-POST routing headers (Mcp-Method, Mcp-Name) for 2026-07-28+ pinned transports. - - MCP-Protocol-Version is not emitted here — `_prepare_headers()` already adds it - from `self.protocol_version` for every request. - """ - if self.protocol_version not in MODERN_PROTOCOL_VERSIONS: - return {} - if not isinstance(message, JSONRPCRequest | JSONRPCNotification): - return {} - headers: dict[str, str] = {MCP_METHOD_HEADER: message.method} - # TODO: Mcp-Name is also REQUIRED for prompts/get (params.name) and resources/read - # (params.uri); a method->param-key map replaces this gate when those land. - if ( - isinstance(message, JSONRPCRequest) - and message.method == "tools/call" - and message.params - and isinstance(name := message.params.get("name"), str) - ): - headers[MCP_NAME_HEADER] = encode_header_value(name) - return headers + # Captured from the first stamped POST's metadata; reused on transport-internal + # GET/DELETE that don't carry per-message metadata. + self._protocol_version_header: str | None = None def _prepare_headers(self) -> dict[str, str]: """Build MCP-specific request headers. @@ -123,8 +97,8 @@ def _prepare_headers(self) -> dict[str, str]: # Add session headers if available if self.session_id: headers[MCP_SESSION_ID] = self.session_id - if self.protocol_version: - headers[MCP_PROTOCOL_VERSION_HEADER] = self.protocol_version + if self._protocol_version_header: + headers[MCP_PROTOCOL_VERSION_HEADER] = self._protocol_version_header return headers def _is_initialization_request(self, message: JSONRPCMessage) -> bool: @@ -142,29 +116,12 @@ def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> N self.session_id = new_session_id logger.info(f"Received session ID: {self.session_id}") - def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) -> None: - """Extract protocol version from initialization response message.""" - if self.protocol_version is not None: - # Only a modern constructor pin reaches here (pre-2026 values are dropped - # in __init__), and a modern pin never sends initialize. - return - if isinstance(message, JSONRPCResponse) and message.result: # pragma: no branch - try: - # Parse the result as InitializeResult for type safety - init_result = InitializeResult.model_validate(message.result, by_name=False) - self.protocol_version = init_result.protocol_version - logger.info(f"Negotiated protocol version: {self.protocol_version}") - except Exception: # pragma: no cover - logger.warning("Failed to parse initialization response as InitializeResult", exc_info=True) - logger.warning(f"Raw result: {message.result}") - async def _handle_sse_event( self, sse: ServerSentEvent, read_stream_writer: StreamWriter, original_request_id: RequestId | None = None, resumption_callback: Callable[[str], Awaitable[None]] | None = None, - is_initialization: bool = False, ) -> bool: """Handle an SSE event, returning True if the response is complete.""" if sse.event == "message": @@ -178,10 +135,6 @@ async def _handle_sse_event( message = jsonrpc_message_adapter.validate_json(sse.data, by_name=False) logger.debug(f"SSE message: {message}") - # Extract protocol version from initialization response - if is_initialization: - self._maybe_extract_protocol_version_from_message(message) - # If this is a response and we have original_request_id, replace it if original_request_id is not None and isinstance(message, JSONRPCResponse | JSONRPCError): message.id = original_request_id @@ -287,9 +240,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" headers = self._prepare_headers() message = ctx.session_message.message - headers.update(self._per_message_headers(message)) if ctx.metadata is not None and ctx.metadata.headers is not None: headers.update(ctx.metadata.headers) + if MCP_PROTOCOL_VERSION_HEADER in ctx.metadata.headers: + self._protocol_version_header = ctx.metadata.headers[MCP_PROTOCOL_VERSION_HEADER] is_initialization = self._is_initialization_request(message) async with ctx.client.stream( @@ -337,11 +291,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: if isinstance(message, JSONRPCRequest): content_type = response.headers.get("content-type", "").lower() if content_type.startswith("application/json"): - await self._handle_json_response( - response, ctx.read_stream_writer, is_initialization, request_id=message.id - ) + await self._handle_json_response(response, ctx.read_stream_writer, request_id=message.id) elif content_type.startswith("text/event-stream"): - await self._handle_sse_response(response, ctx, is_initialization) + await self._handle_sse_response(response, ctx) else: logger.error(f"Unexpected content type: {content_type}") error_data = ErrorData(code=INVALID_REQUEST, message=f"Unexpected content type: {content_type}") @@ -352,7 +304,6 @@ async def _handle_json_response( self, response: httpx.Response, read_stream_writer: StreamWriter, - is_initialization: bool = False, *, request_id: RequestId, ) -> None: @@ -360,11 +311,6 @@ async def _handle_json_response( try: content = await response.aread() message = jsonrpc_message_adapter.validate_json(content, by_name=False) - - # Extract protocol version from initialization response - if is_initialization: - self._maybe_extract_protocol_version_from_message(message) - session_message = SessionMessage(message) await read_stream_writer.send(session_message) except (httpx.StreamError, ValidationError) as exc: @@ -377,7 +323,6 @@ async def _handle_sse_response( self, response: httpx.Response, ctx: RequestContext, - is_initialization: bool = False, ) -> None: """Handle SSE response from the server.""" last_event_id: str | None = None @@ -404,7 +349,6 @@ async def _handle_sse_response( ctx.read_stream_writer, original_request_id=original_request_id, resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), - is_initialization=is_initialization, ) # If the SSE event indicates completion, like returning response/error # break the loop @@ -569,7 +513,6 @@ async def streamable_http_client( *, http_client: httpx.AsyncClient | None = None, terminate_on_close: bool = True, - protocol_version: str | None = None, ) -> AsyncGenerator[TransportStreams, None]: """Client transport for StreamableHTTP. @@ -579,8 +522,6 @@ async def streamable_http_client( client with recommended MCP timeouts will be created. To configure headers, authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here. terminate_on_close: If True, send a DELETE request to terminate the session when the context exits. - protocol_version: Pin the MCP-Protocol-Version header for stateless 2026-07-28 sessions. - Tracer-bullet duplication — also pass to `ClientSession(protocol_version=...)`. Yields: Tuple containing: @@ -598,7 +539,7 @@ async def streamable_http_client( # Create default client with recommended MCP timeouts client = create_mcp_http_client() - transport = StreamableHTTPTransport(url, protocol_version=protocol_version) + transport = StreamableHTTPTransport(url) logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 3680639e0f..2accd093d6 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -316,7 +316,7 @@ async def test_complete_with_prompt_reference(simple_server: Server): def test_client_with_url_initializes_streamable_http_transport(): with patch("mcp.client.client.streamable_http_client") as mock: _ = Client("http://localhost:8000/mcp") - mock.assert_called_once_with("http://localhost:8000/mcp", protocol_version=None) + mock.assert_called_once_with("http://localhost:8000/mcp") async def test_client_uses_transport_directly(app: MCPServer): diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 0751b5b81d..d3732ec55c 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1300,6 +1300,25 @@ async def test_dispatcher_keyword_request_timeout_bounds_wait_for_never_run_peer assert exc.value.error.code == REQUEST_TIMEOUT +def test_adopt_raises_when_no_mutual_modern_version_is_supported() -> None: + """SDK-defined: ``adopt(DiscoverResult)`` picks the newest version both sides support; an + empty intersection is unrecoverable and raises rather than installing a stamp.""" + client_d, _ = create_direct_dispatcher_pair() + session = ClientSession(dispatcher=client_d) + with pytest.raises(RuntimeError, match="No mutually supported modern protocol version"): + session.adopt( + types.DiscoverResult( + supported_versions=["1999-01-01"], + capabilities=types.ServerCapabilities(), + server_info=types.Implementation(name="s", version="0"), + result_type="complete", + ttl_ms=0, + cache_scope="public", + ) + ) + assert session.protocol_version is None + + @pytest.mark.anyio async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_leave_it_unset(): """`send_request` passes `cancel_on_abandon=False` for `initialize` — the spec forbids @@ -1400,66 +1419,6 @@ async def notify(self, method: str, params: Mapping[str, Any] | None, opts: Call assert session._task_group is None -@pytest.mark.anyio -async def test_initialize_on_a_stateless_pinned_session_returns_the_synthesized_result_without_any_frame_sent(): - """A session pinned to the 2026-07-28 stateless protocol is born initialized. - - The 2026-07-28 lifecycle replaces the initialize handshake with a per-request ``_meta`` - envelope, so ``initialize()`` is idempotent and returns a locally-synthesized result - without ever touching the wire. - """ - async with raw_client_session(protocol_version="2026-07-28") as (session, _send, from_client): - result = await session.initialize() - assert result.protocol_version == "2026-07-28" - assert isinstance(result.capabilities, ServerCapabilities) - assert from_client.statistics().current_buffer_used == 0 - assert (await session.initialize()) is result - - -@pytest.mark.anyio -async def test_initialize_on_a_stateful_pin_requests_the_pinned_version(): - """A session pinned to a pre-2026 stateful version still runs the handshake, but the - outgoing ``initialize`` frame requests the pinned version rather than ``LATEST``.""" - async with raw_client_session(protocol_version="2025-06-18") as (session, to_client, from_client): - first: list[InitializeResult] = [] - - async def do_initialize() -> None: - first.append(await session.initialize()) - - async with anyio.create_task_group() as tg: - tg.start_soon(do_initialize) - out = await from_client.receive() - assert isinstance(out.message, JSONRPCRequest) - assert out.message.params is not None - assert out.message.params["protocolVersion"] == "2025-06-18" - assert session.protocol_version == "2025-06-18" - # Server negotiates a different (older) supported version than the pin requested. - result = InitializeResult( - protocol_version="2025-03-26", - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - await to_client.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=out.message.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - # Drain the notifications/initialized frame so the buffer-used assertion below - # measures only what the second initialize() emits. - notif = await from_client.receive() - assert isinstance(notif.message, JSONRPCNotification) - # The property reports the negotiated version, not the pin, once the handshake is done. - assert session.protocol_version == "2025-03-26" - # A second call returns the cached result without a second handshake frame. - again = await session.initialize() - assert again is first[0] - assert from_client.statistics().current_buffer_used == 0 - - @pytest.mark.anyio async def test_send_notification_after_close_is_dropped_silently(): """Post-close `send_notification` is fire-and-forget: the notification is dropped, diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index 7803cef70a..92f1fc8981 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -1,9 +1,9 @@ """Unit tests for the streamable-HTTP client transport. The full client<->server round trip is pinned by the interaction suite under -tests/interaction/transports/; these tests cover the transport's per-message header -derivation directly because the headers are an HTTP-seam observation the public client -never exposes. +tests/interaction/transports/; these tests cover the transport's header encoding and the +per-message metadata-headers merge directly because the headers are an HTTP-seam observation +the public client never exposes. """ import base64 @@ -14,56 +14,10 @@ import pytest from inline_snapshot import snapshot -from mcp.client import ClientSession -from mcp.client.streamable_http import ( - StreamableHTTPTransport, - streamable_http_client, -) -from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER, encode_header_value +from mcp.client.streamable_http import streamable_http_client +from mcp.shared.inbound import encode_header_value from mcp.shared.message import ClientMessageMetadata, SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse - - -@pytest.mark.parametrize( - ("message", "expected"), - [ - ( - JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": "add", "arguments": {}}), - snapshot({"mcp-method": "tools/call", "mcp-name": "add"}), - ), - ( - JSONRPCRequest(jsonrpc="2.0", id=2, method="tools/list", params={}), - snapshot({"mcp-method": "tools/list"}), - ), - ( - JSONRPCNotification(jsonrpc="2.0", method="notifications/cancelled"), - snapshot({"mcp-method": "notifications/cancelled"}), - ), - ( - JSONRPCResponse(jsonrpc="2.0", id=3, result={}), - snapshot({}), - ), - ], -) -def test_per_message_headers_for_pinned_transport_carry_method_and_name( - message: JSONRPCMessage, expected: dict[str, str] -) -> None: - """A 2026-07-28-pinned transport derives ``Mcp-Method`` (and ``Mcp-Name`` for tools/call) from the body. - - ``MCP-Protocol-Version`` is not in the per-message set: ``_prepare_headers()`` adds it from the - pin for every request, so only the method/name advisory headers vary per POST. Responses yield - nothing because the spec only defines the headers for requests and notifications. - """ - transport = StreamableHTTPTransport("http://test/mcp", protocol_version="2026-07-28") - assert transport._per_message_headers(message) == expected # pyright: ignore[reportPrivateUsage] - - -@pytest.mark.parametrize("protocol_version", [None, "2025-11-25"]) -def test_per_message_headers_are_empty_for_legacy_or_unpinned_transport(protocol_version: str | None) -> None: - """An unpinned or 2025-era transport emits no per-message headers, keeping the wire byte-identical to v1.""" - transport = StreamableHTTPTransport("http://test/mcp", protocol_version=protocol_version) - message = JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": "add", "arguments": {}}) - assert transport._per_message_headers(message) == {} # pyright: ignore[reportPrivateUsage] +from mcp.types import JSONRPCRequest @pytest.mark.parametrize( @@ -98,48 +52,6 @@ def test_mcp_name_header_values_are_base64_wrapped_when_unsafe_for_an_http_field assert encoded == raw -@pytest.mark.anyio -async def test_pinned_transport_ignores_returned_session_id_and_never_opens_get_or_delete() -> None: - """A server-issued ``Mcp-Session-Id`` never reaches a pinned client's wire: only POSTs are sent. - - The session-id capture, the standalone GET listening stream, and the DELETE-on-close are all - gated implicitly: a pinned ``ClientSession`` never sends ``initialize`` (no InitializeResult to - capture an id from) and never sends ``notifications/initialized`` (which is what triggers the - standalone GET), so even when a misbehaving peer volunteers a session id on every response the - recorded log stays POST-only and no request echoes the id back. The successful ``tools/call`` - triggers the client's implicit ``tools/list`` output-schema fetch so there is a second POST - after the id was offered. - """ - recorded: list[httpx.Request] = [] - - def handler(request: httpx.Request) -> httpx.Response: - recorded.append(request) - body = json.loads(request.content) - if body["method"] == "tools/list": - result: dict[str, object] = { - "tools": [{"name": "add", "inputSchema": {"type": "object"}}], - "resultType": "complete", - "ttlMs": 0, - "cacheScope": "public", - } - else: - result = {"content": [{"type": "text", "text": "5"}], "isError": False, "resultType": "complete"} - return httpx.Response( - 200, json={"jsonrpc": "2.0", "id": body["id"], "result": result}, headers={"mcp-session-id": "srv-123"} - ) - - with anyio.fail_after(5): - async with ( - httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, - streamable_http_client("http://test/mcp", http_client=http, protocol_version="2026-07-28") as (read, write), - ClientSession(read, write, protocol_version="2026-07-28") as session, - ): - await session.call_tool("add", {"a": 2, "b": 3}) - - assert [r.method for r in recorded] == snapshot(["POST", "POST"]) - assert all("mcp-session-id" not in r.headers for r in recorded) - - @pytest.mark.anyio async def test_post_request_merges_per_message_metadata_headers() -> None: """`ClientMessageMetadata.headers` on a `SessionMessage` are merged into the outgoing POST headers @@ -154,7 +66,7 @@ def handler(request: httpx.Request) -> httpx.Response: with anyio.fail_after(5): async with ( httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, - streamable_http_client("http://test/mcp", http_client=http, protocol_version="2026-07-28") as (read, write), + streamable_http_client("http://test/mcp", http_client=http) as (read, write), ): await write.send( SessionMessage( @@ -166,38 +78,3 @@ def handler(request: httpx.Request) -> httpx.Response: assert isinstance(reply, SessionMessage) assert [r.method for r in recorded] == ["POST"] assert recorded[0].headers["x-test"] == "v" - - -def test_modern_constructor_pin_is_not_overwritten_by_an_initialize_result() -> None: - """A 2026-07-28+ pin wins over the InitializeResult snoop (no initialize is ever sent).""" - transport = StreamableHTTPTransport("http://test/mcp", protocol_version="2026-07-28") - init = JSONRPCResponse( - jsonrpc="2.0", - id=1, - result={ - "protocolVersion": "2025-11-25", - "capabilities": {}, - "serverInfo": {"name": "s", "version": "0"}, - }, - ) - transport._maybe_extract_protocol_version_from_message(init) # pyright: ignore[reportPrivateUsage] - assert transport.protocol_version == "2026-07-28" - - -def test_stateful_constructor_pin_is_ignored_and_the_negotiated_version_wins() -> None: - """A pre-2026 pin is a session-layer concern; the transport must not stamp it on the - initialize request and must adopt the server's negotiated version for later headers.""" - transport = StreamableHTTPTransport("http://test/mcp", protocol_version="2025-06-18") - assert MCP_PROTOCOL_VERSION_HEADER not in transport._prepare_headers() # pyright: ignore[reportPrivateUsage] - init = JSONRPCResponse( - jsonrpc="2.0", - id=1, - result={ - "protocolVersion": "2025-03-26", - "capabilities": {}, - "serverInfo": {"name": "s", "version": "0"}, - }, - ) - transport._maybe_extract_protocol_version_from_message(init) # pyright: ignore[reportPrivateUsage] - assert transport.protocol_version == "2025-03-26" - assert transport._prepare_headers()[MCP_PROTOCOL_VERSION_HEADER] == "2025-03-26" # pyright: ignore[reportPrivateUsage] diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 575a742632..391d7a8359 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -20,7 +20,7 @@ from starlette.routing import Mount, Route from mcp.client.client import Client -from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server @@ -31,14 +31,17 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, + DiscoverResult, Implementation, InitializeRequestParams, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, + ServerCapabilities, jsonrpc_message_adapter, ) from tests.interaction.transports._bridge import StreamingASGITransport @@ -70,7 +73,7 @@ def __call__( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - protocol_version: str = LATEST_PROTOCOL_VERSION, + spec_version: str = LATEST_PROTOCOL_VERSION, ) -> AbstractAsyncContextManager[Client]: ... @@ -85,7 +88,7 @@ async def connect_in_memory( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - protocol_version: str = LATEST_PROTOCOL_VERSION, + spec_version: str = LATEST_PROTOCOL_VERSION, ) -> AsyncIterator[Client]: """Yield a Client connected to the server over the in-memory transport.""" async with Client( @@ -97,7 +100,6 @@ async def connect_in_memory( message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, - protocol_version=protocol_version, ) as client: yield client @@ -117,7 +119,7 @@ async def connect_over_streamable_http( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - protocol_version: str = LATEST_PROTOCOL_VERSION, + spec_version: str = LATEST_PROTOCOL_VERSION, ) -> AsyncIterator[Client]: """Yield a Client connected to the server's streamable HTTP app, entirely in process. @@ -126,6 +128,10 @@ async def connect_over_streamable_http( transport-specific tests pass `json_response` to select the other server mode, and the resumability tests pass an `event_store` (with `retry_interval=0` so the client's reconnection wait is a no-op). + + When `spec_version` is a modern (2026-07-28+) revision, the modern path is exercised: a bare + `ClientSession` is built over the streams and adopted from a synthesized `DiscoverResult` + instead of negotiating via `Client`'s legacy initialize handshake. """ app = server.streamable_http_app( stateless_http=stateless_http, @@ -137,19 +143,48 @@ async def connect_over_streamable_http( async with ( server.session_manager.run(), httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client, - Client( - streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client, protocol_version=protocol_version), - read_timeout_seconds=read_timeout_seconds, - sampling_callback=sampling_callback, - list_roots_callback=list_roots_callback, - logging_callback=logging_callback, - message_handler=message_handler, - client_info=client_info, - elicitation_callback=elicitation_callback, - protocol_version=protocol_version, - ) as client, ): - yield client + if spec_version in MODERN_PROTOCOL_VERSIONS: + async with ( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read, write), + ClientSession( + read, + write, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as session, + ): + session.adopt( + DiscoverResult( + supported_versions=[spec_version], + capabilities=ServerCapabilities(), + server_info=Implementation(name="test", version="0"), + result_type="complete", + ttl_ms=0, + cache_scope="public", + ) + ) + # ClientSession quacks as Client for every modern-arm requirement; the surfaces + # diverge (cursor= vs params=, .session, nullable initialize_result) so widening + # Connect to the union cascades ~58 errors. Contained here until the two converge. + yield session # pyright: ignore[reportReturnType] + else: + async with Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client connect_over_streamable_http_stateless: Connect = partial(connect_over_streamable_http, stateless_http=True) @@ -345,7 +380,7 @@ async def connect_over_sse( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - protocol_version: str = LATEST_PROTOCOL_VERSION, + spec_version: str = LATEST_PROTOCOL_VERSION, ) -> AsyncIterator[Client]: """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" app, _ = build_sse_app(server) diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py index cc1ae5ee7a..b918daf008 100644 --- a/tests/interaction/conftest.py +++ b/tests/interaction/conftest.py @@ -45,4 +45,4 @@ def connect(request: pytest.FixtureRequest) -> Connect: transport, spec_version = request.param assert isinstance(transport, str) assert isinstance(spec_version, str) - return partial(_FACTORIES[transport], protocol_version=spec_version) + return partial(_FACTORIES[transport], spec_version=spec_version) diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 6e6c2b6f60..e0eb958ad8 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -148,7 +148,8 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server) as client: - await client.session.send_notification( + session = client if isinstance(client, ClientSession) else client.session # modern arm yields a bare session + await session.send_notification( types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999)) ) result = await client.call_tool("echo", {}) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py index 77db90401e..85ad1eef67 100644 --- a/tests/interaction/lowlevel/test_pagination.py +++ b/tests/interaction/lowlevel/test_pagination.py @@ -5,6 +5,9 @@ pagination scheme. """ +from collections.abc import Callable, Coroutine +from typing import Any, TypeVar + import pytest from inline_snapshot import snapshot @@ -26,6 +29,19 @@ pytestmark = pytest.mark.anyio +_ResultT = TypeVar("_ResultT") + + +def _page(list_fn: Callable[..., Coroutine[Any, Any, _ResultT]], cursor: str | None) -> Coroutine[Any, Any, _ResultT]: + """Call a paginated list method with a cursor on whichever client surface `connect` yielded. + + `Client.list_*` takes `cursor=`; `ClientSession.list_*` takes `params=PaginatedRequestParams(...)`. + """ + try: + return list_fn(cursor=cursor) + except TypeError: + return list_fn(params=types.PaginatedRequestParams(cursor=cursor)) + @requirement("tools:list:pagination") async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> None: @@ -49,7 +65,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa async with connect(server) as client: first_page = await client.list_tools() - second_page = await client.list_tools(cursor=first_page.next_cursor) + second_page = await _page(client.list_tools, first_page.next_cursor) assert first_page.next_cursor == cursor assert seen_cursors == [None, cursor] @@ -79,7 +95,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa requests_made = 0 async with connect(server) as client: while True: - result = await client.list_tools(cursor=cursor) + result = await _page(client.list_tools, cursor) requests_made += 1 assert requests_made <= len(pages), "the server kept returning next_cursor past the last page" collected.extend(tool.name for tool in result.tools) @@ -122,7 +138,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa cursor: str | None = None async with connect(server) as client: while True: - result = await client.list_tools(cursor=cursor) + result = await _page(client.list_tools, cursor) page_sizes.append(len(result.tools)) if result.next_cursor is None: break @@ -150,7 +166,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa async with connect(server) as client: with pytest.raises(MCPError) as exc_info: - await client.list_tools(cursor="never-issued") + await _page(client.list_tools, "never-issued") assert exc_info.value.error.code == INVALID_PARAMS @@ -174,7 +190,7 @@ async def list_resources( async with connect(server) as client: first_page = await client.list_resources() - second_page = await client.list_resources(cursor=first_page.next_cursor) + second_page = await _page(client.list_resources, first_page.next_cursor) assert first_page.next_cursor == cursor assert seen_cursors == [None, cursor] @@ -207,7 +223,7 @@ async def list_resource_templates( async with connect(server) as client: first_page = await client.list_resource_templates() - second_page = await client.list_resource_templates(cursor=first_page.next_cursor) + second_page = await _page(client.list_resource_templates, first_page.next_cursor) assert first_page.next_cursor == cursor assert seen_cursors == [None, cursor] @@ -233,7 +249,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest async with connect(server) as client: first_page = await client.list_prompts() - second_page = await client.list_prompts(cursor=first_page.next_cursor) + second_page = await _page(client.list_prompts, first_page.next_cursor) assert first_page.next_cursor == cursor assert seen_cursors == [None, cursor] diff --git a/tests/interaction/transports/test_hosting_http_modern.py b/tests/interaction/transports/test_hosting_http_modern.py index f943f9e89e..1ed8dea201 100644 --- a/tests/interaction/transports/test_hosting_http_modern.py +++ b/tests/interaction/transports/test_hosting_http_modern.py @@ -28,6 +28,7 @@ MISSING_REQUIRED_CLIENT_CAPABILITY, CallToolRequestParams, CallToolResult, + DiscoverResult, EmptyResult, Implementation, JSONRPCError, @@ -35,6 +36,7 @@ ListToolsResult, PaginatedRequestParams, RequestParams, + ServerCapabilities, TextContent, Tool, ) @@ -328,12 +330,16 @@ async def on_response(response: httpx.Response) -> None: with anyio.fail_after(5): async with ( mounted_app(server, on_request=on_request, on_response=on_response) as (http, _), - streamable_http_client(f"{BASE_URL}/mcp", http_client=http, protocol_version=MODERN_VERSION) as ( - read, - write, - ), - ClientSession(read, write, client_info=client_info, protocol_version=MODERN_VERSION) as session, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http) as (read, write), + ClientSession(read, write, client_info=client_info) as session, ): + session.adopt( + DiscoverResult( + supported_versions=[MODERN_VERSION], + capabilities=ServerCapabilities(), + server_info=Implementation(name="srv", version="0"), + ) + ) result = await session.call_tool( "add", {"a": 2, "b": 3}, From 1d33743e3636f3fabb347f421e6fc06bffecf2cf Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 22 Jun 2026 15:17:29 +0000 Subject: [PATCH 04/22] Client gains mode= and prior_discover= policy knobs (legacy and version-pin) - mode='legacy' (default) performs the initialize handshake; a version string (e.g. '2026-07-28') adopts that version directly via .adopt() - prior_discover= reuses a known DiscoverResult; omitting it synthesizes a minimal one - 'auto' (server/discover probe with fallback) follows once .discover() lands - Interaction-suite connect fixture passes mode= for the modern arm and yields Client for all arms again; the W1b-era ClientSession adapter and type suppression are removed --- src/mcp/client/client.py | 28 ++++++++- tests/interaction/_connect.py | 63 +++++-------------- .../interaction/lowlevel/test_cancellation.py | 3 +- tests/interaction/lowlevel/test_pagination.py | 30 +++------ 4 files changed, 50 insertions(+), 74 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 3476ada1f6..cf049e7d9e 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -4,10 +4,11 @@ from contextlib import AsyncExitStack from dataclasses import KW_ONLY, dataclass, field -from typing import Any +from typing import Any, Literal from typing_extensions import deprecated +from mcp import types from mcp.client._memory import InMemoryTransport from mcp.client._transport import Transport from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT @@ -36,6 +37,17 @@ ) +def _synthesize_discover(protocol_version: str) -> types.DiscoverResult: + return types.DiscoverResult( + supported_versions=[protocol_version], + capabilities=types.ServerCapabilities(), + server_info=types.Implementation(name="", version=""), + result_type="complete", + ttl_ms=0, + cache_scope="public", + ) + + @dataclass class Client: """A high-level MCP client for connecting to MCP servers. @@ -95,6 +107,15 @@ async def main(): client_info: Implementation | None = None """Client implementation info to send to server.""" + mode: Literal["legacy"] | str = "legacy" + """'legacy' performs the initialize handshake. A protocol-version string (e.g. '2026-07-28') adopts that + version directly without a handshake — supply prior_discover to reuse a known DiscoverResult, or omit it + to synthesize a minimal one.""" + + prior_discover: types.DiscoverResult | None = None + """A previously-obtained DiscoverResult to install via .adopt() when mode is a version pin. + Ignored when mode='legacy'.""" + elicitation_callback: ElicitationFnT | None = None """Callback for handling elicitation requests.""" @@ -132,7 +153,10 @@ async def __aenter__(self) -> Client: ) ) - await self._session.initialize() + if self.mode == "legacy": + await self._session.initialize() + else: + self._session.adopt(self.prior_discover or _synthesize_discover(self.mode)) # Transfer ownership to self for __aexit__ to handle self._exit_stack = exit_stack.pop_all() diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 391d7a8359..e01f0e9dd6 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -20,7 +20,7 @@ from starlette.routing import Mount, Route from mcp.client.client import Client -from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server @@ -35,13 +35,11 @@ from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, - DiscoverResult, Implementation, InitializeRequestParams, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, - ServerCapabilities, jsonrpc_message_adapter, ) from tests.interaction.transports._bridge import StreamingASGITransport @@ -129,9 +127,9 @@ async def connect_over_streamable_http( resumability tests pass an `event_store` (with `retry_interval=0` so the client's reconnection wait is a no-op). - When `spec_version` is a modern (2026-07-28+) revision, the modern path is exercised: a bare - `ClientSession` is built over the streams and adopted from a synthesized `DiscoverResult` - instead of negotiating via `Client`'s legacy initialize handshake. + When `spec_version` is a modern (2026-07-28+) revision the Client is opened with + `mode=`, which adopts a synthesized DiscoverResult instead of running the legacy + initialize handshake. """ app = server.streamable_http_app( stateless_http=stateless_http, @@ -143,48 +141,19 @@ async def connect_over_streamable_http( async with ( server.session_manager.run(), httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client, + Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), + mode=spec_version if spec_version in MODERN_PROTOCOL_VERSIONS else "legacy", + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client, ): - if spec_version in MODERN_PROTOCOL_VERSIONS: - async with ( - streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read, write), - ClientSession( - read, - write, - read_timeout_seconds=read_timeout_seconds, - sampling_callback=sampling_callback, - list_roots_callback=list_roots_callback, - logging_callback=logging_callback, - message_handler=message_handler, - client_info=client_info, - elicitation_callback=elicitation_callback, - ) as session, - ): - session.adopt( - DiscoverResult( - supported_versions=[spec_version], - capabilities=ServerCapabilities(), - server_info=Implementation(name="test", version="0"), - result_type="complete", - ttl_ms=0, - cache_scope="public", - ) - ) - # ClientSession quacks as Client for every modern-arm requirement; the surfaces - # diverge (cursor= vs params=, .session, nullable initialize_result) so widening - # Connect to the union cascades ~58 errors. Contained here until the two converge. - yield session # pyright: ignore[reportReturnType] - else: - async with Client( - streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), - read_timeout_seconds=read_timeout_seconds, - sampling_callback=sampling_callback, - list_roots_callback=list_roots_callback, - logging_callback=logging_callback, - message_handler=message_handler, - client_info=client_info, - elicitation_callback=elicitation_callback, - ) as client: - yield client + yield client connect_over_streamable_http_stateless: Connect = partial(connect_over_streamable_http, stateless_http=True) diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index e0eb958ad8..6e6c2b6f60 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -148,8 +148,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server) as client: - session = client if isinstance(client, ClientSession) else client.session # modern arm yields a bare session - await session.send_notification( + await client.session.send_notification( types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999)) ) result = await client.call_tool("echo", {}) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py index 85ad1eef67..77db90401e 100644 --- a/tests/interaction/lowlevel/test_pagination.py +++ b/tests/interaction/lowlevel/test_pagination.py @@ -5,9 +5,6 @@ pagination scheme. """ -from collections.abc import Callable, Coroutine -from typing import Any, TypeVar - import pytest from inline_snapshot import snapshot @@ -29,19 +26,6 @@ pytestmark = pytest.mark.anyio -_ResultT = TypeVar("_ResultT") - - -def _page(list_fn: Callable[..., Coroutine[Any, Any, _ResultT]], cursor: str | None) -> Coroutine[Any, Any, _ResultT]: - """Call a paginated list method with a cursor on whichever client surface `connect` yielded. - - `Client.list_*` takes `cursor=`; `ClientSession.list_*` takes `params=PaginatedRequestParams(...)`. - """ - try: - return list_fn(cursor=cursor) - except TypeError: - return list_fn(params=types.PaginatedRequestParams(cursor=cursor)) - @requirement("tools:list:pagination") async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> None: @@ -65,7 +49,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa async with connect(server) as client: first_page = await client.list_tools() - second_page = await _page(client.list_tools, first_page.next_cursor) + second_page = await client.list_tools(cursor=first_page.next_cursor) assert first_page.next_cursor == cursor assert seen_cursors == [None, cursor] @@ -95,7 +79,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa requests_made = 0 async with connect(server) as client: while True: - result = await _page(client.list_tools, cursor) + result = await client.list_tools(cursor=cursor) requests_made += 1 assert requests_made <= len(pages), "the server kept returning next_cursor past the last page" collected.extend(tool.name for tool in result.tools) @@ -138,7 +122,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa cursor: str | None = None async with connect(server) as client: while True: - result = await _page(client.list_tools, cursor) + result = await client.list_tools(cursor=cursor) page_sizes.append(len(result.tools)) if result.next_cursor is None: break @@ -166,7 +150,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa async with connect(server) as client: with pytest.raises(MCPError) as exc_info: - await _page(client.list_tools, "never-issued") + await client.list_tools(cursor="never-issued") assert exc_info.value.error.code == INVALID_PARAMS @@ -190,7 +174,7 @@ async def list_resources( async with connect(server) as client: first_page = await client.list_resources() - second_page = await _page(client.list_resources, first_page.next_cursor) + second_page = await client.list_resources(cursor=first_page.next_cursor) assert first_page.next_cursor == cursor assert seen_cursors == [None, cursor] @@ -223,7 +207,7 @@ async def list_resource_templates( async with connect(server) as client: first_page = await client.list_resource_templates() - second_page = await _page(client.list_resource_templates, first_page.next_cursor) + second_page = await client.list_resource_templates(cursor=first_page.next_cursor) assert first_page.next_cursor == cursor assert seen_cursors == [None, cursor] @@ -249,7 +233,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest async with connect(server) as client: first_page = await client.list_prompts() - second_page = await _page(client.list_prompts, first_page.next_cursor) + second_page = await client.list_prompts(cursor=first_page.next_cursor) assert first_page.next_cursor == cursor assert seen_cursors == [None, cursor] From 1fbd0758e67b1323a0b6085f8032e8901ad11eeb Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 22 Jun 2026 15:34:28 +0000 Subject: [PATCH 05/22] ClientSession.discover() with the error ladder; Client mode='auto' - discover() probes server/discover via the dispatcher (bypassing the stamp), validates the response as DiscoverResult before reading any field, then .adopt()s it - Error ladder: -32022 retries once with the intersection of MODERN and data.supported (re-raises if empty or on second failure); -32601 and REQUEST_TIMEOUT fall back to .initialize(); anything else propagates - Idempotent (mirrors .initialize()) - Client.mode gains 'auto' which calls .discover() in __aenter__ - 9 unit tests cover each ladder rung, idempotency, malformed -32022 data, and the response-validation gate; 1 end-to-end test drives mode='auto' over the in-process ASGI bridge --- src/mcp/client/client.py | 10 +- src/mcp/client/session.py | 64 ++++++++++++ tests/client/test_client.py | 16 +++ tests/client/test_session.py | 194 +++++++++++++++++++++++++++++++++++ 4 files changed, 280 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index cf049e7d9e..2c9b862856 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -107,10 +107,10 @@ async def main(): client_info: Implementation | None = None """Client implementation info to send to server.""" - mode: Literal["legacy"] | str = "legacy" - """'legacy' performs the initialize handshake. A protocol-version string (e.g. '2026-07-28') adopts that - version directly without a handshake — supply prior_discover to reuse a known DiscoverResult, or omit it - to synthesize a minimal one.""" + mode: Literal["legacy", "auto"] | str = "legacy" + """'legacy' performs the initialize handshake. 'auto' probes server/discover and falls back to initialize() + on legacy servers. A protocol-version string (e.g. '2026-07-28') adopts that version directly without a + handshake — supply prior_discover to reuse a known DiscoverResult, or omit it to synthesize a minimal one.""" prior_discover: types.DiscoverResult | None = None """A previously-obtained DiscoverResult to install via .adopt() when mode is a version pin. @@ -155,6 +155,8 @@ async def __aenter__(self) -> Client: if self.mode == "legacy": await self._session.initialize() + elif self.mode == "auto": + await self._session.discover() else: self._session.adopt(self.prior_discover or _synthesize_discover(self.mode)) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index ab162c4d8c..f2b4c985ae 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -34,12 +34,15 @@ INTERNAL_ERROR, METHOD_NOT_FOUND, PROTOCOL_VERSION_META_KEY, + REQUEST_TIMEOUT, + UNSUPPORTED_PROTOCOL_VERSION, RequestId, RequestParamsMeta, ) from mcp.types import methods as _methods DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") +DISCOVER_TIMEOUT_SECONDS = 10.0 logger = logging.getLogger("client") @@ -368,6 +371,67 @@ def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None: self._stamp = _make_handshake_stamp(result.protocol_version) self._initialize_result = result + async def discover(self) -> types.InitializeResult: + """Probe `server/discover` and adopt the result, falling back to `initialize()`. + + Sends a single `server/discover` proposing the newest modern protocol + version. The error ladder, in order: + + - `UNSUPPORTED_PROTOCOL_VERSION` (-32022): the server's `supported` + list is intersected with `MODERN_PROTOCOL_VERSIONS` and the probe is + retried once at the highest mutual version. No mutual version, or a + second failure, raises the server's `MCPError`. + - `METHOD_NOT_FOUND` (-32601) or `REQUEST_TIMEOUT` (-32001): the server + is treated as legacy and `initialize()` runs instead — exactly as + ``mode='legacy'`` would. + - Any other error: re-raised. + + Returns the synthesized `InitializeResult` (also available afterwards + via `initialize_result`). + """ + if self._initialize_result is not None: + return self._initialize_result + + client_info = self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True) + capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) + + async def probe(version: str) -> dict[str, Any]: + params = { + "_meta": { + PROTOCOL_VERSION_META_KEY: version, + CLIENT_INFO_META_KEY: client_info, + CLIENT_CAPABILITIES_META_KEY: capabilities, + } + } + opts: CallOptions = { + "timeout": DISCOVER_TIMEOUT_SECONDS, + "cancel_on_abandon": False, + "headers": {MCP_PROTOCOL_VERSION_HEADER: version}, + } + return await self._dispatcher.send_raw_request("server/discover", params, opts) + + try: + raw = await probe(MODERN_PROTOCOL_VERSIONS[-1]) + except MCPError as e: + if e.code == UNSUPPORTED_PROTOCOL_VERSION: + try: + data = types.UnsupportedProtocolVersionErrorData.model_validate(e.error.data) + except ValidationError: + raise e from None + mutual = [v for v in MODERN_PROTOCOL_VERSIONS if v in data.supported] + if not mutual: + raise + raw = await probe(mutual[-1]) + elif e.code in (METHOD_NOT_FOUND, REQUEST_TIMEOUT): + return await self.initialize() + else: + raise + + result = types.DiscoverResult.model_validate(raw) + self.adopt(result) + assert self._initialize_result is not None + return self._initialize_result + @property def initialize_result(self) -> types.InitializeResult | None: """The server's InitializeResult. None until `initialize()` or `adopt()`. diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 2accd093d6..003749c380 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -14,6 +14,7 @@ from mcp import MCPError, types from mcp.client._memory import InMemoryTransport from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer from mcp.types import ( @@ -37,6 +38,7 @@ Tool, ToolsCapability, ) +from tests.interaction._connect import BASE_URL, mounted_app pytestmark = pytest.mark.anyio @@ -359,3 +361,17 @@ async def check_context() -> str: assert result.content[0].text == "client_value", ( # type: ignore[union-attr] "Server handler did not see the sender's contextvars.Context" ) + + +async def test_client_auto_mode_probes_discover_then_adopts(simple_server: Server) -> None: + """`mode='auto'` over an in-process HTTP transport: the `server/discover` probe + reaches the modern entry and the negotiated protocol version is adopted without + an `initialize` handshake. Runs over HTTP because the in-memory runner gates + `server/discover` behind the init handshake.""" + with anyio.fail_after(5): + async with ( + mounted_app(simple_server) as (http, _), + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto") as client, + ): + assert client.initialize_result.protocol_version == "2026-07-28" + assert (await client.list_resources()).resources[0].name == "Test Resource" diff --git a/tests/client/test_session.py b/tests/client/test_session.py index d3732ec55c..a4bb1b2ddc 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -25,7 +25,9 @@ INVALID_PARAMS, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, + PROTOCOL_VERSION_META_KEY, REQUEST_TIMEOUT, + UNSUPPORTED_PROTOCOL_VERSION, CallToolResult, Implementation, InitializedNotification, @@ -1435,3 +1437,195 @@ async def test_send_notification_after_close_is_dropped_silently(): finally: for s in (s2c_send, s2c_recv, c2s_send, c2s_recv): s.close() + + +# --- discover() ladder --- + + +class _ScriptedDispatcher: + """Records every `send_raw_request` and plays back scripted answers in order. + + A script entry that is an `Exception` is raised; a dict is returned.""" + + def __init__(self, *script: dict[str, Any] | Exception) -> None: + self.calls: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifies: list[str] = [] + self._script: list[dict[str, Any] | Exception] = list(script) + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + task_status.started() + await anyio.sleep_forever() + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.calls.append((method, params)) + item = self._script.pop(0) + if isinstance(item, Exception): + raise item + return item + + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: + self.notifies.append(method) + + +def _discover_result_dict() -> dict[str, Any]: + return types.DiscoverResult( + supported_versions=["2026-07-28"], + capabilities=ServerCapabilities(), + server_info=Implementation(name="stub", version="0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True) + + +def _initialize_result_dict() -> dict[str, Any]: + return InitializeResult( + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + capabilities=ServerCapabilities(), + server_info=Implementation(name="stub", version="0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True) + + +@pytest.mark.anyio +async def test_discover_adopts_the_returned_result_and_installs_the_modern_stamp() -> None: + """SDK-defined: a successful `server/discover` is adopted and subsequent requests + carry the modern `_meta` envelope (protocol version + client info + capabilities).""" + dispatcher = _ScriptedDispatcher(_discover_result_dict(), {}) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.discover() + assert session.protocol_version == "2026-07-28" + await session.send_ping() + ping_method, ping_params = dispatcher.calls[-1] + assert ping_method == "ping" + assert ping_params is not None + assert ping_params["_meta"][PROTOCOL_VERSION_META_KEY] == "2026-07-28" + + +@pytest.mark.anyio +async def test_discover_retries_once_on_unsupported_version_then_adopts() -> None: + """Spec SHOULD: a -32022 reply that names a mutually-supported version + triggers exactly one retry at that version, and the retry's result is adopted.""" + dispatcher = _ScriptedDispatcher( + MCPError( + UNSUPPORTED_PROTOCOL_VERSION, + "unsupported", + data={"supported": ["2026-07-28"], "requested": "2026-07-28"}, + ), + _discover_result_dict(), + ) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.discover() + assert session.protocol_version == "2026-07-28" + assert [m for m, _ in dispatcher.calls] == ["server/discover", "server/discover"] + + +@pytest.mark.anyio +async def test_discover_raises_when_retry_intersection_is_empty() -> None: + """Spec SHOULD: a -32022 reply whose `supported` list shares nothing with the + client's modern versions is unrecoverable — the original error is re-raised + without a second probe.""" + dispatcher = _ScriptedDispatcher( + MCPError( + UNSUPPORTED_PROTOCOL_VERSION, + "unsupported", + data={"supported": ["1999-01-01"], "requested": "2026-07-28"}, + ), + ) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + with pytest.raises(MCPError) as exc: + await session.discover() + assert exc.value.error.code == UNSUPPORTED_PROTOCOL_VERSION + assert [m for m, _ in dispatcher.calls] == ["server/discover"] + + +@pytest.mark.anyio +async def test_discover_falls_back_to_initialize_on_method_not_found() -> None: + """Spec SHOULD: a legacy server that answers -32601 to `server/discover` is + transparently driven through the handshake instead.""" + dispatcher = _ScriptedDispatcher( + MCPError(METHOD_NOT_FOUND, "Method not found"), + _initialize_result_dict(), + ) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.discover() + assert session.protocol_version in HANDSHAKE_PROTOCOL_VERSIONS + assert [m for m, _ in dispatcher.calls] == ["server/discover", "initialize"] + assert dispatcher.notifies == ["notifications/initialized"] + + +@pytest.mark.anyio +async def test_discover_falls_back_to_initialize_on_timeout() -> None: + """Spec SHOULD: a `REQUEST_TIMEOUT` from the dispatcher is treated the same as + method-not-found — the server is presumed legacy and `initialize()` runs.""" + dispatcher = _ScriptedDispatcher( + MCPError(REQUEST_TIMEOUT, "timed out"), + _initialize_result_dict(), + ) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.discover() + assert session.protocol_version in HANDSHAKE_PROTOCOL_VERSIONS + assert [m for m, _ in dispatcher.calls] == ["server/discover", "initialize"] + + +@pytest.mark.anyio +async def test_discover_reraises_on_other_errors() -> None: + """SDK-defined: any error outside the retry/fallback ladder propagates verbatim + — `discover()` does not mask server failures by falling back to `initialize()`.""" + dispatcher = _ScriptedDispatcher(MCPError(INTERNAL_ERROR, "boom")) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + with pytest.raises(MCPError) as exc: + await session.discover() + assert exc.value.error.code == INTERNAL_ERROR + assert [m for m, _ in dispatcher.calls] == ["server/discover"] + + +@pytest.mark.anyio +async def test_discover_validates_the_response_shape_before_adopting() -> None: + """SDK-defined: the raw response is run through `DiscoverResult` validation + before any state is installed, so a malformed reply leaves the session + un-adopted rather than half-configured.""" + dispatcher = _ScriptedDispatcher({"supportedVersions": ["2026-07-28"]}) + session = ClientSession(dispatcher=dispatcher) + with anyio.fail_after(5): + async with session: + with pytest.raises(ValidationError): + await session.discover() + assert session.protocol_version is None + + +@pytest.mark.anyio +async def test_discover_is_idempotent_and_returns_the_cached_result() -> None: + """SDK-defined: a second `discover()` returns the already-adopted result without + re-probing — the script holds exactly one entry, so a second wire call would + `IndexError` on the empty script.""" + dispatcher = _ScriptedDispatcher(_discover_result_dict()) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + first = await session.discover() + assert await session.discover() is first + assert [m for m, _ in dispatcher.calls] == ["server/discover"] + + +@pytest.mark.anyio +async def test_discover_reraises_unsupported_version_with_malformed_error_data() -> None: + """SDK-defined: a -32022 reply whose `data` is not a valid + `UnsupportedProtocolVersionErrorData` payload is unrecoverable — the original + error is re-raised without a retry probe.""" + dispatcher = _ScriptedDispatcher(MCPError(UNSUPPORTED_PROTOCOL_VERSION, "unsupported", data="not-an-object")) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + with pytest.raises(MCPError) as exc: + await session.discover() + assert exc.value.error.code == UNSUPPORTED_PROTOCOL_VERSION + assert [m for m, _ in dispatcher.calls] == ["server/discover"] From 46c0742010ed7d97be856de48349171ccfa3012c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 22 Jun 2026 16:14:24 +0000 Subject: [PATCH 06/22] modern_on_request driver + Client in-process modern path via DirectDispatcher peer-pair MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - modern_on_request(server, lifespan_state) returns an OnRequest callback that builds Connection.from_envelope per call and drives serve_one — wire it into the server side of a DirectDispatcher peer-pair for an in-process server on the modern per-request path - Client(Server|MCPServer, mode!=legacy) enters lifespan once, creates a peer-pair, runs the server side with modern_on_request, and hands the client side to ClientSession; legacy in-process keeps InMemoryTransport - Interaction-suite in-memory transport unlocked for 2026-07-28: 71 tests now run on [in-memory-2026-07-28], 67 pass; the 5 streamable-http-only notify-drop xfails are scoped to that transport; 4 progress-notification tests still xfail (peer-pair progress wiring tracked separately) --- src/mcp/client/client.py | 56 +++++++++++++++++++++++++----- src/mcp/server/runner.py | 37 +++++++++++++++++++- tests/interaction/_connect.py | 8 ++++- tests/interaction/_requirements.py | 21 +++++++---- tests/interaction/test_coverage.py | 1 + 5 files changed, 106 insertions(+), 17 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 2c9b862856..9e046ae1e0 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -2,10 +2,12 @@ from __future__ import annotations +from collections.abc import Mapping from contextlib import AsyncExitStack from dataclasses import KW_ONLY, dataclass, field from typing import Any, Literal +import anyio from typing_extensions import deprecated from mcp import types @@ -15,6 +17,8 @@ from mcp.client.streamable_http import streamable_http_client from mcp.server import Server from mcp.server.mcpserver import MCPServer +from mcp.server.runner import modern_on_request +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import ProgressFnT from mcp.shared.exceptions import MCPDeprecationWarning from mcp.types import ( @@ -48,6 +52,14 @@ def _synthesize_discover(protocol_version: str) -> types.DiscoverResult: ) +async def _drop_notify(_dctx: Any, _method: str, _params: Mapping[str, Any] | None) -> None: + """Server-side ``OnNotify`` for the modern in-process path: client→server notifications are dropped. + + The per-request driver (`serve_one`) has no notification dispatch table; progress and + cancellation travel via `CallOptions` on the `DirectDispatcher`, not as JSON-RPC notifies. + """ + + @dataclass class Client: """A high-level MCP client for connecting to MCP servers. @@ -121,11 +133,14 @@ async def main(): _session: ClientSession | None = field(init=False, default=None) _exit_stack: AsyncExitStack | None = field(init=False, default=None) - _transport: Transport = field(init=False) + _transport: Transport | None = field(init=False, default=None) + _inproc_server: Server[Any] | None = field(init=False, default=None) def __post_init__(self) -> None: - if isinstance(self.server, Server | MCPServer): - self._transport = InMemoryTransport(self.server, raise_exceptions=self.raise_exceptions) + if isinstance(self.server, MCPServer): + self._inproc_server = self.server._lowlevel_server # pyright: ignore[reportPrivateUsage] + elif isinstance(self.server, Server): + self._inproc_server = self.server elif isinstance(self.server, str): self._transport = streamable_http_client(self.server) else: @@ -137,10 +152,34 @@ async def __aenter__(self) -> Client: raise RuntimeError("Client is already entered; cannot reenter") async with AsyncExitStack() as exit_stack: - read_stream, write_stream = await exit_stack.enter_async_context(self._transport) - - self._session = await exit_stack.enter_async_context( - ClientSession( + if self._inproc_server is not None and self.mode != "legacy": + # Modern in-process path: drive the server through a DirectDispatcher peer-pair + # with one `serve_one` per request — no streams, no initialize handshake. + lifespan_state = await exit_stack.enter_async_context(self._inproc_server.lifespan(self._inproc_server)) + client_disp, server_disp = create_direct_dispatcher_pair() + tg = await exit_stack.enter_async_context(anyio.create_task_group()) + exit_stack.callback(server_disp.close) + await tg.start(server_disp.run, modern_on_request(self._inproc_server, lifespan_state), _drop_notify) + session = ClientSession( + dispatcher=client_disp, + read_timeout_seconds=self.read_timeout_seconds, + sampling_callback=self.sampling_callback, + list_roots_callback=self.list_roots_callback, + logging_callback=self.logging_callback, + message_handler=self.message_handler, + client_info=self.client_info, + elicitation_callback=self.elicitation_callback, + ) + else: + if self._inproc_server is not None: + transport: Transport = InMemoryTransport( + self._inproc_server, raise_exceptions=self.raise_exceptions + ) + else: + assert self._transport is not None + transport = self._transport + read_stream, write_stream = await exit_stack.enter_async_context(transport) + session = ClientSession( read_stream=read_stream, write_stream=write_stream, read_timeout_seconds=self.read_timeout_seconds, @@ -151,7 +190,8 @@ async def __aenter__(self) -> Client: client_info=self.client_info, elicitation_callback=self.elicitation_callback, ) - ) + + self._session = await exit_stack.enter_async_context(session) if self.mode == "legacy": await self._session.initialize() diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index f8c8d2ae44..4cf508363a 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -36,11 +36,14 @@ from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher, handler_exception_to_error_data from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.transport_context import TransportContext -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( + CLIENT_CAPABILITIES_META_KEY, + CLIENT_INFO_META_KEY, INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND, + PROTOCOL_VERSION_META_KEY, ErrorData, Implementation, InitializeRequestParams, @@ -62,6 +65,7 @@ "ServerMiddleware", "ServerRunner", "aclose_shielded", + "modern_on_request", "otel_middleware", "serve_connection", "serve_loop", @@ -512,3 +516,34 @@ async def serve_one( return await to_jsonrpc_response(request.id, runner.on_request(dctx, request.method, request.params)) finally: await aclose_shielded(connection) + + +def modern_on_request(server: Server[LifespanT], lifespan_state: LifespanT) -> OnRequest: + """Return an `OnRequest` callback that serves each call via `serve_one` with a fresh per-request `Connection`. + + Wire this into the server side of a `DirectDispatcher` peer-pair to drive an + in-process server on the modern per-request-envelope path (each request + carries protocol version, client info, and capabilities in `params._meta`; + no `initialize` handshake). + """ + + async def handle( + dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + meta = (params or {}).get("_meta", {}) + connection = Connection.from_envelope( + meta.get(PROTOCOL_VERSION_META_KEY, MODERN_PROTOCOL_VERSIONS[-1]), + meta.get(CLIENT_INFO_META_KEY), + meta.get(CLIENT_CAPABILITIES_META_KEY), + ) + # `OnRequest` is invoked for requests only, so `request_id` is always set. + assert dctx.request_id is not None + req = JSONRPCRequest( + jsonrpc="2.0", id=dctx.request_id, method=method, params=dict(params) if params is not None else None + ) + msg = await serve_one(server, req, connection=connection, dctx=dctx, lifespan_state=lifespan_state) + if isinstance(msg, JSONRPCError): + raise MCPError(code=msg.error.code, message=msg.error.message, data=msg.error.data) + return msg.result + + return handle diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index e01f0e9dd6..05d2132c0b 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -88,9 +88,15 @@ async def connect_in_memory( elicitation_callback: ElicitationFnT | None = None, spec_version: str = LATEST_PROTOCOL_VERSION, ) -> AsyncIterator[Client]: - """Yield a Client connected to the server over the in-memory transport.""" + """Yield a Client connected to the server over the in-memory transport. + + When `spec_version` is a modern (2026-07-28+) revision the Client is opened with + `mode=`, which drives the server through the DirectDispatcher peer-pair + (per-request `serve_one`, no initialize handshake) instead of the legacy stream pair. + """ async with Client( server, + mode=spec_version if spec_version in MODERN_PROTOCOL_VERSIONS else "legacy", read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, list_roots_callback=list_roots_callback, diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 9aee73b29b..a64a17c55b 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -63,9 +63,7 @@ TRANSPORT_SPEC_VERSIONS: dict[Transport, tuple[SpecVersion, ...]] = { "sse": ("2025-11-25",), - # Temporary lock: the in-memory transport has no modern entry point yet, so it cannot - # negotiate the newer revision. Remove once an in-memory factory for the modern path lands. - "in-memory": ("2025-11-25",), + "in-memory": ("2025-11-25", "2026-07-28"), # At the newer revision the protocol-version header check runs before the stateless branch is # taken, so a stateless connection at that revision behaves identically to the stateful one. # Locked to avoid a redundant matrix column; revisit if the header/stateless ordering changes. @@ -745,7 +743,9 @@ def __post_init__(self) -> None: "Log notifications emitted by a tool handler during execution reach the client's logging " "callback before the tool result returns." ), - known_failures=(KnownFailure(spec_version="2026-07-28", note=_MODERN_NOTIFY_DROP, issue=None),), + known_failures=( + KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), + ), ), "tools:call:progress": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", @@ -974,7 +974,9 @@ def __post_init__(self) -> None: "The Context logging helpers (debug/info/warning/error) send log message notifications at the " "corresponding severity." ), - known_failures=(KnownFailure(spec_version="2026-07-28", note=_MODERN_NOTIFY_DROP, issue=None),), + known_failures=( + KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), + ), ), "mcpserver:context:progress": Requirement( source="sdk", @@ -1339,7 +1341,9 @@ def __post_init__(self) -> None: "logging:message:all-levels": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", - known_failures=(KnownFailure(spec_version="2026-07-28", note=_MODERN_NOTIFY_DROP, issue=None),), + known_failures=( + KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), + ), ), "logging:message:fields": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", @@ -1347,7 +1351,9 @@ def __post_init__(self) -> None: "A log message sent by a server handler is delivered to the client's logging callback with its " "severity level, logger name, and data." ), - known_failures=(KnownFailure(spec_version="2026-07-28", note=_MODERN_NOTIFY_DROP, issue=None),), + known_failures=( + KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), + ), ), "logging:message:filtered": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", @@ -1970,6 +1976,7 @@ def __post_init__(self) -> None: known_failures=( KnownFailure( spec_version="2026-07-28", + transport="streamable-http", note=( "List-mutation assertions hold; only the sentinel ctx.info() never reaches the client. " + _MODERN_NOTIFY_DROP diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py index 2c7e486ab3..26e697c3ba 100644 --- a/tests/interaction/test_coverage.py +++ b/tests/interaction/test_coverage.py @@ -301,6 +301,7 @@ def test_compute_cells_drops_era_locked_transport_outside_its_versions() -> None "sse-2025-11-25", "streamable-http-2025-11-25", "streamable-http-stateless-2025-11-25", + "in-memory-2026-07-28", "streamable-http-2026-07-28", ] From dcbe6e8fb6d442e5d44b5420c9327bf7e635c9e5 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 22 Jun 2026 16:44:27 +0000 Subject: [PATCH 07/22] Sweep: route report_progress through DispatchContext.progress; bump LATEST_PROTOCOL_VERSION; orphan cleanup - Context.report_progress now delegates to DispatchContext.progress() via ServerSession.report_progress (was: token-gated send_notification, which only worked under JSONRPCDispatcher). Progress now reaches the client on the in-process modern path; 4 progress-notification xfails flip to pass. ServerSession's request_outbound is typed DispatchContext (it always was one at runtime). - LATEST_PROTOCOL_VERSION bumped to '2026-07-28' (the newest revision the SDK supports). Handshake-outcome assertions and mock-InitializeResult fixtures switched to HANDSHAKE_PROTOCOL_VERSIONS[-1]. migration.md entry. - ServerMessageMetadata.protocol_version deleted (no readers, no writers). - ClientSession.send_progress_notification and Client.send_progress_notification deprecated (client-to-server progress is server-to-client only at 2026-07-28). - Mcp-Name TODO re-anchored on _make_modern_stamp. --- docs/migration.md | 8 ++++ pyproject.toml | 3 ++ src/mcp/client/client.py | 6 ++- src/mcp/client/session.py | 5 ++ src/mcp/server/mcpserver/context.py | 13 +----- src/mcp/server/session.py | 14 +++++- src/mcp/shared/message.py | 3 -- src/mcp/types/_types.py | 2 +- tests/client/test_client.py | 5 +- tests/client/test_session.py | 27 ++++++----- tests/interaction/_connect.py | 17 ++++--- tests/interaction/_requirements.py | 25 ++++++++-- tests/interaction/auth/test_flow.py | 2 +- tests/interaction/lowlevel/test_progress.py | 46 +++++-------------- .../transports/test_hosting_resume.py | 4 +- tests/issues/test_176_progress_token.py | 38 +++++++-------- tests/issues/test_552_windows_hang.py | 5 +- tests/server/mcpserver/test_server.py | 22 ++++----- tests/server/test_runner.py | 20 ++++---- tests/server/test_session.py | 36 +++++++++++++-- tests/server/test_stateless_mode.py | 14 +++++- tests/shared/test_inbound.py | 13 +++--- 22 files changed, 182 insertions(+), 146 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 91d24f354e..65fb9d090b 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -776,6 +776,8 @@ async def my_tool(ctx: Context[MyLifespanState]) -> str: ... `SUPPORTED_PROTOCOL_VERSIONS` is deprecated — it's now the union of `HANDSHAKE_PROTOCOL_VERSIONS` (initialize-handshake versions) and `MODERN_PROTOCOL_VERSIONS` (per-request-envelope versions). If you were using it to mean "versions the initialize handshake accepts", switch to `HANDSHAKE_PROTOCOL_VERSIONS`. +`LATEST_PROTOCOL_VERSION` now reflects the newest protocol revision the SDK supports (`2026-07-28`). Code that used it to mean "the version `.initialize()` offers" should switch to `HANDSHAKE_PROTOCOL_VERSIONS[-1]`. + ### `ProgressContext` and `progress()` context manager removed The `mcp.shared.progress` module (`ProgressContext`, `Progress`, and the `progress()` context manager) has been removed. This module had no real-world adoption — all users send progress notifications via `Context.report_progress()` or `session.send_progress_notification()` directly. @@ -1301,6 +1303,12 @@ warnings.filterwarnings("ignore", category=MCPDeprecationWarning) No migration is required during the deprecation window. New code should avoid building on these features, since they may be removed in a future spec version. +### Client-to-server progress deprecated (2026-07-28) + +The 2026-07-28 spec restricts `notifications/progress` to the server-to-client direction only — `ProgressNotification` is no longer in `ClientNotification`. `Client.send_progress_notification()` and `ClientSession.send_progress_notification()` now carry `typing_extensions.deprecated` and emit `mcp.MCPDeprecationWarning` at runtime. They continue to work against servers negotiating 2025-11-25 or earlier. + +On the server side, prefer the new dispatcher-agnostic `ServerSession.report_progress(progress, total, message)` (and `Context.report_progress()` on `MCPServer`) over the raw `ServerSession.send_progress_notification(progress_token, …)`. `report_progress` encapsulates the "no-op when the caller did not request progress" rule and works on every dispatcher; the raw token-taking form remains for handlers that read `_meta.progressToken` directly. + ## Bug Fixes ### OAuth metadata URLs no longer gain a trailing slash diff --git a/pyproject.toml b/pyproject.toml index 07bfff740e..e02c727be2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -217,6 +217,9 @@ filterwarnings = [ # them internally (e.g. `ctx.debug` -> `log` -> `send_log_message`), so the # advisory warning is silenced. Tests asserting it opt back in with pytest.warns. "ignore:.*is deprecated as of 2026-07-28 \\(SEP-2577\\).:mcp.MCPDeprecationWarning", + # 2026-07-28 restricts progress to server->client; the client send path is + # advisory-deprecated and a handful of tests still exercise it. + "ignore:Client-to-server progress is deprecated as of 2026-07-28.*:mcp.MCPDeprecationWarning", ] [tool.markdown.lint] diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 9e046ae1e0..b5b234e3f5 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -239,6 +239,10 @@ async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> EmptyResu """Send a ping request to the server.""" return await self.session.send_ping(meta=meta) + @deprecated( + "Client-to-server progress is deprecated as of 2026-07-28; progress is server-to-client only.", + category=MCPDeprecationWarning, + ) async def send_progress_notification( self, progress_token: str | int, @@ -247,7 +251,7 @@ async def send_progress_notification( message: str | None = None, ) -> None: """Send a progress notification to the server.""" - await self.session.send_progress_notification( + await self.session.send_progress_notification( # pyright: ignore[reportDeprecated] progress_token=progress_token, progress=progress, total=total, diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index f2b4c985ae..5f4d795211 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -72,6 +72,7 @@ def stamp(data: dict[str, Any], opts: CallOptions) -> None: headers = opts.setdefault("headers", {}) headers[MCP_PROTOCOL_VERSION_HEADER] = protocol_version headers[MCP_METHOD_HEADER] = data["method"] + # TODO: also emit Mcp-Name for prompts/get (params.name) and resources/read (params.uri) if data["method"] == "tools/call" and isinstance(name := params.get("name"), str): headers[MCP_NAME_HEADER] = encode_header_value(name) @@ -451,6 +452,10 @@ async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.Emp """Send a ping request.""" return await self.send_request(types.PingRequest(params=types.RequestParams(_meta=meta)), types.EmptyResult) + @deprecated( + "Client-to-server progress is deprecated as of 2026-07-28; progress is server-to-client only.", + category=MCPDeprecationWarning, + ) async def send_progress_notification( self, progress_token: str | int, diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 0bf0b7ebfd..7856e32185 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -94,18 +94,7 @@ async def report_progress(self, progress: float, total: float | None = None, mes total: Optional total value (e.g., 100) message: Optional message (e.g., "Starting render...") """ - progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None - - if progress_token is None: - return - - await self.request_context.session.send_progress_notification( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - related_request_id=self.request_id, - ) + await self.request_context.session.report_progress(progress, total, message) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: """Read a resource by URI. diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index be4c1805a9..aa84ad37b8 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -14,7 +14,7 @@ from mcp import types from mcp.server.connection import Connection from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared.dispatcher import CallOptions, Outbound, ProgressFnT +from mcp.shared.dispatcher import CallOptions, DispatchContext, ProgressFnT from mcp.shared.exceptions import MCPDeprecationWarning from mcp.shared.message import ServerMessageMetadata from mcp.types import methods as _methods @@ -36,7 +36,7 @@ class ServerSession: never crosses the `Outbound` Protocol. """ - def __init__(self, request_outbound: Outbound, connection: Connection) -> None: + def __init__(self, request_outbound: DispatchContext[Any], connection: Connection) -> None: self._request_outbound = request_outbound self._connection = connection @@ -353,6 +353,16 @@ async def send_ping(self) -> types.EmptyResult: types.EmptyResult, ) + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the inbound request this session is scoped to. + + A no-op when the caller did not request progress. Dispatcher-agnostic: + on JSON-RPC the held `DispatchContext` emits ``notifications/progress`` + against the caller's token; on the in-process direct dispatcher it + invokes the caller's callback directly. + """ + await self._request_outbound.progress(progress, total, message) + async def send_progress_notification( self, progress_token: str | int, diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index fda0fb8cc3..a0b6561151 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -37,9 +37,6 @@ class ServerMessageMetadata: # transports, None for stdio). Typed as Any because the server layer is # transport-agnostic. request_context: Any = None - # Per-message protocol version observed by the transport (e.g. the - # validated MCP-Protocol-Version header). - protocol_version: str | None = None # Callback to close SSE stream for the current request without terminating close_sse_stream: CloseSSEStreamCallback | None = None # Callback to close the standalone GET SSE stream (for unsolicited notifications) diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 82b4a084d5..bf260b67b2 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -22,7 +22,7 @@ from mcp.types.jsonrpc import RequestId -LATEST_PROTOCOL_VERSION: Final[str] = "2025-11-25" +LATEST_PROTOCOL_VERSION: Final[str] = "2026-07-28" """The newest protocol version this SDK can negotiate. See https://modelcontextprotocol.io/specification/latest. diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 003749c380..6ebdf9553e 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -17,6 +17,7 @@ from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS from mcp.types import ( CallToolResult, EmptyResult, @@ -118,7 +119,7 @@ async def test_client_is_initialized(app: MCPServer): async def test_client_initialize_result_exposes_negotiated_protocol_version(app: MCPServer): """The negotiated protocol version is readable after initialization.""" async with Client(app) as client: - assert client.initialize_result.protocol_version == types.LATEST_PROTOCOL_VERSION + assert client.initialize_result.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] async def test_client_with_simple_server(simple_server: Server): @@ -241,7 +242,7 @@ async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotif server = Server(name="test_server", on_progress=handle_progress) async with Client(server) as client: - await client.send_progress_notification(progress_token="token123", progress=50.0) + await client.send_progress_notification(progress_token="token123", progress=50.0) # pyright: ignore[reportDeprecated] await event.wait() assert received_from_client == snapshot({"progress_token": "token123", "progress": 50.0}) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index a4bb1b2ddc..54e17e57d3 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -23,7 +23,6 @@ CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_PARAMS, - LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, PROTOCOL_VERSION_META_KEY, REQUEST_TIMEOUT, @@ -88,7 +87,7 @@ async def mock_server(): assert isinstance(request, InitializeRequest) result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=ServerCapabilities( logging=None, resources=None, @@ -141,7 +140,7 @@ async def message_handler( # pragma: no cover # Assert the result assert isinstance(result, InitializeResult) - assert result.protocol_version == LATEST_PROTOCOL_VERSION + assert result.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] assert isinstance(result.capabilities, ServerCapabilities) assert result.server_info == Implementation(name="mock-server", version="0.1.0") assert result.instructions == "The server instructions." @@ -172,7 +171,7 @@ async def mock_server(): received_client_info = request.params.client_info result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -229,7 +228,7 @@ async def mock_server(): received_client_info = request.params.client_info result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -278,8 +277,8 @@ async def mock_server(): ) assert isinstance(request, InitializeRequest) - # Verify client sent the latest protocol version - assert request.params.protocol_version == LATEST_PROTOCOL_VERSION + # Verify client offers the newest handshake protocol version + assert request.params.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] # Server responds with a supported older version result = InitializeResult( @@ -387,7 +386,7 @@ async def mock_server(): received_capabilities = request.params.capabilities result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -458,7 +457,7 @@ async def mock_server(): received_capabilities = request.params.capabilities result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -537,7 +536,7 @@ async def mock_server(): received_capabilities = request.params.capabilities result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -605,7 +604,7 @@ async def mock_server(): assert isinstance(request, InitializeRequest) result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=expected_capabilities, server_info=expected_server_info, instructions=expected_instructions, @@ -644,7 +643,7 @@ async def mock_server(): assert result.server_info == expected_server_info assert result.capabilities == expected_capabilities assert result.instructions == expected_instructions - assert result.protocol_version == LATEST_PROTOCOL_VERSION + assert result.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] @pytest.mark.anyio @@ -667,7 +666,7 @@ async def mock_server(): assert isinstance(request, InitializeRequest) result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -1348,7 +1347,7 @@ async def send_raw_request( self.calls.append((method, opts or {})) if method == "initialize": return InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ).model_dump(by_alias=True, mode="json", exclude_none=True) diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 05d2132c0b..654c74add2 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -31,9 +31,8 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.version import MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( - LATEST_PROTOCOL_VERSION, ClientCapabilities, Implementation, InitializeRequestParams, @@ -71,7 +70,7 @@ def __call__( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - spec_version: str = LATEST_PROTOCOL_VERSION, + spec_version: str = HANDSHAKE_PROTOCOL_VERSIONS[-1], ) -> AbstractAsyncContextManager[Client]: ... @@ -86,7 +85,7 @@ async def connect_in_memory( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - spec_version: str = LATEST_PROTOCOL_VERSION, + spec_version: str = HANDSHAKE_PROTOCOL_VERSIONS[-1], ) -> AsyncIterator[Client]: """Yield a Client connected to the server over the in-memory transport. @@ -123,7 +122,7 @@ async def connect_over_streamable_http( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - spec_version: str = LATEST_PROTOCOL_VERSION, + spec_version: str = HANDSHAKE_PROTOCOL_VERSIONS[-1], ) -> AsyncIterator[Client]: """Yield a Client connected to the server's streamable HTTP app, entirely in process. @@ -270,14 +269,14 @@ def base_headers(*, session_id: str | None = None) -> dict[str, str]: """Standard request headers for raw-httpx streamable-HTTP tests. Every well-formed request carries these (Accept covering both response representations, - Content-Type for POST bodies, MCP-Protocol-Version at the latest revision, and the session + Content-Type for POST bodies, MCP-Protocol-Version at the newest handshake revision, and the session ID once one exists), so a test that wants to assert a specific rejection only varies the one header under test. """ headers = { "accept": "application/json, text/event-stream", "content-type": "application/json", - "mcp-protocol-version": LATEST_PROTOCOL_VERSION, + "mcp-protocol-version": HANDSHAKE_PROTOCOL_VERSIONS[-1], } if session_id is not None: headers["mcp-session-id"] = session_id @@ -287,7 +286,7 @@ def base_headers(*, session_id: str | None = None) -> dict[str, str]: def initialize_body(request_id: int = 1) -> dict[str, object]: """A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it.""" params = InitializeRequestParams( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=ClientCapabilities(), client_info=Implementation(name="raw", version="0.0.0"), ) @@ -355,7 +354,7 @@ async def connect_over_sse( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - spec_version: str = LATEST_PROTOCOL_VERSION, + spec_version: str = HANDSHAKE_PROTOCOL_VERSIONS[-1], ) -> AsyncIterator[Client]: """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" app, _ = build_sse_app(server) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index a64a17c55b..ce646ad598 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -579,7 +579,9 @@ def __post_init__(self) -> None: "Progress notifications emitted by a handler during a request are delivered to the caller's " "progress callback, in order, with their progress, total, and message." ), - known_failures=(KnownFailure(spec_version="2026-07-28", note=_MODERN_NOTIFY_DROP, issue=None),), + known_failures=( + KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), + ), ), "protocol:progress:token-injected": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", @@ -592,7 +594,14 @@ def __post_init__(self) -> None: "protocol:progress:token-unique": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", behavior=("Concurrent in-flight requests that each supply a progress callback carry distinct progress tokens."), - known_failures=(KnownFailure(spec_version="2026-07-28", note=_MODERN_NOTIFY_DROP, issue=None),), + note=( + "Tested as the consequence: each callback receives only its own request's progress under " + "interleaved emission. Token distinctness is the JSON-RPC mechanism for that; the in-process " + "direct dispatcher carries the callback per-request without a wire-level token." + ), + known_failures=( + KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), + ), ), "protocol:progress:monotonic": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", @@ -605,7 +614,9 @@ def __post_init__(self) -> None: "handler that emits non-increasing values has them forwarded to the callback unchanged." ), ), - known_failures=(KnownFailure(spec_version="2026-07-28", note=_MODERN_NOTIFY_DROP, issue=None),), + known_failures=( + KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), + ), ), "protocol:progress:stops-after-completion": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#behavior-requirements", @@ -753,7 +764,9 @@ def __post_init__(self) -> None: "Progress notifications emitted by a tool handler reach the caller's progress callback before " "the tool result returns." ), - known_failures=(KnownFailure(spec_version="2026-07-28", note=_MODERN_NOTIFY_DROP, issue=None),), + known_failures=( + KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), + ), ), "tools:call:sampling-roundtrip": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", @@ -983,7 +996,9 @@ def __post_init__(self) -> None: behavior=( "Context.report_progress sends a progress notification against the requesting client's progress token." ), - known_failures=(KnownFailure(spec_version="2026-07-28", note=_MODERN_NOTIFY_DROP, issue=None),), + known_failures=( + KnownFailure(spec_version="2026-07-28", transport="streamable-http", note=_MODERN_NOTIFY_DROP, issue=None), + ), ), "mcpserver:context:elicit": Requirement( source="sdk", diff --git a/tests/interaction/auth/test_flow.py b/tests/interaction/auth/test_flow.py index ab96185796..96903ef910 100644 --- a/tests/interaction/auth/test_flow.py +++ b/tests/interaction/auth/test_flow.py @@ -104,7 +104,7 @@ async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow # The first PRM discovery GET carries the protocol-version header (an SDK behaviour, not a # spec requirement on discovery requests). prm_get = next(r for r in requests if r.url.path == "/.well-known/oauth-protected-resource/mcp") - assert prm_get.headers.get("mcp-protocol-version") == snapshot("2025-11-25") + assert prm_get.headers.get("mcp-protocol-version") == snapshot("2026-07-28") authorize = parse_qs(urlsplit(headless.authorize_url).query) assert authorize["response_type"] == ["code"] diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index a89039b99e..4fb2c7c224 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -41,18 +41,9 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "download" - assert ctx.meta is not None - token = ctx.meta.get("progress_token") - assert token is not None - await ctx.session.send_progress_notification( - token, 1.0, total=3.0, message="first chunk", related_request_id=str(ctx.request_id) - ) - await ctx.session.send_progress_notification( - token, 2.0, total=3.0, message="second chunk", related_request_id=str(ctx.request_id) - ) - await ctx.session.send_progress_notification( - token, 3.0, total=3.0, message="done", related_request_id=str(ctx.request_id) - ) + await ctx.session.report_progress(1.0, total=3.0, message="first chunk") + await ctx.session.report_progress(2.0, total=3.0, message="second chunk") + await ctx.session.report_progress(3.0, total=3.0, message="done") return CallToolResult(content=[TextContent(text="downloaded")]) server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) @@ -130,7 +121,7 @@ async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationPar server = Server("observer", on_progress=on_progress) async with connect(server) as client: - await client.send_progress_notification("upload-1", 0.5, total=1.0, message="halfway") + await client.send_progress_notification("upload-1", 0.5, total=1.0, message="halfway") # pyright: ignore[reportDeprecated] with anyio.fail_after(5): await delivered.wait() @@ -147,12 +138,11 @@ async def test_concurrent_requests_carry_distinct_progress_tokens(connect: Conne token would be live at a time and the demultiplexing would never be exercised. The handlers each block until both have started and then hand control back and forth so the four progress notifications are emitted in strict a, b, a, b order on the wire. The two handlers send different - progress values so a stream swap (token A delivered to callback B and vice versa) would fail: each - callback receiving exactly its own values proves notifications are routed by token, not by arrival - order or by chance. + progress values so a stream swap (request A's progress delivered to callback B and vice versa) + would fail: each callback receiving exactly its own values proves notifications are routed + per-request, not by arrival order or by chance. """ progress_values = {"a": (1.0, 2.0), "b": (10.0, 20.0)} - tokens: dict[str, ProgressToken] = {} entered = {"a": anyio.Event(), "b": anyio.Event()} # turns[n] is set to release the nth emission; each emission releases the next. turns = [anyio.Event() for _ in range(4)] @@ -165,23 +155,15 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "report" assert params.arguments is not None - assert ctx.meta is not None - token = ctx.meta.get("progress_token") - assert token is not None label = params.arguments["label"] - tokens[label] = token entered[label].set() # The two handlers interleave by waiting on alternating turns: a takes 0 and 2, b takes 1 and 3. first, second = (0, 2) if label == "a" else (1, 3) await turns[first].wait() - await ctx.session.send_progress_notification( - token, progress_values[label][0], related_request_id=str(ctx.request_id) - ) + await ctx.session.report_progress(progress_values[label][0]) turns[first + 1].set() await turns[second].wait() - await ctx.session.send_progress_notification( - token, progress_values[label][1], related_request_id=str(ctx.request_id) - ) + await ctx.session.report_progress(progress_values[label][1]) if second + 1 < len(turns): turns[second + 1].set() return CallToolResult(content=[TextContent(text="done")]) @@ -210,7 +192,6 @@ async def call(label: str, collect: ProgressFnT) -> None: await entered["b"].wait() turns[0].set() - assert tokens["a"] != tokens["b"] assert received_a == [1.0, 2.0] assert received_b == [10.0, 20.0] @@ -285,12 +266,9 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "zigzag" - assert ctx.meta is not None - token = ctx.meta.get("progress_token") - assert token is not None - await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) - await ctx.session.send_progress_notification(token, 0.3, related_request_id=str(ctx.request_id)) - await ctx.session.send_progress_notification(token, 0.9, related_request_id=str(ctx.request_id)) + await ctx.session.report_progress(0.5) + await ctx.session.report_progress(0.3) + await ctx.session.report_progress(0.9) return CallToolResult(content=[TextContent(text="done")]) server = Server("zigzagger", on_list_tools=list_tools, on_call_tool=call_tool) diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index f88521dbb0..3fa0da44e8 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -21,8 +21,8 @@ from mcp.client.streamable_http import streamable_http_client from mcp.server.mcpserver import Context, MCPServer from mcp.shared.message import ClientMessageMetadata +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS from mcp.types import ( - LATEST_PROTOCOL_VERSION, CallToolRequest, CallToolRequestParams, CallToolResult, @@ -431,7 +431,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: # The session id is only observable via the manager (the client transport does not expose it). (session_id,) = manager._server_instances http.headers["mcp-session-id"] = session_id - http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION + http.headers["mcp-protocol-version"] = HANDSHAKE_PROTOCOL_VERSIONS[-1] tg.cancel_scope.cancel() with anyio.fail_after(5): # pragma: no branch diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 1ba2c8e118..5e62e9c692 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -9,13 +9,17 @@ async def test_progress_token_zero_first_call(): - """Test that progress notifications work when progress_token is 0 on first call.""" - - # Create mock session with progress notification tracking + """Regression: progress reporting must not be gated on a falsy token. + + Issue #176: the original Context.report_progress treated token 0 as "no token" and + silently dropped progress. Context now delegates unconditionally to + ServerSession.report_progress (which calls DispatchContext.progress, whose JSONRPC + implementation gates on `is None`, not truthiness), so a request whose meta carries + a 0-valued token still emits all three reports. + """ mock_session = AsyncMock() - mock_session.send_progress_notification = AsyncMock() + mock_session.report_progress = AsyncMock() - # Create request context with progress token 0 request_context = ServerRequestContext( request_id="test-request", session=mock_session, @@ -25,22 +29,14 @@ async def test_progress_token_zero_first_call(): protocol_version="2025-11-25", ) - # Create context with our mocks ctx = Context(request_context=request_context, mcp_server=MagicMock()) - # Test progress reporting - await ctx.report_progress(0, 10) # First call with 0 - await ctx.report_progress(5, 10) # Middle progress - await ctx.report_progress(10, 10) # Complete + await ctx.report_progress(0, 10) + await ctx.report_progress(5, 10) + await ctx.report_progress(10, 10) - # Verify progress notifications - assert mock_session.send_progress_notification.call_count == 3, "All progress notifications should be sent" - mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=0.0, total=10.0, message=None, related_request_id="test-request" - ) - mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=5.0, total=10.0, message=None, related_request_id="test-request" - ) - mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=10.0, total=10.0, message=None, related_request_id="test-request" - ) + assert mock_session.report_progress.await_args_list == [ + ((0, 10, None),), + ((5, 10, None),), + ((10, 10, None),), + ] diff --git a/tests/issues/test_552_windows_hang.py b/tests/issues/test_552_windows_hang.py index 371d033c2b..e2210a6d65 100644 --- a/tests/issues/test_552_windows_hang.py +++ b/tests/issues/test_552_windows_hang.py @@ -9,7 +9,8 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.types import LATEST_PROTOCOL_VERSION, InitializeResult +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS +from mcp.types import InitializeResult @pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific test") # pragma: no cover @@ -32,7 +33,7 @@ async def test_initialize_succeeds_and_shutdown_returns_after_the_server_exits_m "jsonrpc": "2.0", "id": request["id"], "result": {{ - "protocolVersion": {json.dumps(LATEST_PROTOCOL_VERSION)}, + "protocolVersion": {json.dumps(HANDSHAKE_PROTOCOL_VERSIONS[-1])}, "capabilities": {{}}, "serverInfo": {{"name": "test-server", "version": "1.0"}} }} diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 554fe50215..d44a4df42f 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1515,21 +1515,21 @@ def test_streamable_http_no_redirect() -> None: assert streamable_routes[0].path == "/mcp", "Streamable route path should be /mcp" -async def test_report_progress_passes_related_request_id(): - """Test that report_progress passes the request_id as related_request_id. +async def test_report_progress_delegates_to_session_report_progress(): + """Context.report_progress delegates to ServerSession.report_progress unconditionally. - Without related_request_id, the streamable HTTP transport cannot route - progress notifications to the correct SSE stream, causing them to be - silently dropped. See #953 and #2001. + Stream routing (related_request_id, progress-token gating) is encapsulated in the + per-request DispatchContext that ServerSession holds, so Context never inspects + request metadata itself. See #953 and #2001 for the original streamable-HTTP routing bug. """ mock_session = AsyncMock() - mock_session.send_progress_notification = AsyncMock() + mock_session.report_progress = AsyncMock() request_context = ServerRequestContext( request_id="req-abc-123", session=mock_session, method="tools/call", - meta={"progress_token": "tok-1"}, + meta=None, lifespan_context=None, protocol_version="2025-11-25", ) @@ -1538,13 +1538,7 @@ async def test_report_progress_passes_related_request_id(): await ctx.report_progress(50, 100, message="halfway") - mock_session.send_progress_notification.assert_awaited_once_with( - progress_token="tok-1", - progress=50, - total=100, - message="halfway", - related_request_id="req-abc-123", - ) + mock_session.report_progress.assert_awaited_once_with(50, 100, "halfway") async def test_read_resource_template_error(): diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 1987a05ce1..8707dc077e 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -71,7 +71,7 @@ def _initialize_params() -> dict[str, Any]: return InitializeRequestParams( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], capabilities=ClientCapabilities(), client_info=Implementation(name="test-client", version="1.0"), ).model_dump(by_alias=True, exclude_none=True) @@ -168,7 +168,7 @@ async def test_runner_handles_initialize_and_populates_connection(server: SrvT): assert "tools" in result["capabilities"] assert runner.connection.client_params is not None assert runner.connection.client_params.client_info.name == "test-client" - assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner.connection.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] assert runner.connection.initialize_accepted is True @@ -245,7 +245,7 @@ async def test_runner_routes_to_handler_and_builds_context(server: SrvT): assert isinstance(ctx.session, ServerSession) assert ctx.session.protocol_version == runner.connection.protocol_version assert ctx.request_id is not None - assert ctx.protocol_version == LATEST_PROTOCOL_VERSION + assert ctx.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] @pytest.mark.anyio @@ -815,7 +815,7 @@ async def test_runner_with_born_ready_connection_skips_init_gate(server: SrvT): """A `Connection.from_envelope` connection is born ready: the kernel's init-gate is open without any handshake. The kernel is mode-agnostic - the same `on_request` reads `connection.initialize_accepted` as a fact.""" - born_ready = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) + born_ready = Connection.from_envelope(HANDSHAKE_PROTOCOL_VERSIONS[-1], None, None) async with connected_runner(server, initialized=False, connection=born_ready) as (client, runner): assert runner.connection.initialize_accepted is True assert runner.connection.initialized.is_set() @@ -848,7 +848,7 @@ async def greet(ctx: Ctx, params: GreetParams) -> dict[str, Any]: @pytest.mark.anyio async def test_runner_spec_method_with_invalid_params_is_invalid_params_at_the_negotiated_version(server: SrvT): async with connected_runner(server) as (client, runner): - assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner.connection.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/call", {"name": 42}) assert exc.value.error.code == INVALID_PARAMS @@ -1006,7 +1006,7 @@ async def test_runner_initialize_echoes_supported_version_and_falls_back_to_late async with connected_runner(server, initialized=False) as (client, _): params = {**_initialize_params(), "protocolVersion": "1999-01-01"} result = await client.send_raw_request("initialize", params) - assert result["protocolVersion"] == LATEST_PROTOCOL_VERSION + assert result["protocolVersion"] == HANDSHAKE_PROTOCOL_VERSIONS[-1] @pytest.mark.anyio @@ -1328,7 +1328,7 @@ async def _append_async(dst: list[int], v: int) -> None: async def test_serve_one_runs_handler_and_returns_jsonrpc_response(server: SrvT): """The single-exchange driver: builds the kernel, runs `on_request` once, wraps via `to_jsonrpc_response`, and tears down `connection.exit_stack`.""" - conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) + conn = Connection.from_envelope(HANDSHAKE_PROTOCOL_VERSIONS[-1], None, None) cleaned: list[int] = [] conn.exit_stack.push_async_callback(_append_async, cleaned, 1) request = JSONRPCRequest(jsonrpc="2.0", id=9, method="tools/list", params=None) @@ -1338,7 +1338,7 @@ async def test_serve_one_runs_handler_and_returns_jsonrpc_response(server: SrvT) assert reply.result["tools"][0]["name"] == "t" assert cleaned == [1] ctx = _seen_ctx[0] - assert ctx.protocol_version == LATEST_PROTOCOL_VERSION + assert ctx.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] @pytest.mark.anyio @@ -1346,7 +1346,7 @@ async def test_serve_one_maps_error_to_jsonrpc_error_and_still_closes_exit_stack """SDK-defined: a kernel-produced error (here `METHOD_NOT_FOUND` for an unregistered method) is wrapped as a `JSONRPCError`, and the per-request exit stack is closed on the error path too.""" - conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) + conn = Connection.from_envelope(HANDSHAKE_PROTOCOL_VERSIONS[-1], None, None) cleaned: list[int] = [] conn.exit_stack.push_async_callback(_append_async, cleaned, 1) request = JSONRPCRequest(jsonrpc="2.0", id=2, method="resources/list", params=None) @@ -1389,5 +1389,5 @@ async def test_serve_connection_drives_dispatcher_loop_and_tears_down(server: Sr assert cleaned == [] close() assert cleaned == [1] - assert conn.protocol_version == LATEST_PROTOCOL_VERSION + assert conn.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] assert conn.client_params is not None diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 2379ccb8bb..f6f2a61e18 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -15,9 +15,9 @@ from mcp import types from mcp.server.connection import Connection from mcp.server.session import ServerSession -from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.dispatcher import CallOptions from mcp.shared.message import ServerMessageMetadata -from mcp.shared.version import MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, @@ -28,11 +28,21 @@ class StubOutbound: - """Records `send_raw_request` / `notify` calls and returns a canned result.""" + """Records `send_raw_request` / `notify` / `progress` calls and returns a canned result. + + Structurally a `DispatchContext[Any]` so it can stand in for the per-request channel. + """ + + transport: Any = None + can_send_request: bool = True + request_id: Any = None + message_metadata: Any = None + cancel_requested: Any = None def __init__(self, result: dict[str, Any] | None = None) -> None: self.requests: list[tuple[str, Mapping[str, Any] | None, CallOptions | None]] = [] self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self.progress_calls: list[tuple[float, float | None, str | None]] = [] self.result = result if result is not None else {} async def send_raw_request( @@ -47,12 +57,15 @@ async def send_raw_request( async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: self.notifications.append((method, params)) + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + self.progress_calls.append((progress, total, message)) + def _make_session( outbound: StubOutbound, *, capabilities: ClientCapabilities | None = None, - protocol_version: str = LATEST_PROTOCOL_VERSION, + protocol_version: str = HANDSHAKE_PROTOCOL_VERSIONS[-1], ) -> ServerSession: """Single-channel session: the stub is both request and standalone outbound.""" client_info = Implementation(name="c", version="0") if capabilities is not None else None @@ -60,7 +73,7 @@ def _make_session( return ServerSession(outbound, conn) -def _two_channel_session(request_ch: Outbound, standalone_ch: Outbound) -> ServerSession: +def _two_channel_session(request_ch: StubOutbound, standalone_ch: StubOutbound) -> ServerSession: """Distinct request/standalone outbounds so routing assertions can tell the channels apart.""" conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None, outbound=standalone_ch) return ServerSession(request_ch, conn) @@ -145,6 +158,19 @@ async def test_send_notification_routes_by_related_request_id(): assert [m for m, _ in request_ch.notifications] == ["notifications/progress"] +@pytest.mark.anyio +async def test_report_progress_delegates_to_the_request_dispatch_context(): + """`report_progress` calls the per-request `DispatchContext.progress` seam, never the + standalone channel: token gating and routing live in the dispatcher, not here.""" + request_ch = StubOutbound() + standalone_ch = StubOutbound() + session = _two_channel_session(request_ch, standalone_ch) + await session.report_progress(0.5, total=1.0, message="halfway") + assert request_ch.progress_calls == [(0.5, 1.0, "halfway")] + assert standalone_ch.progress_calls == [] + assert request_ch.notifications == [] + + @pytest.mark.anyio async def test_send_request_validates_the_client_result_against_the_surface_schema(): """A spec-method result that fails the per-version surface schema raises diff --git a/tests/server/test_stateless_mode.py b/tests/server/test_stateless_mode.py index 49c04dd847..7b252ef536 100644 --- a/tests/server/test_stateless_mode.py +++ b/tests/server/test_stateless_mode.py @@ -21,7 +21,16 @@ class StubOutbound: - """Records `send_raw_request` / `notify` calls and returns a canned result.""" + """Records `send_raw_request` / `notify` calls and returns a canned result. + + Structurally a `DispatchContext[Any]` so it can stand in for the per-request channel. + """ + + transport: Any = None + can_send_request: bool = True + request_id: Any = None + message_metadata: Any = None + cancel_requested: Any = None def __init__(self, result: dict[str, Any] | None = None) -> None: self.requests: list[tuple[str, Mapping[str, Any] | None, CallOptions | None]] = [] @@ -40,6 +49,9 @@ async def send_raw_request( async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: self.notifications.append((method, params)) + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + raise NotImplementedError # pragma: no cover + def _no_channel_session(request_ch: StubOutbound | None = None) -> tuple[ServerSession, StubOutbound]: """A session whose standalone channel is the connection's no-channel diff --git a/tests/shared/test_inbound.py b/tests/shared/test_inbound.py index 75ed93e99a..eaa2a59bf4 100644 --- a/tests/shared/test_inbound.py +++ b/tests/shared/test_inbound.py @@ -17,11 +17,10 @@ InboundModernRoute, classify_inbound_request, ) -from mcp.shared.version import MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, - LATEST_PROTOCOL_VERSION, PROTOCOL_VERSION_META_KEY, ) from mcp.types.jsonrpc import ( @@ -108,18 +107,18 @@ def test_envelope_rung_rejects_non_mapping_shapes(body: dict[str, Any]) -> None: def test_version_rung_rejects_unsupported_with_data_shape() -> None: """Spec-mandated: an envelope version outside the modern set rejects with the ``supported``/``requested`` data.""" rejection = assert_rejected( - classify_inbound_request(envelope(version=LATEST_PROTOCOL_VERSION)), + classify_inbound_request(envelope(version=HANDSHAKE_PROTOCOL_VERSIONS[-1])), UNSUPPORTED_PROTOCOL_VERSION, ) assert rejection.data == { "supported": list(MODERN_PROTOCOL_VERSIONS), - "requested": LATEST_PROTOCOL_VERSION, + "requested": HANDSHAKE_PROTOCOL_VERSIONS[-1], } def test_version_rung_data_reflects_supplied_supported_list() -> None: """SDK-defined: the caller-supplied ``supported_modern_versions`` is what rejection ``data.supported`` echoes.""" - custom = (LATEST_PROTOCOL_VERSION,) + custom = (HANDSHAKE_PROTOCOL_VERSIONS[-1],) rejection = assert_rejected( classify_inbound_request(envelope(), supported_modern_versions=custom), UNSUPPORTED_PROTOCOL_VERSION, @@ -145,7 +144,7 @@ def test_header_rung_passes_when_header_matches_envelope() -> None: @pytest.mark.parametrize( "headers", [ - pytest.param({MCP_PROTOCOL_VERSION_HEADER: LATEST_PROTOCOL_VERSION}, id="mismatch"), + pytest.param({MCP_PROTOCOL_VERSION_HEADER: HANDSHAKE_PROTOCOL_VERSIONS[-1]}, id="mismatch"), pytest.param({}, id="header-absent"), ], ) @@ -177,7 +176,7 @@ def test_ladder_first_failure_wins() -> None: """Spec-mandated: rungs evaluate in order — header-mismatch and version-unsupported would both fail; the header rung fires first so an inconsistent client is told it disagrees with itself rather than that its body version is unsupported.""" - body = envelope(version=LATEST_PROTOCOL_VERSION) + body = envelope(version=HANDSHAKE_PROTOCOL_VERSIONS[-1]) result = classify_inbound_request(body, headers={MCP_PROTOCOL_VERSION_HEADER: MODERN}) assert_rejected(result, HEADER_MISMATCH) From a97459d916413bf21fd8e63a7f65f1ffdd0f4c6a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 22 Jun 2026 16:57:56 +0000 Subject: [PATCH 08/22] Add lifecycle:envelope/discover/mode requirement entries and interaction tests - 9 new requirement IDs in the Lifecycle section covering the per-request envelope, server/discover behaviour, and Client mode= policy - 10 interaction tests in tests/interaction/lowlevel/test_client_connect.py driving each via Client(server, mode=...) over in-memory and in-process ASGI --- tests/interaction/_requirements.py | 76 ++++ .../lowlevel/test_client_connect.py | 362 ++++++++++++++++++ 2 files changed, 438 insertions(+) create mode 100644 tests/interaction/lowlevel/test_client_connect.py diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index ce646ad598..afb9e2b8a1 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -250,6 +250,7 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", behavior="The client's name, version, and title are visible to server handlers after initialization.", removed_in="2026-07-28", + superseded_by="lifecycle:envelope:stamped-on-every-request", note="initialize handshake removed at 2026-07-28; per-request _meta envelope replaces it.", arm_exclusions=(ArmExclusion(reason="requires-session", transport="streamable-http-stateless"),), ), @@ -260,6 +261,7 @@ def __post_init__(self) -> None: "(sampling, elicitation, roots)." ), removed_in="2026-07-28", + superseded_by="lifecycle:envelope:stamped-on-every-request", note="initialize handshake removed at 2026-07-28; per-request _meta envelope replaces it.", arm_exclusions=(ArmExclusion(reason="requires-session", transport="streamable-http-stateless"),), ), @@ -393,6 +395,80 @@ def __post_init__(self) -> None: "hosting:http:legacy-no-modern-vocabulary covers the same vocabulary set" ), ), + "lifecycle:envelope:stamped-on-every-request": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic#_meta", + behavior=( + "Every client→server request on a modern-negotiated session carries " + "_meta.{protocolVersion,clientInfo,clientCapabilities}; notifications do not." + ), + added_in="2026-07-28", + supersedes=("lifecycle:initialize:client-info", "lifecycle:initialize:client-capabilities"), + ), + "lifecycle:envelope:header-matches-meta": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic/transports/streamable-http#headers", + behavior="On HTTP, the MCP-Protocol-Version header on every POST matches _meta.protocolVersion in the body.", + transports=("streamable-http", "streamable-http-stateless"), + added_in="2026-07-28", + note="HTTP-only: the header is a streamable-http transport concern; stdio and in-memory carry no headers.", + ), + "lifecycle:discover:basic": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic/lifecycle#discover", + behavior=( + "Calling discover() sends server/discover with no params and returns a typed DiscoverResult " + "carrying protocolVersion, capabilities, serverInfo and the cache hint fields." + ), + added_in="2026-07-28", + ), + "lifecycle:discover:retry-on-32022": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic/lifecycle#version-errors", + behavior=( + "When server/discover returns -32022 UnsupportedProtocolVersion, the client retries once with " + "the intersection of error.data.supported and its own modern versions; an empty intersection raises." + ), + added_in="2026-07-28", + ), + "lifecycle:discover:fallback-method-not-found": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic/transports/stdio#backward-compatibility", + behavior=( + "When server/discover returns -32601 (or HTTP 404), an auto-negotiating client falls back to " + "the legacy initialize handshake and the connection succeeds at a handshake-era version." + ), + added_in="2026-07-28", + ), + "lifecycle:discover:network-error-raises": Requirement( + source="sdk", + behavior=( + "An HTTP timeout, connection error, or non-404 4xx/5xx during server/discover raises to the " + "caller without falling back to initialize." + ), + transports=("streamable-http", "streamable-http-stateless"), + added_in="2026-07-28", + note="HTTP-only: distinguishes transport-level failures from the -32601 fallback signal.", + ), + "lifecycle:mode:legacy-never-probes": Requirement( + source="sdk", + behavior=( + "A Client constructed with mode='legacy' (the default) sends initialize as its first request " + "and never sends server/discover." + ), + added_in="2026-07-28", + ), + "lifecycle:mode:pin-never-handshakes": Requirement( + source="sdk", + behavior=( + "A Client constructed with mode='2026-07-28' sends no initialize and no server/discover; its " + "first wire request is the caller's first call, carrying the full _meta envelope." + ), + added_in="2026-07-28", + ), + "lifecycle:mode:prior-discover-zero-rtt": Requirement( + source="sdk", + behavior=( + "A Client constructed with prior_discover= sends no negotiation traffic; " + "server_info and capabilities are populated from the prior result." + ), + added_in="2026-07-28", + ), # ═══════════════════════════════════════════════════════════════════════════ # Protocol primitives: cancellation, timeout, progress, errors, _meta # ═══════════════════════════════════════════════════════════════════════════ diff --git a/tests/interaction/lowlevel/test_client_connect.py b/tests/interaction/lowlevel/test_client_connect.py new file mode 100644 index 0000000000..19c0c6605b --- /dev/null +++ b/tests/interaction/lowlevel/test_client_connect.py @@ -0,0 +1,362 @@ +"""Client connect-time negotiation: mode selection, server/discover, and the per-request envelope. + +These tests pin what `Client(..., mode=...)` puts on the wire BEFORE the caller's first call -- +the legacy initialize handshake, the modern `server/discover` probe, or nothing at all -- and +that a modern-negotiated session stamps the three-key `io.modelcontextprotocol/*` `_meta` +envelope on every subsequent request. Each test drives the highest public surface (`Client`) +and observes traffic at a recording seam: `RecordingTransport` for the legacy stream pair, and +`mounted_app`'s httpx event hook for the in-process streamable-HTTP transport. + +The fallback test alone hand-plays the server's side of the wire, because no real `Server` +answers `server/discover` with -32601. +""" + +import json +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import asynccontextmanager + +import anyio +import httpx +import pytest + +from mcp import MCPError, types +from mcp.client._memory import InMemoryTransport +from mcp.client._transport import TransportStreams +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS +from mcp.types import ( + CLIENT_CAPABILITIES_META_KEY, + CLIENT_INFO_META_KEY, + INTERNAL_ERROR, + METHOD_NOT_FOUND, + PROTOCOL_VERSION_META_KEY, + UNSUPPORTED_PROTOCOL_VERSION, + DiscoverResult, + Implementation, + InitializeResult, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + ToolsCapability, +) +from tests.interaction._connect import BASE_URL, Connect, mounted_app +from tests.interaction._helpers import RecordingTransport +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +MODERN_VERSION = "2026-07-28" + + +def _tools_server(name: str = "negotiator") -> Server: + """A low-level server with one list-tools handler, so a feature request has something to reach.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="noop", input_schema={"type": "object"})]) + + return Server(name, on_list_tools=list_tools) + + +def _request_recorder() -> tuple[list[httpx.Request], Callable[[httpx.Request], Awaitable[None]]]: + """Return a list and an `on_request` hook that appends each outgoing httpx request to it.""" + captured: list[httpx.Request] = [] + + async def on_request(request: httpx.Request) -> None: + captured.append(request) + + return captured, on_request + + +@requirement("lifecycle:mode:legacy-never-probes") +async def test_legacy_mode_sends_initialize_and_never_probes_discover() -> None: + """`Client(server)` (mode defaults to 'legacy') opens with `initialize` and never sends `server/discover`. + + Requirement `lifecycle:mode:legacy-never-probes` (sdk-defined): the default mode must remain + byte-identical to the pre-2026 client so a 2025-era server never observes modern vocabulary. + """ + recording = RecordingTransport(InMemoryTransport(_tools_server())) + + with anyio.fail_after(5): + async with Client(recording) as client: + await client.list_tools() + + sent = [m.message for m in recording.sent] + methods = [m.method for m in sent if isinstance(m, JSONRPCRequest | JSONRPCNotification)] + assert methods[0] == "initialize" + assert "server/discover" not in methods + assert "notifications/initialized" in methods + + +@requirement("lifecycle:mode:pin-never-handshakes") +async def test_pinned_mode_sends_no_connect_time_traffic() -> None: + """`Client(..., mode='2026-07-28')` sends nothing on entry; the caller's first call is the first wire request. + + Requirement `lifecycle:mode:pin-never-handshakes` (sdk-defined): a version pin adopts a + synthesized DiscoverResult locally, so no `initialize` and no `server/discover` ever cross + the wire. Asserted at the in-process streamable-HTTP seam via the httpx event hook. + """ + requests, on_request = _request_recorder() + + with anyio.fail_after(5): + async with ( + mounted_app(_tools_server(), on_request=on_request) as (http, _), + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode=MODERN_VERSION) as client, + ): + assert requests == [] # entering the Client produced zero HTTP traffic + result = await client.list_tools() + + bodies = [json.loads(r.content) for r in requests] + assert [b["method"] for b in bodies] == ["tools/list"] + assert PROTOCOL_VERSION_META_KEY in bodies[0]["params"]["_meta"] + assert [t.name for t in result.tools] == ["noop"] + + +@requirement("lifecycle:mode:prior-discover-zero-rtt") +async def test_prior_discover_populates_state_with_zero_connect_time_traffic() -> None: + """`Client(..., mode=, prior_discover=...)` sends nothing on entry and exposes the prior server_info. + + Requirement `lifecycle:mode:prior-discover-zero-rtt` (sdk-defined): a previously-obtained + DiscoverResult is installed via `adopt()` so server_info and capabilities are available + immediately with zero round trips. + """ + prior = DiscoverResult( + supported_versions=[MODERN_VERSION], + capabilities=ServerCapabilities(tools=ToolsCapability(list_changed=False)), + server_info=Implementation(name="cached-server", version="9.9.9"), + ) + requests, on_request = _request_recorder() + + with anyio.fail_after(5): + async with ( + mounted_app(_tools_server(), on_request=on_request) as (http, _), + Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http), + mode=MODERN_VERSION, + prior_discover=prior, + ) as client, + ): + assert requests == [] + assert client.initialize_result.server_info == Implementation(name="cached-server", version="9.9.9") + assert client.initialize_result.capabilities.tools == ToolsCapability(list_changed=False) + await client.list_tools() + + assert [json.loads(r.content)["method"] for r in requests] == ["tools/list"] + + +@requirement("lifecycle:discover:basic") +async def test_auto_mode_probes_server_discover_and_adopts_the_result() -> None: + """`Client(..., mode='auto')` sends `server/discover` first and adopts the returned version and server_info. + + Requirement `lifecycle:discover:basic` (spec basic/lifecycle#discover): the probe is a + single `server/discover` request whose result carries supported versions, capabilities, + server_info and the cache-hint fields, after which the session is modern-negotiated. + """ + requests, on_request = _request_recorder() + server = _tools_server("discoverable") + + with anyio.fail_after(5): + async with ( + mounted_app(server, on_request=on_request) as (http, _), + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto") as client, + ): + assert client.initialize_result.protocol_version == MODERN_VERSION + assert client.initialize_result.server_info.name == "discoverable" + await client.list_tools() + + bodies = [json.loads(r.content) for r in requests] + assert bodies[0]["method"] == "server/discover" + assert "initialize" not in [b["method"] for b in bodies] + + +@requirement("lifecycle:discover:retry-on-32022") +async def test_auto_mode_retries_discover_once_on_unsupported_protocol_version() -> None: + """A -32022 from `server/discover` triggers exactly one retry at the highest mutual modern version. + + Requirement `lifecycle:discover:retry-on-32022` (spec basic/lifecycle#version-errors): the + client intersects `error.data.supported` with its own modern versions and re-probes once; + the second success is adopted. The server's `server/discover` handler is overridden to fail + the first call and succeed on the second. + """ + calls: list[str | None] = [] + + async def discover(ctx: ServerRequestContext, params: types.RequestParams | None) -> DiscoverResult: + proposed = ctx.meta.get(PROTOCOL_VERSION_META_KEY) if ctx.meta else None + calls.append(proposed) + if len(calls) == 1: + raise MCPError( + code=UNSUPPORTED_PROTOCOL_VERSION, + message="unsupported protocol version", + data={"supported": list(MODERN_PROTOCOL_VERSIONS), "requested": proposed}, + ) + return DiscoverResult( + supported_versions=list(MODERN_PROTOCOL_VERSIONS), + capabilities=ServerCapabilities(), + server_info=Implementation(name="picky", version="1.0.0"), + ) + + server = _tools_server("picky") + server.add_request_handler("server/discover", types.RequestParams, discover) + requests, on_request = _request_recorder() + + with anyio.fail_after(5): + async with ( + mounted_app(server, on_request=on_request) as (http, _), + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto") as client, + ): + assert client.initialize_result.protocol_version == MODERN_VERSION + + assert calls == [MODERN_VERSION, MODERN_VERSION] + assert [json.loads(r.content)["method"] for r in requests][:2] == ["server/discover", "server/discover"] + + +@requirement("lifecycle:discover:network-error-raises") +async def test_auto_mode_reraises_a_non_fallback_discover_error_without_initializing() -> None: + """A `server/discover` failure outside the {-32601, -32001, -32022} ladder raises without falling back. + + Requirement `lifecycle:discover:network-error-raises` (sdk-defined): a 5xx-class error from + the probe is surfaced to the caller; the client never sends `initialize`. Exercised here as + the JSON-RPC INTERNAL_ERROR branch (which the modern HTTP entry maps to a 5xx). The error + reaches the test wrapped in the streamable-http transport's task-group teardown, so + `pytest.RaisesGroup` flattens before matching. + """ + + async def discover(ctx: ServerRequestContext, params: types.RequestParams | None) -> DiscoverResult: + raise MCPError(code=INTERNAL_ERROR, message="storage unavailable") + + server = _tools_server() + server.add_request_handler("server/discover", types.RequestParams, discover) + requests, on_request = _request_recorder() + + def is_internal_error(exc: MCPError) -> bool: + return exc.code == INTERNAL_ERROR + + with anyio.fail_after(5): + async with mounted_app(server, on_request=on_request) as (http, _): + with pytest.RaisesGroup( + pytest.RaisesExc(MCPError, check=is_internal_error), flatten_subgroups=True + ): # pragma: no branch + async with Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto"): + pytest.fail("entering the Client should have raised") # pragma: no cover + + assert [json.loads(r.content)["method"] for r in requests] == ["server/discover"] + + +@requirement("lifecycle:discover:fallback-method-not-found") +async def test_auto_mode_falls_back_to_initialize_when_discover_is_method_not_found() -> None: + """A -32601 from `server/discover` makes an auto-negotiating client run the legacy `initialize` handshake. + + Requirement `lifecycle:discover:fallback-method-not-found` (spec stdio#backward-compatibility): + a legacy-era server that does not implement `server/discover` is connected to via the + handshake, and the session lands at a handshake-era protocol version. A real `Server` always + implements `server/discover`, so this test plays the server's side of the wire by hand. + Reserve this pattern for behaviour no real server can be made to produce. + """ + methods_seen: list[str] = [] + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + async for message in server_read: + assert isinstance(message, SessionMessage) + frame = message.message + assert isinstance(frame, JSONRPCRequest | JSONRPCNotification) + methods_seen.append(frame.method) + if isinstance(frame, JSONRPCRequest) and frame.method == "server/discover": + error = types.ErrorData(code=METHOD_NOT_FOUND, message="Method not found") + await server_write.send(SessionMessage(JSONRPCError(jsonrpc="2.0", id=frame.id, error=error))) + elif isinstance(frame, JSONRPCRequest) and frame.method == "initialize": + result = InitializeResult( + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + capabilities=ServerCapabilities(), + server_info=Implementation(name="legacy-only", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=frame.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + # notifications/initialized (and anything else) is observed and ignored. + + @asynccontextmanager + async def scripted_transport() -> AsyncIterator[TransportStreams]: + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ): + tg.start_soon(scripted_server, server_streams) + yield client_read, client_write + tg.cancel_scope.cancel() + + with anyio.fail_after(5): + async with Client(scripted_transport(), mode="auto") as client: + assert client.initialize_result.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert client.initialize_result.server_info.name == "legacy-only" + + assert methods_seen == ["server/discover", "initialize", "notifications/initialized"] + + +@requirement("lifecycle:envelope:stamped-on-every-request") +async def test_every_request_on_a_modern_session_carries_the_three_key_meta_envelope(connect: Connect) -> None: + """Each modern-session request's `params._meta` carries protocolVersion, clientInfo and clientCapabilities. + + Requirement `lifecycle:envelope:stamped-on-every-request` (spec basic#_meta): the per-request + envelope replaces the initialize handshake's once-per-session exchange. Asserted server-side + by capturing `ctx.meta` inside the handler. + """ + observed: list[dict[str, object]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + assert ctx.meta is not None + observed.append(dict(ctx.meta)) + return types.ListToolsResult(tools=[]) + + server = Server("stamped", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect(server, client_info=Implementation(name="enveloper", version="1.2.3")) as client: + await client.list_tools() + await client.list_tools() + + assert len(observed) == 2 + for meta in observed: + assert meta[PROTOCOL_VERSION_META_KEY] == MODERN_VERSION + assert meta[CLIENT_INFO_META_KEY] == {"name": "enveloper", "version": "1.2.3"} + assert CLIENT_CAPABILITIES_META_KEY in meta + + +@requirement("lifecycle:envelope:header-matches-meta") +async def test_http_protocol_version_header_matches_meta_protocol_version_on_every_post() -> None: + """On streamable-HTTP, the `MCP-Protocol-Version` header on each POST equals `_meta.protocolVersion` in its body. + + Requirement `lifecycle:envelope:header-matches-meta` (spec streamable-http#headers): the + body-derived header and the envelope's protocol version are kept in lockstep so the server's + header-based routing and body-based validation never disagree. + """ + requests, on_request = _request_recorder() + + with anyio.fail_after(5): + async with ( + mounted_app(_tools_server(), on_request=on_request) as (http, _), + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode=MODERN_VERSION) as client, + ): + await client.list_tools() + await client.list_tools() + + assert requests, "no HTTP traffic recorded" + for request in requests: + body = json.loads(request.content) + assert request.headers["mcp-protocol-version"] == body["params"]["_meta"][PROTOCOL_VERSION_META_KEY] + assert request.headers["mcp-protocol-version"] == MODERN_VERSION From 52f200b13266d24d9e03648ab56b677091cf7784 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 22 Jun 2026 17:04:25 +0000 Subject: [PATCH 09/22] Conformance client fixture: drive Client(mode='auto') for the modern leg - client.py reads MCP_CONFORMANCE_PROTOCOL_VERSION and passes mode='auto' (modern) or 'legacy' (handshake-era) to the high-level Client; auth flows wrap the OAuth-authed httpx client in streamable_http_client and hand that as a Transport - New fixture handlers for request-metadata and http-standard-headers - json-schema-ref-no-deref pinned to legacy (its mock only speaks the handshake-era lifecycle; the check is lifecycle-agnostic) - Baselines: request-metadata + auth/authorization-server-migration removed from expected-failures.yml; tools_call + auth/scope-step-up + auth/scope-retry-limit + the two above removed from expected-failures.2026-07-28.yml. http-custom-headers / http-invalid-tool-headers (Mcp-Param-* headers) and sep-2322-client-request-state (multi-round-trip) stay waived. --- .github/actions/conformance/client.py | 205 +++++++++++------- .../expected-failures.2026-07-28.yml | 25 +-- .../actions/conformance/expected-failures.yml | 10 +- 3 files changed, 134 insertions(+), 106 deletions(-) diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py index 2a7fd14681..4f3a93978b 100644 --- a/.github/actions/conformance/client.py +++ b/.github/actions/conformance/client.py @@ -17,6 +17,9 @@ initialize - Connect, initialize, list tools, close tools_call - Connect, call add_numbers(a=5, b=3), close sse-retry - Connect, call test_reconnection, close + json-schema-ref-no-deref - Connect, list tools (no $ref deref) + request-metadata - Connect with all callbacks; client stamps _meta + http-standard-headers - Connect, call a tool (Mcp-* headers checked) elicitation-sep1034-client-defaults - Elicitation with default accept callback auth/client-credentials-jwt - Client credentials with private_key_jwt auth/client-credentials-basic - Client credentials with client_secret_basic @@ -35,16 +38,18 @@ import httpx from pydantic import AnyUrl -from mcp import ClientSession, types +from mcp import types from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.auth.extensions.client_credentials import ( ClientCredentialsOAuthProvider, PrivateKeyJWTOAuthProvider, SignedJWTParameters, ) +from mcp.client.client import Client from mcp.client.context import ClientRequestContext from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import AuthorizationCodeResult, OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS # Set up logging to stderr (stdout is for conformance test output) logging.basicConfig( @@ -58,10 +63,24 @@ #: "2026-07-28"). The harness always sets this (when --spec-version is omitted #: it picks per-scenario: LATEST_SPEC_VERSION for active scenarios, #: DRAFT_PROTOCOL_VERSION for draft-only ones), so None means we were invoked -#: outside the harness. Handlers that need to take the stateless 2026 path will -#: branch on this once the SDK has one; today it is logged only. +#: outside the harness. PROTOCOL_VERSION: str | None = os.environ.get("MCP_CONFORMANCE_PROTOCOL_VERSION") + +def client_mode() -> str: + """Pick the Client(mode=) for the harness leg. + + On a modern leg (2026-07-28+) -> 'auto' so Client.discover() runs and the + _meta envelope + MCP-Protocol-Version header are stamped on every request. + On a handshake-era leg -> 'legacy' so the initialize handshake runs exactly + as before (no server/discover probe is sent against a mock that would 400 it). + Outside the harness -> 'auto' (probe + fallback). + """ + if PROTOCOL_VERSION is None or PROTOCOL_VERSION in MODERN_PROTOCOL_VERSIONS: + return "auto" + return "legacy" + + # Type for async scenario handler functions ScenarioHandler = Callable[[str], Coroutine[Any, None, None]] @@ -165,52 +184,22 @@ async def handle_callback(self) -> AuthorizationCodeResult: return result -# --- Scenario Handlers --- - - -@register("initialize") -async def run_initialize(server_url: str) -> None: - """Connect, initialize, list tools, close.""" - async with streamable_http_client(url=server_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - logger.debug("Initialized successfully") - await session.list_tools() - logger.debug("Listed tools successfully") - - -@register("json-schema-ref-no-deref") -async def run_json_schema_ref_no_deref(server_url: str) -> None: - """Initialize and list tools; the scenario fails only if the client fetches a network $ref. - - ClientSession never walks inputSchema or resolves $refs, so listing is enough (SEP-2106). - """ - async with streamable_http_client(url=server_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - await session.list_tools() +# --- Stub callbacks (declare capabilities in _meta without doing real work) --- -@register("tools_call") -async def run_tools_call(server_url: str) -> None: - """Connect, initialize, list tools, call add_numbers(a=5, b=3), close.""" - async with streamable_http_client(url=server_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - await session.list_tools() - result = await session.call_tool("add_numbers", {"a": 5, "b": 3}) - logger.debug(f"add_numbers result: {result}") +async def stub_sampling_callback( + context: ClientRequestContext, + params: types.CreateMessageRequestParams, +) -> types.CreateMessageResult | types.ErrorData: + return types.CreateMessageResult( + role="assistant", + content=types.TextContent(type="text", text=""), + model="conformance-stub", + ) -@register("sse-retry") -async def run_sse_retry(server_url: str) -> None: - """Connect, initialize, list tools, call test_reconnection, close.""" - async with streamable_http_client(url=server_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - await session.list_tools() - result = await session.call_tool("test_reconnection", {}) - logger.debug(f"test_reconnection result: {result}") +async def stub_list_roots_callback(context: ClientRequestContext) -> types.ListRootsResult | types.ErrorData: + return types.ListRootsResult(roots=[]) async def default_elicitation_callback( @@ -233,17 +222,87 @@ async def default_elicitation_callback( return types.ElicitResult(action="accept", content=content) +# --- Scenario Handlers --- + + +@register("initialize") +async def run_initialize(server_url: str) -> None: + """Connect, initialize, list tools, close.""" + async with Client(server_url, mode=client_mode()) as client: + logger.debug("Initialized successfully") + await client.list_tools() + logger.debug("Listed tools successfully") + + +@register("json-schema-ref-no-deref") +async def run_json_schema_ref_no_deref(server_url: str) -> None: + """Initialize and list tools; the scenario fails only if the client fetches a network $ref. + + The client never walks inputSchema or resolves $refs, so listing is enough (SEP-2106). + Pinned to mode='legacy': the harness reports PROTOCOL_VERSION=2026-07-28 for this + scenario but its mock server only speaks the handshake-era lifecycle and 400s a + modern-stamped tools/list. The check is lifecycle-agnostic so this is harmless. + """ + async with Client(server_url, mode="legacy") as client: + await client.list_tools() + + +@register("tools_call") +async def run_tools_call(server_url: str) -> None: + """Connect, list tools, call add_numbers(a=5, b=3), close.""" + async with Client(server_url, mode=client_mode()) as client: + await client.list_tools() + result = await client.call_tool("add_numbers", {"a": 5, "b": 3}) + logger.debug(f"add_numbers result: {result}") + + +@register("sse-retry") +async def run_sse_retry(server_url: str) -> None: + """Connect, list tools, call test_reconnection, close.""" + async with Client(server_url, mode=client_mode()) as client: + await client.list_tools() + result = await client.call_tool("test_reconnection", {}) + logger.debug(f"test_reconnection result: {result}") + + +@register("request-metadata") +async def run_request_metadata(server_url: str) -> None: + """Connect on the modern path with every client capability declared. + + The scenario inspects every request's `_meta` envelope (SEP-2575) for + protocolVersion / clientInfo / clientCapabilities, and the matching + MCP-Protocol-Version header. mode='auto' makes the SDK send + server/discover (covering the unsupported-version retry check), then adopt + and stamp the envelope on the follow-up requests. + """ + async with Client( + server_url, + mode=client_mode(), + sampling_callback=stub_sampling_callback, + list_roots_callback=stub_list_roots_callback, + elicitation_callback=default_elicitation_callback, + ) as client: + await client.list_tools() + result = await client.call_tool("add_numbers", {"a": 5, "b": 3}) + logger.debug(f"add_numbers result: {result}") + + +@register("http-standard-headers") +async def run_http_standard_headers(server_url: str) -> None: + """Connect on the modern path so Mcp-Method / Mcp-Name / MCP-Protocol-Version are sent (SEP-2243).""" + async with Client(server_url, mode=client_mode()) as client: + await client.list_tools() + result = await client.call_tool("add_numbers", {"a": 5, "b": 3}) + logger.debug(f"add_numbers result: {result}") + + @register("elicitation-sep1034-client-defaults") async def run_elicitation_defaults(server_url: str) -> None: """Connect with elicitation callback that applies schema defaults.""" - async with streamable_http_client(url=server_url) as (read_stream, write_stream): - async with ClientSession( - read_stream, write_stream, elicitation_callback=default_elicitation_callback - ) as session: - await session.initialize() - await session.list_tools() - result = await session.call_tool("test_client_elicitation_defaults", {}) - logger.debug(f"test_client_elicitation_defaults result: {result}") + async with Client(server_url, mode=client_mode(), elicitation_callback=default_elicitation_callback) as client: + await client.list_tools() + result = await client.call_tool("test_client_elicitation_defaults", {}) + logger.debug(f"test_client_elicitation_defaults result: {result}") @register("auth/client-credentials-jwt") @@ -343,25 +402,22 @@ async def run_auth_code_client(server_url: str) -> None: async def _run_auth_session(server_url: str, oauth_auth: OAuthClientProvider) -> None: """Common session logic for all OAuth flows.""" - client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0) - async with streamable_http_client(url=server_url, http_client=client) as (read_stream, write_stream): - async with ClientSession( - read_stream, write_stream, elicitation_callback=default_elicitation_callback - ) as session: - await session.initialize() - logger.debug("Initialized successfully") - - tools_result = await session.list_tools() - logger.debug(f"Listed tools: {[t.name for t in tools_result.tools]}") - - # Call the first available tool (different tests have different tools) - if tools_result.tools: - tool_name = tools_result.tools[0].name - try: - result = await session.call_tool(tool_name, {}) - logger.debug(f"Called {tool_name}, result: {result}") - except Exception as e: - logger.debug(f"Tool call result/error: {e}") + http_client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0) + transport = streamable_http_client(url=server_url, http_client=http_client) + async with Client(transport, mode=client_mode(), elicitation_callback=default_elicitation_callback) as client: + logger.debug("Initialized successfully") + + tools_result = await client.list_tools() + logger.debug(f"Listed tools: {[t.name for t in tools_result.tools]}") + + # Call the first available tool (different tests have different tools) + if tools_result.tools: + tool_name = tools_result.tools[0].name + try: + result = await client.call_tool(tool_name, {}) + logger.debug(f"Called {tool_name}, result: {result}") + except Exception as e: + logger.debug(f"Tool call result/error: {e}") logger.debug("Connection closed successfully") @@ -374,7 +430,7 @@ def main() -> None: server_url = sys.argv[1] scenario = os.environ.get("MCP_CONFORMANCE_SCENARIO") - logger.debug(f"Conformance protocol version: {PROTOCOL_VERSION!r}") + logger.debug(f"Conformance protocol version: {PROTOCOL_VERSION!r} -> mode={client_mode()!r}") if scenario: logger.debug(f"Running explicit scenario '{scenario}' against {server_url}") @@ -384,6 +440,9 @@ def main() -> None: elif scenario.startswith("auth/"): asyncio.run(run_auth_code_client(server_url)) else: + # Unhandled scenarios: + # - sep-2322-client-request-state (SEP-2322 / S6: MRTR client loop) + # - http-custom-headers, http-invalid-tool-headers (SEP-2243 / S8: Mcp-Param-* headers) print(f"Unknown scenario: {scenario}", file=sys.stderr) sys.exit(1) else: diff --git a/.github/actions/conformance/expected-failures.2026-07-28.yml b/.github/actions/conformance/expected-failures.2026-07-28.yml index b49626d0d6..529eb8babe 100644 --- a/.github/actions/conformance/expected-failures.2026-07-28.yml +++ b/.github/actions/conformance/expected-failures.2026-07-28.yml @@ -21,36 +21,13 @@ # milestone. client: - # --- No stateless client path on main yet --- - # client.py drives the 2025 stateful lifecycle (initialize handshake + - # session). The 2026-mode mock server is stateless, so the call sequence - # never reaches the assertion. Unblocks when client.py's is_modern_protocol() - # branch takes the per-request _meta path. - - tools_call - - # --- Auth scenarios cut short by the 2026 connection lifecycle --- - # The auth fixture flow drives the 2025 stateful lifecycle; the 2026-mode - # mock rejects the MCP POST before the scope-escalation behaviour these - # scenarios measure, so no authorization requests are observed. Unblocks - # when client.py's auth flow speaks the 2026 per-request lifecycle. - - auth/scope-step-up - - auth/scope-retry-limit - # --- Same gaps as the 2025 baseline (fail identically when forced to 2026-07-28) --- - # SEP-2575 (request metadata / _meta envelope): client does not populate the - # _meta envelope or the MCP-Protocol-Version header semantics yet. - - request-metadata # SEP-2322 (multi-round-trip requests): client does not echo requestState / # handle IncompleteResult yet. - sep-2322-client-request-state - # SEP-2243 (HTTP standardization): no fixture handler / client header support yet. + # SEP-2243 (HTTP standardization): no fixture handler / client Mcp-Param-* support yet. - http-custom-headers - http-invalid-tool-headers - # SEP-2352 (authorization server migration): the client re-registers and does not reuse the old - # AS credentials, but the 2026-mode mock rejects the MCP POST before the migration 401 fires - # (client.py drives the 2025 stateful lifecycle), so the re-register check is never reached. - # Unblocks with the 2026 stateless client lifecycle. - - auth/authorization-server-migration # auth/enterprise-managed-authorization (SEP-990) is in the 2025 baseline but # NOT here: the harness skips it as inapplicable at --spec-version 2026-07-28 # (it is an extension scenario not carried into the 2026 wire), so it is diff --git a/.github/actions/conformance/expected-failures.yml b/.github/actions/conformance/expected-failures.yml index 4234a6d4aa..2a411b4cde 100644 --- a/.github/actions/conformance/expected-failures.yml +++ b/.github/actions/conformance/expected-failures.yml @@ -12,20 +12,12 @@ client: # --- Draft-spec scenarios (in `--suite draft`, also part of `--suite all`) --- - # SEP-2575 (request metadata / _meta envelope): client does not populate the - # _meta envelope or the MCP-Protocol-Version header semantics yet. - - request-metadata # SEP-2322 (multi-round-trip requests): client does not echo requestState / # handle IncompleteResult yet. - sep-2322-client-request-state - # SEP-2243 (HTTP standardization): no fixture handler / client header support yet. + # SEP-2243 (HTTP standardization): no fixture handler / client Mcp-Param-* support yet. - http-custom-headers - http-invalid-tool-headers - # SEP-2352 (authorization server migration): the client re-registers and does not reuse the old - # AS credentials, but this 2026-introduced scenario runs at 2026-07-28, where client.py's 2025 - # stateful lifecycle is rejected (400 on initialize) before the migration 401 fires, so the - # re-register check is never reached. Unblocks with the 2026 stateless client lifecycle. - - auth/authorization-server-migration # --- Pre-existing scenarios that fail on checks added after conformance 0.1.15 --- # SEP-990 (enterprise-managed authorization extension): no fixture handler / From b6be755548a0d96d0d9008e1879d6bb0d8006b3b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 23 Jun 2026 11:27:05 +0000 Subject: [PATCH 10/22] Transport: pre-session 404 maps to METHOD_NOT_FOUND; POSTs never read cached pv header. serve_one reshape + raise_exceptions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Bare HTTP 404 before a session is established now maps to METHOD_NOT_FOUND (was INVALID_REQUEST/"Session terminated", which is meaningless pre-session); with a session_id, 404 keeps the session-terminated mapping - _prepare_headers split: _base_headers (POST) vs _prepare_headers (GET/DELETE). POSTs never read the cached MCP-Protocol-Version header — they get it via per-message metadata only. Prevents the discover probe's header from leaking onto a fallback initialize POST. - serve_one reshaped to (server, dctx, method, params, *, ..., raise_exceptions); modern_on_request drops the JSONRPCRequest round-trip and threads raise_exceptions through to to_jsonrpc_response(raise_unhandled=). Client's modern in-process branch now honors raise_exceptions (handler exceptions chain via __cause__ instead of being sanitized to INTERNAL_ERROR). --- src/mcp/client/client.py | 5 +- src/mcp/client/streamable_http.py | 29 +++++++--- src/mcp/server/_streamable_http_modern.py | 2 +- src/mcp/server/runner.py | 59 +++++++++++++------- tests/client/test_client.py | 17 ++++++ tests/client/test_notification_response.py | 11 +++- tests/client/test_streamable_http.py | 62 +++++++++++++++++++++- tests/server/test_runner.py | 48 ++++++++++++++--- 8 files changed, 194 insertions(+), 39 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index b5b234e3f5..a50ef5d403 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -159,7 +159,10 @@ async def __aenter__(self) -> Client: client_disp, server_disp = create_direct_dispatcher_pair() tg = await exit_stack.enter_async_context(anyio.create_task_group()) exit_stack.callback(server_disp.close) - await tg.start(server_disp.run, modern_on_request(self._inproc_server, lifespan_state), _drop_notify) + on_request = modern_on_request( + self._inproc_server, lifespan_state, raise_exceptions=self.raise_exceptions + ) + await tg.start(server_disp.run, on_request, _drop_notify) session = ClientSession( dispatcher=client_disp, read_timeout_seconds=self.read_timeout_seconds, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 962e104e1e..d4c4d3995c 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -23,6 +23,7 @@ from mcp.types import ( INTERNAL_ERROR, INVALID_REQUEST, + METHOD_NOT_FOUND, PARSE_ERROR, ErrorData, JSONRPCError, @@ -84,19 +85,29 @@ def __init__(self, url: str) -> None: # GET/DELETE that don't carry per-message metadata. self._protocol_version_header: str | None = None - def _prepare_headers(self) -> dict[str, str]: - """Build MCP-specific request headers. + def _base_headers(self) -> dict[str, str]: + """Build MCP-specific request headers (accept / content-type / session-id). These headers will be merged with the httpx.AsyncClient's default headers, - with these MCP-specific headers taking precedence. + with these MCP-specific headers taking precedence. POSTs use this directly: + their protocol-version header arrives per-message via ``metadata.headers``, + so they must never read the cached value. """ headers: dict[str, str] = { "accept": "application/json, text/event-stream", "content-type": "application/json", } - # Add session headers if available if self.session_id: headers[MCP_SESSION_ID] = self.session_id + return headers + + def _prepare_headers(self) -> dict[str, str]: + """Base headers plus the cached protocol-version header. + + Used by transport-internal GET/DELETE (listen stream, resumption, + reconnect, terminate) which don't carry per-message metadata. + """ + headers = self._base_headers() if self._protocol_version_header: headers[MCP_PROTOCOL_VERSION_HEADER] = self._protocol_version_header return headers @@ -238,7 +249,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" - headers = self._prepare_headers() + headers = self._base_headers() message = ctx.session_message.message if ctx.metadata is not None and ctx.metadata.headers is not None: headers.update(ctx.metadata.headers) @@ -276,7 +287,13 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: pass logger.debug("Non-2xx body was not a JSON-RPC error; using fallback") if response.status_code == 404: - error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") + if self.session_id is None: + # No session yet → 404 is the HTTP-level spelling of + # METHOD_NOT_FOUND (gateway / legacy server doesn't know + # this method); "Session terminated" would be a lie here. + error_data = ErrorData(code=METHOD_NOT_FOUND, message="Not Found") + else: + error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") else: error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response") session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) diff --git a/src/mcp/server/_streamable_http_modern.py b/src/mcp/server/_streamable_http_modern.py index 2ddeb78ed3..732656fd8b 100644 --- a/src/mcp/server/_streamable_http_modern.py +++ b/src/mcp/server/_streamable_http_modern.py @@ -193,5 +193,5 @@ async def handle_modern_request( request_id=req.id, message_metadata=ServerMessageMetadata(request_context=request), ) - msg = await serve_one(app, req, connection=connection, dctx=dctx, lifespan_state=lifespan_state) + msg = await serve_one(app, dctx, req.method, req.params, connection=connection, lifespan_state=lifespan_state) await _write(msg, scope, receive, send) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 4cf508363a..8b71edb69f 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -49,7 +49,6 @@ InitializeRequestParams, InitializeResult, JSONRPCError, - JSONRPCRequest, JSONRPCResponse, RequestId, RequestParams, @@ -183,20 +182,26 @@ async def aclose_shielded(connection: Connection) -> None: ) -async def to_jsonrpc_response(request_id: RequestId, coro: Awaitable[dict[str, Any]]) -> JSONRPCResponse | JSONRPCError: +async def to_jsonrpc_response( + request_id: RequestId, coro: Awaitable[dict[str, Any]], *, raise_unhandled: bool = False +) -> JSONRPCResponse | JSONRPCError: """Await ``coro`` and wrap its outcome as the JSON-RPC reply for ``request_id``. The exception-to-wire boundary for the request-per-call drivers (`serve_one`, the modern HTTP entry). `MCPError` and `ValidationError` map via the shared `handler_exception_to_error_data` ladder; any other exception is logged and surfaced as `INTERNAL_ERROR` so handler internals - never reach the wire. + never reach the wire. Set ``raise_unhandled`` to let unmapped exceptions + propagate instead of being sanitized — used by the in-process test path so + handler tracebacks reach the caller. """ try: result = await coro except Exception as exc: error = handler_exception_to_error_data(exc) if error is None: + if raise_unhandled: + raise logger.exception("request handler raised") error = ErrorData(code=INTERNAL_ERROR, message="Internal server error") return JSONRPCError(jsonrpc="2.0", id=request_id, error=error) @@ -497,34 +502,45 @@ async def serve_loop( async def serve_one( server: Server[LifespanT], - request: JSONRPCRequest, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, *, connection: Connection, - dctx: DispatchContext[TransportContext], lifespan_state: LifespanT, + raise_exceptions: bool = False, ) -> JSONRPCResponse | JSONRPCError: - """Handle a single ``request`` and return its JSON-RPC reply. - - The single-exchange driver: builds the kernel, runs `on_request` once for - `request` under `dctx`, maps the outcome to a `JSONRPCResponse` / - `JSONRPCError` via `to_jsonrpc_response`, and tears down - `connection.exit_stack` (shielded) on the way out. The entry constructs - the (born-ready) `Connection` and the `dctx`; this only consumes them. + """Handle a single request ``(method, params)`` and return its JSON-RPC reply. + + The single-exchange driver: builds the kernel, runs `on_request` once under + `dctx`, maps the outcome to a `JSONRPCResponse` / `JSONRPCError` via + `to_jsonrpc_response`, and tears down `connection.exit_stack` (shielded) on + the way out. The entry constructs the (born-ready) `Connection` and the + `dctx`; this only consumes them. ``raise_exceptions`` lets unmapped handler + exceptions propagate instead of being sanitized to `INTERNAL_ERROR`. """ runner = ServerRunner(server, connection, lifespan_state) try: - return await to_jsonrpc_response(request.id, runner.on_request(dctx, request.method, request.params)) + # Single-exchange driver only handles requests; both entries populate `request_id`. + # TODO(L54): drop once `DispatchContext` is split so `OnRequest` carries a non-Optional id. + assert dctx.request_id is not None + return await to_jsonrpc_response( + dctx.request_id, runner.on_request(dctx, method, params), raise_unhandled=raise_exceptions + ) finally: await aclose_shielded(connection) -def modern_on_request(server: Server[LifespanT], lifespan_state: LifespanT) -> OnRequest: +def modern_on_request( + server: Server[LifespanT], lifespan_state: LifespanT, *, raise_exceptions: bool = False +) -> OnRequest: """Return an `OnRequest` callback that serves each call via `serve_one` with a fresh per-request `Connection`. Wire this into the server side of a `DirectDispatcher` peer-pair to drive an in-process server on the modern per-request-envelope path (each request carries protocol version, client info, and capabilities in `params._meta`; - no `initialize` handshake). + no `initialize` handshake). ``raise_exceptions`` lets unmapped handler + exceptions propagate to the caller for debuggable in-process testing. """ async def handle( @@ -536,12 +552,15 @@ async def handle( meta.get(CLIENT_INFO_META_KEY), meta.get(CLIENT_CAPABILITIES_META_KEY), ) - # `OnRequest` is invoked for requests only, so `request_id` is always set. - assert dctx.request_id is not None - req = JSONRPCRequest( - jsonrpc="2.0", id=dctx.request_id, method=method, params=dict(params) if params is not None else None + msg = await serve_one( + server, + dctx, + method, + params, + connection=connection, + lifespan_state=lifespan_state, + raise_exceptions=raise_exceptions, ) - msg = await serve_one(server, req, connection=connection, dctx=dctx, lifespan_state=lifespan_state) if isinstance(msg, JSONRPCError): raise MCPError(code=msg.error.code, message=msg.error.message, data=msg.error.data) return msg.result diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 6ebdf9553e..bd33b39ab8 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -203,6 +203,23 @@ async def handle_read_resource( assert exc_info.value.error.code == 404 +async def test_raise_exceptions_propagates_handler_error_on_modern_inproc_path(): + """`raise_exceptions=True` on the modern in-process path: an unmapped handler + exception reaches the client with its original type chained, instead of being + sanitized to an opaque `INTERNAL_ERROR`.""" + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + raise ValueError("boom") + + server = Server("test", on_call_tool=handle_call_tool) + async with Client(server, mode="2026-07-28", raise_exceptions=True) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("explode", {}) + # The original exception is chained — not swallowed into a generic "Internal server error". + assert isinstance(exc_info.value.__cause__, ValueError) + assert str(exc_info.value.__cause__) == "boom" + + async def test_get_prompt(app: MCPServer): """Test getting a prompt.""" async with Client(app) as client: diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 4dbd78dbbe..bd85cd074a 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -204,12 +204,19 @@ async def test_invalid_json_response_sends_jsonrpc_error() -> None: def _create_non_2xx_json_body_app(status: int, body: bytes) -> Starlette: - """Server that returns a fixed non-2xx status + ``application/json`` body for non-init requests.""" + """Server that returns a fixed non-2xx status + ``application/json`` body for non-init requests. + + The initialize response carries an ``mcp-session-id`` so the client treats subsequent + requests as part of an established session (needed for the 404 → session-terminated mapping). + """ async def handle_mcp_request(request: Request) -> Response: data = json.loads(await request.body()) if data.get("method") == "initialize": - return _init_json_response(data) + return JSONResponse( + {"jsonrpc": "2.0", "id": data["id"], "result": INIT_RESPONSE}, + headers={"mcp-session-id": "test-session"}, + ) if "id" not in data: return Response(status_code=202) return Response(content=body, status_code=status, media_type="application/json") diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index 92f1fc8981..77b1fdc061 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -15,9 +15,9 @@ from inline_snapshot import snapshot from mcp.client.streamable_http import streamable_http_client -from mcp.shared.inbound import encode_header_value +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER, encode_header_value from mcp.shared.message import ClientMessageMetadata, SessionMessage -from mcp.types import JSONRPCRequest +from mcp.types import METHOD_NOT_FOUND, JSONRPCError, JSONRPCRequest @pytest.mark.parametrize( @@ -78,3 +78,61 @@ def handler(request: httpx.Request) -> httpx.Response: assert isinstance(reply, SessionMessage) assert [r.method for r in recorded] == ["POST"] assert recorded[0].headers["x-test"] == "v" + + +@pytest.mark.anyio +async def test_pre_session_bare_404_maps_to_method_not_found() -> None: + """A bare HTTP 404 (no JSON-RPC body) before any session-id is held maps to METHOD_NOT_FOUND. + + Gateways and legacy servers 404 at the HTTP layer for unknown methods; with no session yet, + "Session terminated" is meaningless, and the discover→initialize fallback ladder keys on -32601. + """ + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(404) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (read, write), + ): + await write.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="server/discover", params={}))) + reply = await read.receive() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCError) + assert reply.message.error.code == METHOD_NOT_FOUND + + +@pytest.mark.anyio +async def test_post_does_not_read_cached_protocol_version_header() -> None: + """A POST's protocol-version header comes only from its own ``metadata.headers``. + + The first POST carries (and caches) a pv header; the second POST sends no metadata + and must therefore carry no pv header — a stale cached value would poison the + fallback ``initialize`` after a failed discover probe. The cache exists for + transport-internal GET/DELETE only. + """ + recorded: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + recorded.append(request) + body = json.loads(request.content) + return httpx.Response(200, json={"jsonrpc": "2.0", "id": body["id"], "result": {}}) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (read, write), + ): + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="server/discover", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: "2026-07-28"}), + ) + ) + await read.receive() + await write.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=2, method="initialize", params={}))) + await read.receive() + assert [r.method for r in recorded] == ["POST", "POST"] + assert recorded[0].headers[MCP_PROTOCOL_VERSION_HEADER] == "2026-07-28" + assert MCP_PROTOCOL_VERSION_HEADER not in recorded[1].headers diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 8707dc077e..2bf7dae983 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -51,7 +51,6 @@ Implementation, InitializeRequestParams, JSONRPCError, - JSONRPCRequest, JSONRPCResponse, ListToolsResult, NotificationParams, @@ -1267,6 +1266,33 @@ async def fail() -> dict[str, Any]: assert "request handler raised" in caplog.text +@pytest.mark.anyio +async def test_to_jsonrpc_response_raise_unhandled_propagates_unmapped_exception(): + """SDK-defined: ``raise_unhandled=True`` lets an unmapped exception escape + instead of being sanitized to `INTERNAL_ERROR` — used by the in-process test + path so the original traceback reaches the caller.""" + + async def fail() -> dict[str, Any]: + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + await to_jsonrpc_response(1, fail(), raise_unhandled=True) + + +@pytest.mark.anyio +async def test_to_jsonrpc_response_raise_unhandled_still_maps_mcp_error(): + """SDK-defined: ``raise_unhandled`` only affects unmapped exceptions; an + `MCPError` is still converted to a `JSONRPCError` (it is protocol-level, not + a crash).""" + + async def fail() -> dict[str, Any]: + raise MCPError(code=METHOD_NOT_FOUND, message="nope") + + reply = await to_jsonrpc_response(1, fail(), raise_unhandled=True) + assert isinstance(reply, JSONRPCError) + assert reply.error.code == METHOD_NOT_FOUND + + # --- aclose_shielded ----------------------------------------------------------- @@ -1331,8 +1357,9 @@ async def test_serve_one_runs_handler_and_returns_jsonrpc_response(server: SrvT) conn = Connection.from_envelope(HANDSHAKE_PROTOCOL_VERSIONS[-1], None, None) cleaned: list[int] = [] conn.exit_stack.push_async_callback(_append_async, cleaned, 1) - request = JSONRPCRequest(jsonrpc="2.0", id=9, method="tools/list", params=None) - reply = await serve_one(server, request, connection=conn, dctx=_StubDispatchContext(9), lifespan_state=_LIFESPAN) + reply = await serve_one( + server, _StubDispatchContext(9), "tools/list", None, connection=conn, lifespan_state=_LIFESPAN + ) assert isinstance(reply, JSONRPCResponse) assert reply.id == 9 assert reply.result["tools"][0]["name"] == "t" @@ -1349,8 +1376,9 @@ async def test_serve_one_maps_error_to_jsonrpc_error_and_still_closes_exit_stack conn = Connection.from_envelope(HANDSHAKE_PROTOCOL_VERSIONS[-1], None, None) cleaned: list[int] = [] conn.exit_stack.push_async_callback(_append_async, cleaned, 1) - request = JSONRPCRequest(jsonrpc="2.0", id=2, method="resources/list", params=None) - reply = await serve_one(server, request, connection=conn, dctx=_StubDispatchContext(2), lifespan_state=_LIFESPAN) + reply = await serve_one( + server, _StubDispatchContext(2), "resources/list", None, connection=conn, lifespan_state=_LIFESPAN + ) assert isinstance(reply, JSONRPCError) assert reply.error.code == METHOD_NOT_FOUND assert cleaned == [1] @@ -1362,8 +1390,14 @@ async def test_serve_one_reads_connection_protocol_version_as_a_fact(server: Srv reads `connection.protocol_version` for the version gate. A `from_envelope` connection at a modern version rejects a method absent there.""" conn = Connection.from_envelope(MODERN_PROTOCOL_VERSIONS[0], None, None) - request = JSONRPCRequest(jsonrpc="2.0", id=1, method="logging/setLevel", params={"level": "info"}) - reply = await serve_one(server, request, connection=conn, dctx=_StubDispatchContext(1), lifespan_state=_LIFESPAN) + reply = await serve_one( + server, + _StubDispatchContext(1), + "logging/setLevel", + {"level": "info"}, + connection=conn, + lifespan_state=_LIFESPAN, + ) assert isinstance(reply, JSONRPCError) assert reply.error.code == METHOD_NOT_FOUND From a82042a37805cab6ab6e86fbe0f0c4bd9aaaf07c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 23 Jun 2026 11:39:52 +0000 Subject: [PATCH 11/22] discover() returns DiscoverResult; fallback ladder moves to Client; era-neutral accessors - ClientSession.discover() -> DiscoverResult: no fallback (METHOD_NOT_FOUND/ REQUEST_TIMEOUT propagate; Client owns that policy), no InitializeResult synthesis. Separate _discover_result/_initialize_result/_negotiated_version slots. .adopt() sets the matching slot; no more synthesis. - Era-neutral properties on ClientSession and Client: .server_info, .server_capabilities, .instructions, .protocol_version read from whichever result is set. ClientSession.discover_result for prior_discover round-trip. - Client.__aenter__: mode='auto' wraps discover() with the fallback ladder (METHOD_NOT_FOUND | REQUEST_TIMEOUT -> initialize()). _build_session helper consolidates the dispatcher/transport branching to one ClientSession() site. - Client.initialize_result removed (use the era-neutral accessors). - mode= validated in __post_init__: ValueError on unknown values, with a redirect hint for handshake-era versions. - adopt()/discover() docstrings gain Raises: sections. --- src/mcp/client/client.py | 148 +++++++++++------- src/mcp/client/session.py | 131 ++++++++++------ tests/client/test_client.py | 82 +++++++++- tests/client/test_session.py | 111 +++++++------ tests/client/transports/test_memory.py | 2 +- .../lowlevel/test_client_connect.py | 14 +- tests/interaction/lowlevel/test_completion.py | 2 +- tests/interaction/lowlevel/test_initialize.py | 10 +- .../interaction/mcpserver/test_completion.py | 4 +- tests/interaction/mcpserver/test_context.py | 2 +- tests/interaction/transports/test_stdio.py | 2 +- .../transports/test_streamable_http.py | 2 +- tests/server/mcpserver/test_integration.py | 12 +- tests/server/mcpserver/test_server.py | 2 +- 14 files changed, 332 insertions(+), 192 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index a50ef5d403..8bfeeb6b9a 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -19,15 +19,17 @@ from mcp.server.mcpserver import MCPServer from mcp.server.runner import modern_on_request from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair -from mcp.shared.dispatcher import ProgressFnT -from mcp.shared.exceptions import MCPDeprecationWarning +from mcp.shared.dispatcher import Dispatcher, ProgressFnT +from mcp.shared.exceptions import MCPDeprecationWarning, MCPError +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( + METHOD_NOT_FOUND, + REQUEST_TIMEOUT, CallToolResult, CompleteResult, EmptyResult, GetPromptResult, Implementation, - InitializeResult, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, @@ -38,8 +40,14 @@ ReadResourceResult, RequestParamsMeta, ResourceTemplateReference, + ServerCapabilities, ) +ConnectMode = Literal["legacy", "auto"] | str +"""``mode=`` value: ``"legacy"`` (initialize handshake), ``"auto"`` (discover, fall back to +initialize), or a modern protocol-version string (adopt directly). The ``str`` arm is for +forward-compat; ``Client.__post_init__`` rejects anything outside that set at construction.""" + def _synthesize_discover(protocol_version: str) -> types.DiscoverResult: return types.DiscoverResult( @@ -119,10 +127,10 @@ async def main(): client_info: Implementation | None = None """Client implementation info to send to server.""" - mode: Literal["legacy", "auto"] | str = "legacy" + mode: ConnectMode = "legacy" """'legacy' performs the initialize handshake. 'auto' probes server/discover and falls back to initialize() - on legacy servers. A protocol-version string (e.g. '2026-07-28') adopts that version directly without a - handshake — supply prior_discover to reuse a known DiscoverResult, or omit it to synthesize a minimal one.""" + on legacy servers. A modern protocol-version string (e.g. '2026-07-28') adopts that version directly without + a handshake — supply prior_discover to reuse a known DiscoverResult, or omit it to synthesize a minimal one.""" prior_discover: types.DiscoverResult | None = None """A previously-obtained DiscoverResult to install via .adopt() when mode is a version pin. @@ -146,60 +154,70 @@ def __post_init__(self) -> None: else: self._transport = self.server + if self.mode not in ("legacy", "auto") and self.mode not in MODERN_PROTOCOL_VERSIONS: + hint = ( + f" ({self.mode!r} is a handshake-era version — use mode='legacy')" + if self.mode in HANDSHAKE_PROTOCOL_VERSIONS + else "" + ) + raise ValueError( + f"mode must be 'legacy', 'auto', or one of {list(MODERN_PROTOCOL_VERSIONS)}; got {self.mode!r}{hint}" + ) + + async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession: + """Set up the dispatcher/transport and return an un-entered ClientSession.""" + dispatcher: Dispatcher[Any] | None + if self._inproc_server is not None and self.mode != "legacy": + # Modern in-process path: drive the server through a DirectDispatcher peer-pair + # with one `serve_one` per request — no streams, no initialize handshake. + lifespan_state = await exit_stack.enter_async_context(self._inproc_server.lifespan(self._inproc_server)) + client_disp, server_disp = create_direct_dispatcher_pair() + tg = await exit_stack.enter_async_context(anyio.create_task_group()) + exit_stack.callback(server_disp.close) + on_request = modern_on_request(self._inproc_server, lifespan_state, raise_exceptions=self.raise_exceptions) + await tg.start(server_disp.run, on_request, _drop_notify) + dispatcher = client_disp + read_stream = write_stream = None + else: + if self._inproc_server is not None: + transport: Transport = InMemoryTransport(self._inproc_server, raise_exceptions=self.raise_exceptions) + else: + assert self._transport is not None + transport = self._transport + read_stream, write_stream = await exit_stack.enter_async_context(transport) + dispatcher = None + return ClientSession( + read_stream=read_stream, + write_stream=write_stream, + dispatcher=dispatcher, + read_timeout_seconds=self.read_timeout_seconds, + sampling_callback=self.sampling_callback, + list_roots_callback=self.list_roots_callback, + logging_callback=self.logging_callback, + message_handler=self.message_handler, + client_info=self.client_info, + elicitation_callback=self.elicitation_callback, + ) + async def __aenter__(self) -> Client: """Enter the async context manager.""" if self._session is not None: raise RuntimeError("Client is already entered; cannot reenter") async with AsyncExitStack() as exit_stack: - if self._inproc_server is not None and self.mode != "legacy": - # Modern in-process path: drive the server through a DirectDispatcher peer-pair - # with one `serve_one` per request — no streams, no initialize handshake. - lifespan_state = await exit_stack.enter_async_context(self._inproc_server.lifespan(self._inproc_server)) - client_disp, server_disp = create_direct_dispatcher_pair() - tg = await exit_stack.enter_async_context(anyio.create_task_group()) - exit_stack.callback(server_disp.close) - on_request = modern_on_request( - self._inproc_server, lifespan_state, raise_exceptions=self.raise_exceptions - ) - await tg.start(server_disp.run, on_request, _drop_notify) - session = ClientSession( - dispatcher=client_disp, - read_timeout_seconds=self.read_timeout_seconds, - sampling_callback=self.sampling_callback, - list_roots_callback=self.list_roots_callback, - logging_callback=self.logging_callback, - message_handler=self.message_handler, - client_info=self.client_info, - elicitation_callback=self.elicitation_callback, - ) - else: - if self._inproc_server is not None: - transport: Transport = InMemoryTransport( - self._inproc_server, raise_exceptions=self.raise_exceptions - ) - else: - assert self._transport is not None - transport = self._transport - read_stream, write_stream = await exit_stack.enter_async_context(transport) - session = ClientSession( - read_stream=read_stream, - write_stream=write_stream, - read_timeout_seconds=self.read_timeout_seconds, - sampling_callback=self.sampling_callback, - list_roots_callback=self.list_roots_callback, - logging_callback=self.logging_callback, - message_handler=self.message_handler, - client_info=self.client_info, - elicitation_callback=self.elicitation_callback, - ) - + session = await self._build_session(exit_stack) self._session = await exit_stack.enter_async_context(session) if self.mode == "legacy": await self._session.initialize() elif self.mode == "auto": - await self._session.discover() + try: + await self._session.discover() + except MCPError as e: + if e.code in (METHOD_NOT_FOUND, REQUEST_TIMEOUT): + await self._session.initialize() + else: + raise else: self._session.adopt(self.prior_discover or _synthesize_discover(self.mode)) @@ -227,16 +245,30 @@ def session(self) -> ClientSession: return self._session @property - def initialize_result(self) -> InitializeResult: - """The server's InitializeResult. + def protocol_version(self) -> str: + """Negotiated protocol version (set by initialize/discover/adopt during ``__aenter__``).""" + version = self.session.protocol_version + assert version is not None + return version - Contains server_info, capabilities, instructions, and the negotiated protocol_version. - Raises RuntimeError if accessed outside the context manager. - """ - result = self.session.initialize_result - if result is None: # pragma: no cover - raise RuntimeError("Client must be used within an async context manager") - return result + @property + def server_info(self) -> Implementation: + """Server name/version (set by initialize/discover/adopt during ``__aenter__``).""" + info = self.session.server_info + assert info is not None + return info + + @property + def server_capabilities(self) -> ServerCapabilities: + """Server capabilities (set by initialize/discover/adopt during ``__aenter__``).""" + caps = self.session.server_capabilities + assert caps is not None + return caps + + @property + def instructions(self) -> str | None: + """Server-provided instructions text, if any.""" + return self.session.instructions async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> EmptyResult: """Send a ping request to the server.""" diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 5f4d795211..cc3d7cbba0 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -34,7 +34,6 @@ INTERNAL_ERROR, METHOD_NOT_FOUND, PROTOCOL_VERSION_META_KEY, - REQUEST_TIMEOUT, UNSUPPORTED_PROTOCOL_VERSION, RequestId, RequestParamsMeta, @@ -204,6 +203,8 @@ def __init__( self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._initialize_result: types.InitializeResult | None = None + self._discover_result: types.DiscoverResult | None = None + self._negotiated_version: str | None = None self._stamp: Callable[[dict[str, Any], CallOptions], None] = _preconnect_stamp self._task_group: anyio.abc.TaskGroup | None = None if dispatcher is not None: @@ -287,7 +288,7 @@ async def send_request( opts["on_resumption_token"] = metadata.on_resumption_token_update raw = await self._dispatcher.send_raw_request(method, data.get("params"), opts) # Literal fallback covers pre-handshake and stateless; matches runner.py. - version = self.protocol_version or "2025-11-25" + version = self._negotiated_version or "2025-11-25" try: _methods.validate_server_result(method, version, raw) except KeyError: @@ -350,48 +351,48 @@ async def initialize(self) -> types.InitializeResult: return result def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None: - """Install negotiated state from a result the caller already holds (no wire traffic).""" + """Install negotiated state from a result the caller already holds (no wire traffic). + + Raises: + RuntimeError: `result` is a `DiscoverResult` whose `supported_versions` + shares nothing with this client's `MODERN_PROTOCOL_VERSIONS`. + """ if isinstance(result, types.DiscoverResult): + # ordered oldest→newest via MODERN_PROTOCOL_VERSIONS mutual = [v for v in MODERN_PROTOCOL_VERSIONS if v in result.supported_versions] if not mutual: raise RuntimeError( f"No mutually supported modern protocol version " f"(server: {result.supported_versions}, client: {list(MODERN_PROTOCOL_VERSIONS)})" ) - protocol_version = mutual[-1] client_info = self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True) capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) - self._stamp = _make_modern_stamp(protocol_version, client_info, capabilities) - self._initialize_result = types.InitializeResult( - protocol_version=protocol_version, - capabilities=result.capabilities, - server_info=result.server_info, - instructions=result.instructions, - ) + self._stamp = _make_modern_stamp(mutual[-1], client_info, capabilities) + self._discover_result = result + self._negotiated_version = mutual[-1] else: self._stamp = _make_handshake_stamp(result.protocol_version) self._initialize_result = result + self._negotiated_version = result.protocol_version - async def discover(self) -> types.InitializeResult: - """Probe `server/discover` and adopt the result, falling back to `initialize()`. + async def discover(self) -> types.DiscoverResult: + """Probe `server/discover` and adopt the result. Sends a single `server/discover` proposing the newest modern protocol - version. The error ladder, in order: - - - `UNSUPPORTED_PROTOCOL_VERSION` (-32022): the server's `supported` - list is intersected with `MODERN_PROTOCOL_VERSIONS` and the probe is - retried once at the highest mutual version. No mutual version, or a - second failure, raises the server's `MCPError`. - - `METHOD_NOT_FOUND` (-32601) or `REQUEST_TIMEOUT` (-32001): the server - is treated as legacy and `initialize()` runs instead — exactly as - ``mode='legacy'`` would. - - Any other error: re-raised. - - Returns the synthesized `InitializeResult` (also available afterwards - via `initialize_result`). + version. On `UNSUPPORTED_PROTOCOL_VERSION` (-32022) the server's + `supported` list is intersected with `MODERN_PROTOCOL_VERSIONS` and the + probe is retried once at the highest mutual version. Any other error — + including `METHOD_NOT_FOUND` (-32601) and `REQUEST_TIMEOUT` (-32001) — + propagates; the legacy `initialize()` fallback is the caller's policy. + + Raises: + MCPError: The server rejected `server/discover`, the probe timed + out, or the -32022 retry found no mutual version / failed again. + RuntimeError: `adopt()` found no mutual version in the returned + `supported_versions`. """ - if self._initialize_result is not None: - return self._initialize_result + if self._discover_result is not None: + return self._discover_result client_info = self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True) capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) @@ -414,39 +415,67 @@ async def probe(version: str) -> dict[str, Any]: try: raw = await probe(MODERN_PROTOCOL_VERSIONS[-1]) except MCPError as e: - if e.code == UNSUPPORTED_PROTOCOL_VERSION: - try: - data = types.UnsupportedProtocolVersionErrorData.model_validate(e.error.data) - except ValidationError: - raise e from None - mutual = [v for v in MODERN_PROTOCOL_VERSIONS if v in data.supported] - if not mutual: - raise - raw = await probe(mutual[-1]) - elif e.code in (METHOD_NOT_FOUND, REQUEST_TIMEOUT): - return await self.initialize() - else: + if e.code != UNSUPPORTED_PROTOCOL_VERSION: + raise + try: + data = types.UnsupportedProtocolVersionErrorData.model_validate(e.error.data) + except ValidationError: + raise e from None + # ordered oldest→newest via MODERN_PROTOCOL_VERSIONS + mutual = [v for v in MODERN_PROTOCOL_VERSIONS if v in data.supported] + if not mutual: raise + raw = await probe(mutual[-1]) result = types.DiscoverResult.model_validate(raw) self.adopt(result) - assert self._initialize_result is not None - return self._initialize_result + return result @property def initialize_result(self) -> types.InitializeResult | None: - """The server's InitializeResult. None until `initialize()` or `adopt()`. + """The server's InitializeResult. None unless `initialize()` ran (or was adopted).""" + return self._initialize_result + + @property + def discover_result(self) -> types.DiscoverResult | None: + """The server's DiscoverResult. None unless `discover()` ran (or was adopted). - Contains server_info, capabilities, instructions, and the negotiated - protocol_version. For a modern session adopted from a DiscoverResult, - this is synthesized locally with the chosen protocol version. + Retained intact (supported_versions, ttl_ms, cache_scope) so callers + can round-trip it as ``prior_discover=``. """ - return self._initialize_result + return self._discover_result @property def protocol_version(self) -> str | None: - """Negotiated protocol version. None until `initialize()` or `adopt()`.""" - return self._initialize_result.protocol_version if self._initialize_result is not None else None + """Negotiated protocol version. None until `initialize()`, `discover()`, or `adopt()`.""" + return self._negotiated_version + + @property + def server_info(self) -> types.Implementation | None: + """Server name/version. None until `initialize()`, `discover()`, or `adopt()`.""" + if self._discover_result is not None: + return self._discover_result.server_info + if self._initialize_result is not None: + return self._initialize_result.server_info + return None + + @property + def server_capabilities(self) -> types.ServerCapabilities | None: + """Server capabilities. None until `initialize()`, `discover()`, or `adopt()`.""" + if self._discover_result is not None: + return self._discover_result.capabilities + if self._initialize_result is not None: + return self._initialize_result.capabilities + return None + + @property + def instructions(self) -> str | None: + """Server-provided instructions text, if any.""" + if self._discover_result is not None: + return self._discover_result.instructions + if self._initialize_result is not None: + return self._initialize_result.instructions + return None async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: """Send a ping request.""" @@ -655,7 +684,7 @@ async def _on_request( # Literal, not LATEST_PROTOCOL_VERSION: the fallback covers the initialize # handshake (which only exists at <=2025) and stateless until the header # is plumbed; its meaning is fixed regardless of LATEST bumps. - version = self.protocol_version or "2025-11-25" + version = self._negotiated_version or "2025-11-25" try: request = cast(types.ServerRequest, _methods.parse_server_request(method, version, params)) except KeyError: @@ -693,7 +722,7 @@ async def _on_notify( ) -> None: """Route a server notification: validate, run the typed callback, tee to message_handler.""" # Same fallback as `_on_request`: covers pre-handshake and stateless. - version = self.protocol_version or "2025-11-25" + version = self._negotiated_version or "2025-11-25" try: notification = cast(types.ServerNotification, _methods.parse_server_notification(method, version, params)) except KeyError: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index bd33b39ab8..e6464fde11 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -3,8 +3,8 @@ from __future__ import annotations import contextvars -from collections.abc import Iterator -from contextlib import contextmanager +from collections.abc import AsyncIterator, Iterator +from contextlib import asynccontextmanager, contextmanager from unittest.mock import patch import anyio @@ -13,10 +13,13 @@ from mcp import MCPError, types from mcp.client._memory import InMemoryTransport +from mcp.client._transport import TransportStreams from mcp.client.client import Client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS from mcp.types import ( CallToolResult, @@ -105,7 +108,7 @@ def greeting_prompt(name: str) -> str: async def test_client_is_initialized(app: MCPServer): """Test that the client is initialized after entering context.""" async with Client(app) as client: - assert client.initialize_result.capabilities == snapshot( + assert client.server_capabilities == snapshot( ServerCapabilities( experimental={}, prompts=PromptsCapability(list_changed=False), @@ -113,13 +116,13 @@ async def test_client_is_initialized(app: MCPServer): tools=ToolsCapability(list_changed=False), ) ) - assert client.initialize_result.server_info.name == "test" + assert client.server_info.name == "test" -async def test_client_initialize_result_exposes_negotiated_protocol_version(app: MCPServer): +async def test_client_exposes_negotiated_protocol_version(app: MCPServer): """The negotiated protocol version is readable after initialization.""" async with Client(app) as client: - assert client.initialize_result.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert client.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] async def test_client_with_simple_server(simple_server: Server): @@ -391,5 +394,70 @@ async def test_client_auto_mode_probes_discover_then_adopts(simple_server: Serve mounted_app(simple_server) as (http, _), Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto") as client, ): - assert client.initialize_result.protocol_version == "2026-07-28" + assert client.protocol_version == "2026-07-28" assert (await client.list_resources()).resources[0].name == "Test Resource" + + +@pytest.mark.parametrize("code", [types.METHOD_NOT_FOUND, types.REQUEST_TIMEOUT]) +async def test_client_auto_mode_falls_back_to_initialize_on_legacy_signal(code: int) -> None: + """`mode='auto'`: when `server/discover` is rejected with -32601 or -32001, + `Client.__aenter__` runs the legacy `initialize()` handshake and lands at a + handshake-era protocol version. The session itself does not fall back — + that policy lives here. A real `Server` always implements `server/discover`, + so the server side is hand-played.""" + methods_seen: list[str] = [] + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + async for message in server_read: + assert isinstance(message, SessionMessage) + frame = message.message + assert isinstance(frame, types.JSONRPCRequest | types.JSONRPCNotification) + methods_seen.append(frame.method) + if isinstance(frame, types.JSONRPCNotification): + continue + if frame.method == "server/discover": + error = types.ErrorData(code=code, message="nope") + await server_write.send(SessionMessage(types.JSONRPCError(jsonrpc="2.0", id=frame.id, error=error))) + elif frame.method == "initialize": # pragma: no branch + result = types.InitializeResult( + protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + capabilities=ServerCapabilities(), + server_info=types.Implementation(name="legacy-only", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + types.JSONRPCResponse( + jsonrpc="2.0", + id=frame.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + @asynccontextmanager + async def scripted_transport() -> AsyncIterator[TransportStreams]: + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ): + tg.start_soon(scripted_server, server_streams) + yield client_read, client_write + tg.cancel_scope.cancel() + + with anyio.fail_after(5): + async with Client(scripted_transport(), mode="auto") as client: + assert client.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert client.server_info.name == "legacy-only" + assert methods_seen == ["server/discover", "initialize", "notifications/initialized"] + + +def test_client_rejects_handshake_era_mode_at_construction() -> None: + """A handshake-era protocol-version string passed as `mode=` is rejected by + `__post_init__` with a hint to use `mode='legacy'` — the version-pin path is + modern-only.""" + server = MCPServer("test") + with pytest.raises(ValueError, match=r"handshake-era version — use mode='legacy'"): + Client(server, mode="2025-06-18") + with pytest.raises(ValueError, match=r"mode must be 'legacy', 'auto', or one of"): + Client(server, mode="not-a-version") diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 54e17e57d3..dabc27d7c4 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -644,6 +644,11 @@ async def mock_server(): assert result.capabilities == expected_capabilities assert result.instructions == expected_instructions assert result.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + # Era-neutral accessors are populated from the InitializeResult. + assert session.server_info == expected_server_info + assert session.server_capabilities == expected_capabilities + assert session.instructions == expected_instructions + assert session.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] @pytest.mark.anyio @@ -785,10 +790,12 @@ async def test_receive_loop_drops_unknown_notification_method_without_response() def _set_negotiated_version(session: ClientSession, version: str) -> None: """Force `session.protocol_version` without running the handshake.""" - session._initialize_result = InitializeResult( - protocol_version=version, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), + session.adopt( + InitializeResult( + protocol_version=version, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ) ) @@ -1471,7 +1478,9 @@ async def send_raw_request( return item async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: - self.notifies.append(method) + self.notifies.append( + method + ) # pragma: no cover — recorded so a wrongly-sent notification fails the == [] assert def _discover_result_dict() -> dict[str, Any]: @@ -1482,14 +1491,6 @@ def _discover_result_dict() -> dict[str, Any]: ).model_dump(by_alias=True, mode="json", exclude_none=True) -def _initialize_result_dict() -> dict[str, Any]: - return InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], - capabilities=ServerCapabilities(), - server_info=Implementation(name="stub", version="0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True) - - @pytest.mark.anyio async def test_discover_adopts_the_returned_result_and_installs_the_modern_stamp() -> None: """SDK-defined: a successful `server/discover` is adopted and subsequent requests @@ -1497,7 +1498,8 @@ async def test_discover_adopts_the_returned_result_and_installs_the_modern_stamp dispatcher = _ScriptedDispatcher(_discover_result_dict(), {}) with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: - await session.discover() + result = await session.discover() + assert isinstance(result, types.DiscoverResult) assert session.protocol_version == "2026-07-28" await session.send_ping() ping_method, ping_params = dispatcher.calls[-1] @@ -1546,47 +1548,20 @@ async def test_discover_raises_when_retry_intersection_is_empty() -> None: @pytest.mark.anyio -async def test_discover_falls_back_to_initialize_on_method_not_found() -> None: - """Spec SHOULD: a legacy server that answers -32601 to `server/discover` is - transparently driven through the handshake instead.""" - dispatcher = _ScriptedDispatcher( - MCPError(METHOD_NOT_FOUND, "Method not found"), - _initialize_result_dict(), - ) - with anyio.fail_after(5): - async with ClientSession(dispatcher=dispatcher) as session: - await session.discover() - assert session.protocol_version in HANDSHAKE_PROTOCOL_VERSIONS - assert [m for m, _ in dispatcher.calls] == ["server/discover", "initialize"] - assert dispatcher.notifies == ["notifications/initialized"] - - -@pytest.mark.anyio -async def test_discover_falls_back_to_initialize_on_timeout() -> None: - """Spec SHOULD: a `REQUEST_TIMEOUT` from the dispatcher is treated the same as - method-not-found — the server is presumed legacy and `initialize()` runs.""" - dispatcher = _ScriptedDispatcher( - MCPError(REQUEST_TIMEOUT, "timed out"), - _initialize_result_dict(), - ) - with anyio.fail_after(5): - async with ClientSession(dispatcher=dispatcher) as session: - await session.discover() - assert session.protocol_version in HANDSHAKE_PROTOCOL_VERSIONS - assert [m for m, _ in dispatcher.calls] == ["server/discover", "initialize"] - - -@pytest.mark.anyio -async def test_discover_reraises_on_other_errors() -> None: - """SDK-defined: any error outside the retry/fallback ladder propagates verbatim - — `discover()` does not mask server failures by falling back to `initialize()`.""" - dispatcher = _ScriptedDispatcher(MCPError(INTERNAL_ERROR, "boom")) +@pytest.mark.parametrize("code", [METHOD_NOT_FOUND, REQUEST_TIMEOUT, INTERNAL_ERROR]) +async def test_discover_reraises_non_retry_errors_without_falling_back(code: int) -> None: + """SDK-defined: any error outside the -32022 retry rung propagates verbatim + — `discover()` does not fall back to `initialize()` itself; that is the + caller's policy (`Client.__aenter__`).""" + dispatcher = _ScriptedDispatcher(MCPError(code, "nope")) with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: with pytest.raises(MCPError) as exc: await session.discover() - assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.code == code + assert session.protocol_version is None assert [m for m, _ in dispatcher.calls] == ["server/discover"] + assert dispatcher.notifies == [] @pytest.mark.anyio @@ -1612,10 +1587,46 @@ async def test_discover_is_idempotent_and_returns_the_cached_result() -> None: with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: first = await session.discover() + assert isinstance(first, types.DiscoverResult) assert await session.discover() is first + assert session.discover_result is first assert [m for m, _ in dispatcher.calls] == ["server/discover"] +def test_era_neutral_properties_are_none_before_any_handshake() -> None: + """SDK-defined: the era-neutral accessors all read as None on a fresh session.""" + client_d, _ = create_direct_dispatcher_pair() + session = ClientSession(dispatcher=client_d) + assert session.protocol_version is None + assert session.server_info is None + assert session.server_capabilities is None + assert session.instructions is None + assert session.discover_result is None + assert session.initialize_result is None + + +@pytest.mark.anyio +async def test_era_neutral_properties_after_discover() -> None: + """SDK-defined: after `discover()` the era-neutral accessors read from the + DiscoverResult; `initialize_result` stays None.""" + raw = types.DiscoverResult( + supported_versions=["2026-07-28"], + capabilities=ServerCapabilities(tools=types.ToolsCapability(list_changed=True)), + server_info=Implementation(name="discovered", version="2.0"), + instructions="hello", + ).model_dump(by_alias=True, mode="json", exclude_none=True) + dispatcher = _ScriptedDispatcher(raw) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.discover() + assert session.protocol_version == "2026-07-28" + assert session.server_info == Implementation(name="discovered", version="2.0") + assert session.server_capabilities == ServerCapabilities(tools=types.ToolsCapability(list_changed=True)) + assert session.instructions == "hello" + assert session.initialize_result is None + assert isinstance(session.discover_result, types.DiscoverResult) + + @pytest.mark.anyio async def test_discover_reraises_unsupported_version_with_malformed_error_data() -> None: """SDK-defined: a -32022 reply whose `data` is not a valid diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 8baee128b5..51a026c138 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -76,7 +76,7 @@ async def test_with_mcpserver(mcpserver_server: MCPServer): async def test_server_is_running(mcpserver_server: MCPServer): """Test that the server is running and responding to requests.""" async with Client(mcpserver_server) as client: - assert client.initialize_result.capabilities.tools is not None + assert client.server_capabilities.tools is not None async def test_list_tools(mcpserver_server: MCPServer): diff --git a/tests/interaction/lowlevel/test_client_connect.py b/tests/interaction/lowlevel/test_client_connect.py index 19c0c6605b..111dd1cff9 100644 --- a/tests/interaction/lowlevel/test_client_connect.py +++ b/tests/interaction/lowlevel/test_client_connect.py @@ -144,8 +144,8 @@ async def test_prior_discover_populates_state_with_zero_connect_time_traffic() - ) as client, ): assert requests == [] - assert client.initialize_result.server_info == Implementation(name="cached-server", version="9.9.9") - assert client.initialize_result.capabilities.tools == ToolsCapability(list_changed=False) + assert client.server_info == Implementation(name="cached-server", version="9.9.9") + assert client.server_capabilities.tools == ToolsCapability(list_changed=False) await client.list_tools() assert [json.loads(r.content)["method"] for r in requests] == ["tools/list"] @@ -167,8 +167,8 @@ async def test_auto_mode_probes_server_discover_and_adopts_the_result() -> None: mounted_app(server, on_request=on_request) as (http, _), Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto") as client, ): - assert client.initialize_result.protocol_version == MODERN_VERSION - assert client.initialize_result.server_info.name == "discoverable" + assert client.protocol_version == MODERN_VERSION + assert client.server_info.name == "discoverable" await client.list_tools() bodies = [json.loads(r.content) for r in requests] @@ -211,7 +211,7 @@ async def discover(ctx: ServerRequestContext, params: types.RequestParams | None mounted_app(server, on_request=on_request) as (http, _), Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto") as client, ): - assert client.initialize_result.protocol_version == MODERN_VERSION + assert client.protocol_version == MODERN_VERSION assert calls == [MODERN_VERSION, MODERN_VERSION] assert [json.loads(r.content)["method"] for r in requests][:2] == ["server/discover", "server/discover"] @@ -300,8 +300,8 @@ async def scripted_transport() -> AsyncIterator[TransportStreams]: with anyio.fail_after(5): async with Client(scripted_transport(), mode="auto") as client: - assert client.initialize_result.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] - assert client.initialize_result.server_info.name == "legacy-only" + assert client.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert client.server_info.name == "legacy-only" assert methods_seen == ["server/discover", "initialize", "notifications/initialized"] diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py index f12671d935..8478831f31 100644 --- a/tests/interaction/lowlevel/test_completion.py +++ b/tests/interaction/lowlevel/test_completion.py @@ -123,7 +123,7 @@ async def test_complete_without_handler_is_method_not_found(connect: Connect) -> server = Server("incomplete") async with connect(server) as client: - assert client.initialize_result.capabilities.completions is None + assert client.server_capabilities.completions is None with pytest.raises(MCPError) as exc_info: await client.complete(PromptReference(name="anything"), argument={"name": "topic", "value": ""}) diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index 91adbf5611..d1f79c0cb7 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -60,7 +60,7 @@ async def test_initialize_returns_server_info(connect: Connect) -> None: ) async with connect(server) as client: - server_info = client.initialize_result.server_info + server_info = client.server_info assert server_info == snapshot( Implementation( @@ -78,10 +78,10 @@ async def test_initialize_returns_server_info(connect: Connect) -> None: async def test_initialize_returns_instructions(connect: Connect) -> None: """Instructions are returned when the server declares them and omitted when it does not.""" async with connect(Server("guided", instructions="Call the add tool.")) as client: - assert client.initialize_result.instructions == snapshot("Call the add tool.") + assert client.instructions == snapshot("Call the add tool.") async with connect(Server("unguided")) as client: - assert client.initialize_result.instructions is None + assert client.instructions is None @requirement("lifecycle:initialize:capabilities:from-handlers") @@ -137,7 +137,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar ) async with connect(server) as client: - capabilities = client.initialize_result.capabilities + capabilities = client.server_capabilities assert capabilities == snapshot( ServerCapabilities( @@ -155,7 +155,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar async def test_initialize_minimal_server_advertises_no_capabilities(connect: Connect) -> None: """A server with no feature handlers advertises no feature capabilities.""" async with connect(Server("bare")) as client: - capabilities = client.initialize_result.capabilities + capabilities = client.server_capabilities assert capabilities == snapshot(ServerCapabilities(experimental={})) diff --git a/tests/interaction/mcpserver/test_completion.py b/tests/interaction/mcpserver/test_completion.py index 7761066e94..30ff9613e3 100644 --- a/tests/interaction/mcpserver/test_completion.py +++ b/tests/interaction/mcpserver/test_completion.py @@ -32,7 +32,7 @@ async def complete( raise NotImplementedError async with connect(with_handler) as client: - assert client.initialize_result.capabilities.completions == CompletionsCapability() + assert client.server_capabilities.completions == CompletionsCapability() async with connect(MCPServer("plain")) as client: - assert client.initialize_result.capabilities.completions is None + assert client.server_capabilities.completions is None diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py index f3ee3f52e4..edbbc94467 100644 --- a/tests/interaction/mcpserver/test_context.py +++ b/tests/interaction/mcpserver/test_context.py @@ -52,7 +52,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: async with connect(mcp, logging_callback=collect) as client: result = await client.call_tool("narrate", {}) - advertised_logging = client.initialize_result.capabilities.logging + advertised_logging = client.server_capabilities.logging assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) assert received == snapshot( diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py index 8aac551c67..60a9b93981 100644 --- a/tests/interaction/transports/test_stdio.py +++ b/tests/interaction/transports/test_stdio.py @@ -90,7 +90,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: # Must exceed session time plus the patched PROCESS_TERMINATION_TIMEOUT (20s). with anyio.fail_after(30): async with Client(transport, logging_callback=collect) as client: - assert client.initialize_result.server_info.name == "stdio-echo" + assert client.server_info.name == "stdio-echo" result = await client.call_tool("echo", {"text": "across\nprocesses"}) errlog.seek(0) diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py index cb63e389ca..79aace2639 100644 --- a/tests/interaction/transports/test_streamable_http.py +++ b/tests/interaction/transports/test_streamable_http.py @@ -69,7 +69,7 @@ async def announce(ctx: Context) -> str: async def test_tool_call_over_streamable_http_with_json_responses() -> None: """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" async with connect_over_streamable_http(_smoke_server(), json_response=True) as client: - assert client.initialize_result.server_info.name == "smoke" + assert client.server_info.name == "smoke" result = await client.call_tool("echo", {"text": "as json"}) assert result == snapshot( diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index 5bac39dfee..a5388b17a4 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -105,7 +105,7 @@ async def elicitation_callback(context: ClientRequestContext, params: ElicitRequ async def test_basic_tools() -> None: """Test basic tool functionality.""" async with Client(basic_tool.mcp) as client: - assert client.initialize_result.capabilities.tools is not None + assert client.server_capabilities.tools is not None # Test sum tool tool_result = await client.call_tool("sum", {"a": 5, "b": 3}) @@ -123,7 +123,7 @@ async def test_basic_tools() -> None: async def test_basic_resources() -> None: """Test basic resource functionality.""" async with Client(basic_resource.mcp) as client: - assert client.initialize_result.capabilities.resources is not None + assert client.server_capabilities.resources is not None # Test document resource doc_content = await client.read_resource("file://documents/readme") @@ -145,7 +145,7 @@ async def test_basic_resources() -> None: async def test_basic_prompts() -> None: """Test basic prompt functionality.""" async with Client(basic_prompt.mcp) as client: - assert client.initialize_result.capabilities.prompts is not None + assert client.server_capabilities.prompts is not None # Test review_code prompt prompts = await client.list_prompts() @@ -216,7 +216,7 @@ async def progress_callback(progress: float, total: float | None, message: str | async def test_sampling() -> None: """Test sampling (LLM interaction) functionality.""" async with Client(sampling.mcp, sampling_callback=sampling_callback) as client: - assert client.initialize_result.capabilities.tools is not None + assert client.server_capabilities.tools is not None # Test sampling tool sampling_result = await client.call_tool("generate_poem", {"topic": "nature"}) @@ -286,8 +286,8 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] async def test_completion() -> None: """Test completion (autocomplete) functionality.""" async with Client(completion.mcp) as client: - assert client.initialize_result.capabilities.resources is not None - assert client.initialize_result.capabilities.prompts is not None + assert client.server_capabilities.resources is not None + assert client.server_capabilities.prompts is not None # Test resource completion completion_result = await client.complete( diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index d44a4df42f..a48bd7ae47 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -720,7 +720,7 @@ def get_text() -> str: mcp = MCPServer(resources=[resource]) async with Client(mcp) as client: - assert client.initialize_result.capabilities.resources is not None + assert client.server_capabilities.resources is not None resources = await client.list_resources() assert len(resources.resources) == 1 From cdfdfd11c72243cbde4c8c163a90f46a75864368 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 23 Jun 2026 11:48:01 +0000 Subject: [PATCH 12/22] migration.md: drop v2-only churn entries; document ctx.report_progress() preference - StreamableHTTPTransport.protocol_version section: attribute-only (the constructor param was v2-only churn, never on v1.x) - Delete ClientSession(protocol_version=) section (param never on v1.x) - Fix v1 surface reference: ClientSession.get_server_capabilities() (Client class did not exist in v1) - New section on handler progress reporting: ctx.report_progress() is dispatcher-agnostic; reading meta['progress_token'] + send_progress_notification is JSONRPC-specific and won't work on the in-process modern path - test_client_connect.py: pytest.fail -> raise NotImplementedError --- docs/migration.md | 16 +++++++++------- .../interaction/lowlevel/test_client_connect.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 65fb9d090b..cb2cc4dce4 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -159,9 +159,9 @@ The `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters have been re Note: `sse_client` retains its `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters — only the streamable HTTP transport changed. -### `protocol_version` removed from `StreamableHTTPTransport` and `streamable_http_client` +### `StreamableHTTPTransport.protocol_version` attribute removed -The `protocol_version` attribute on `StreamableHTTPTransport` and the `protocol_version` parameter on `streamable_http_client` have been removed. The transport no longer holds per-connection protocol state; era-dependent headers (e.g. `MCP-Protocol-Version`) are supplied per-message by the session, so the transport never needs to know the negotiated version. +The transport no longer holds per-connection protocol state; era-dependent headers (e.g. `MCP-Protocol-Version`) are now supplied per-message by the session. If you were reading `transport.protocol_version` to learn the negotiated version, read it from `session.initialize_result.protocol_version` instead. ### `terminate_windows_process` removed @@ -352,11 +352,7 @@ if result is not None: version = result.protocol_version ``` -The high-level `Client.initialize_result` returns the same `InitializeResult` but is non-nullable — initialization is guaranteed inside the context manager, so no `None` check is needed. This replaces v1's `Client.server_capabilities`; use `client.initialize_result.capabilities` instead. - -### `ClientSession(protocol_version=)` removed - -The `protocol_version` constructor parameter on `ClientSession` has been removed. To install a known protocol version without performing the `initialize` handshake (e.g. when reconnecting to an existing session), call `session.adopt(result)` after construction with a stored `InitializeResult`. +The high-level `Client.initialize_result` returns the same `InitializeResult` but is non-nullable — initialization is guaranteed inside the context manager, so no `None` check is needed. Like `session.initialize_result`, this replaces v1's `ClientSession.get_server_capabilities()`; use `client.initialize_result.capabilities` instead. ### `McpError` renamed to `MCPError` @@ -810,6 +806,12 @@ await session.send_progress_notification( ) ``` +### Handler progress reporting: prefer `ctx.report_progress()` over manual `progress_token` + +Reading `ctx.meta["progress_token"]` and calling `session.send_progress_notification(token, ...)` is specific to the JSON-RPC transport path. On the in-process modern path (`DirectDispatcher` / `Client(server)`), there is no wire token in `_meta`, so handlers that gate progress on the token's presence go silent. + +`ctx.report_progress(progress, total, message)` works on every dispatcher: it sends a progress notification when a token is present and routes the update through the dispatcher's progress channel otherwise, no-opping only when the caller did not request progress at all. `session.send_progress_notification(progress_token, ...)` is unchanged and still works on JSON-RPC transports for code that already holds a token. + ### `create_connected_server_and_client_session` removed The `create_connected_server_and_client_session` helper in `mcp.shared.memory` has been removed. Use `mcp.client.Client` instead — it accepts a `Server` or `MCPServer` instance directly and handles the in-memory transport and session setup for you. diff --git a/tests/interaction/lowlevel/test_client_connect.py b/tests/interaction/lowlevel/test_client_connect.py index 111dd1cff9..aa6625bc2e 100644 --- a/tests/interaction/lowlevel/test_client_connect.py +++ b/tests/interaction/lowlevel/test_client_connect.py @@ -244,7 +244,7 @@ def is_internal_error(exc: MCPError) -> bool: pytest.RaisesExc(MCPError, check=is_internal_error), flatten_subgroups=True ): # pragma: no branch async with Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto"): - pytest.fail("entering the Client should have raised") # pragma: no cover + raise NotImplementedError("entering the Client should have raised") # pragma: no cover assert [json.loads(r.content)["method"] for r in requests] == ["server/discover"] From b3263478ba24fe5971bbad2c0ee13d2fd0e03ee0 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 23 Jun 2026 16:38:57 +0000 Subject: [PATCH 13/22] Add named protocol-version scalars and replace tuple indexing shared/version.py gains four derived constants alongside the existing tuples: LATEST_PROTOCOL_VERSION (now derived here instead of a duplicate literal in types/_types.py), LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION, and OLDEST_SUPPORTED_VERSION. Call sites that previously wrote HANDSHAKE_PROTOCOL_VERSIONS[-1] / MODERN_PROTOCOL_VERSIONS[-1] / [0] now import the named scalar so the meaning is explicit at the use site and a future version bump is one edit. This also fixes a quiet drift: a handful of tests were passing LATEST_PROTOCOL_VERSION (now "2026-07-28") into InitializeRequest on the legacy handshake path; those now use LATEST_HANDSHAKE_VERSION. --- src/mcp/client/session.py | 11 ++++++-- src/mcp/server/connection.py | 8 ++---- src/mcp/shared/version.py | 12 ++++++++ src/mcp/types/_types.py | 7 +---- tests/client/test_session.py | 28 +++++++++---------- tests/interaction/_connect.py | 14 +++++----- .../transports/test_hosting_resume.py | 4 +-- tests/issues/test_192_request_id.py | 4 +-- tests/issues/test_552_windows_hang.py | 4 +-- tests/server/test_cancel_handling.py | 6 ++-- tests/server/test_connection.py | 10 +++---- tests/server/test_session.py | 9 +++--- tests/server/test_streamable_http_modern.py | 6 ++-- tests/shared/test_inbound.py | 14 +++++----- 14 files changed, 73 insertions(+), 64 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index cc3d7cbba0..94a1b2e59b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -27,7 +27,12 @@ from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.transport_context import TransportContext -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import ( + HANDSHAKE_PROTOCOL_VERSIONS, + LATEST_HANDSHAKE_VERSION, + LATEST_MODERN_VERSION, + MODERN_PROTOCOL_VERSIONS, +) from mcp.types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, @@ -333,7 +338,7 @@ async def initialize(self) -> types.InitializeResult: result = await self.send_request( types.InitializeRequest( params=types.InitializeRequestParams( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=self._build_capabilities(), client_info=self._client_info, ), @@ -413,7 +418,7 @@ async def probe(version: str) -> dict[str, Any]: return await self._dispatcher.send_raw_request("server/discover", params, opts) try: - raw = await probe(MODERN_PROTOCOL_VERSIONS[-1]) + raw = await probe(LATEST_MODERN_VERSION) except MCPError as e: if e.code != UNSUPPORTED_PROTOCOL_VERSION: raise diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index 06691535c0..933bd5c6b6 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -31,7 +31,7 @@ from mcp.shared.dispatcher import CallOptions, Outbound from mcp.shared.exceptions import MCPDeprecationWarning, NoBackChannelError from mcp.shared.peer import Meta, dump_params -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_HANDSHAKE_VERSION from mcp.types import ( ClientCapabilities, CreateMessageRequest, @@ -192,14 +192,12 @@ def for_loop( Not born-ready: `initialized` is set later by the kernel when `notifications/initialized` arrives. `protocol_version` is seeded from - the transport hint (or `HANDSHAKE_PROTOCOL_VERSIONS[-1]`) so it's never `None`; + the transport hint (or `LATEST_HANDSHAKE_VERSION`) so it's never `None`; the handshake overwrites it once negotiated. """ return cls( outbound, - protocol_version=protocol_version_hint - if protocol_version_hint is not None - else HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=protocol_version_hint if protocol_version_hint is not None else LATEST_HANDSHAKE_VERSION, session_id=session_id, ) diff --git a/src/mcp/shared/version.py b/src/mcp/shared/version.py index 2232c1d3c4..c5c2233274 100644 --- a/src/mcp/shared/version.py +++ b/src/mcp/shared/version.py @@ -35,6 +35,18 @@ Kept as the union for v1.x compatibility. """ +LATEST_PROTOCOL_VERSION: Final[str] = KNOWN_PROTOCOL_VERSIONS[-1] +"""Newest protocol revision this SDK speaks (any era).""" + +LATEST_HANDSHAKE_VERSION: Final[str] = HANDSHAKE_PROTOCOL_VERSIONS[-1] +"""Newest revision reachable via the ``initialize`` handshake; the client's offer and server's counter-offer default.""" + +LATEST_MODERN_VERSION: Final[str] = MODERN_PROTOCOL_VERSIONS[-1] +"""Newest per-request-envelope revision; the ``server/discover`` probe default.""" + +OLDEST_SUPPORTED_VERSION: Final[str] = HANDSHAKE_PROTOCOL_VERSIONS[0] +"""Oldest revision this SDK still negotiates via the ``initialize`` handshake.""" + def is_version_at_least(version: str, minimum: str) -> bool: """Return True if `version` is a known revision at least as new as `minimum`. diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index bf260b67b2..a08b4a3e59 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -20,14 +20,9 @@ from pydantic.alias_generators import to_camel from typing_extensions import NotRequired, TypedDict +from mcp.shared.version import LATEST_PROTOCOL_VERSION as LATEST_PROTOCOL_VERSION from mcp.types.jsonrpc import RequestId -LATEST_PROTOCOL_VERSION: Final[str] = "2026-07-28" -"""The newest protocol version this SDK can negotiate. - -See https://modelcontextprotocol.io/specification/latest. -""" - DEFAULT_NEGOTIATED_VERSION: Final[str] = "2025-03-26" """The default negotiated version of the Model Context Protocol when no version is specified. diff --git a/tests/client/test_session.py b/tests/client/test_session.py index dabc27d7c4..933171eabd 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -18,7 +18,7 @@ from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.transport_context import TransportContext -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, LATEST_HANDSHAKE_VERSION from mcp.types import ( CONNECTION_CLOSED, INTERNAL_ERROR, @@ -87,7 +87,7 @@ async def mock_server(): assert isinstance(request, InitializeRequest) result = InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities( logging=None, resources=None, @@ -140,7 +140,7 @@ async def message_handler( # pragma: no cover # Assert the result assert isinstance(result, InitializeResult) - assert result.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert result.protocol_version == LATEST_HANDSHAKE_VERSION assert isinstance(result.capabilities, ServerCapabilities) assert result.server_info == Implementation(name="mock-server", version="0.1.0") assert result.instructions == "The server instructions." @@ -171,7 +171,7 @@ async def mock_server(): received_client_info = request.params.client_info result = InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -228,7 +228,7 @@ async def mock_server(): received_client_info = request.params.client_info result = InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -278,7 +278,7 @@ async def mock_server(): assert isinstance(request, InitializeRequest) # Verify client offers the newest handshake protocol version - assert request.params.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert request.params.protocol_version == LATEST_HANDSHAKE_VERSION # Server responds with a supported older version result = InitializeResult( @@ -386,7 +386,7 @@ async def mock_server(): received_capabilities = request.params.capabilities result = InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -457,7 +457,7 @@ async def mock_server(): received_capabilities = request.params.capabilities result = InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -536,7 +536,7 @@ async def mock_server(): received_capabilities = request.params.capabilities result = InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -604,7 +604,7 @@ async def mock_server(): assert isinstance(request, InitializeRequest) result = InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=expected_capabilities, server_info=expected_server_info, instructions=expected_instructions, @@ -643,12 +643,12 @@ async def mock_server(): assert result.server_info == expected_server_info assert result.capabilities == expected_capabilities assert result.instructions == expected_instructions - assert result.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert result.protocol_version == LATEST_HANDSHAKE_VERSION # Era-neutral accessors are populated from the InitializeResult. assert session.server_info == expected_server_info assert session.server_capabilities == expected_capabilities assert session.instructions == expected_instructions - assert session.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert session.protocol_version == LATEST_HANDSHAKE_VERSION @pytest.mark.anyio @@ -671,7 +671,7 @@ async def mock_server(): assert isinstance(request, InitializeRequest) result = InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ) @@ -1354,7 +1354,7 @@ async def send_raw_request( self.calls.append((method, opts or {})) if method == "initialize": return InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities(), server_info=Implementation(name="mock-server", version="0.1.0"), ).model_dump(by_alias=True, mode="json", exclude_none=True) diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 654c74add2..a9383837d2 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -31,7 +31,7 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_HANDSHAKE_VERSION, MODERN_PROTOCOL_VERSIONS from mcp.types import ( ClientCapabilities, Implementation, @@ -70,7 +70,7 @@ def __call__( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - spec_version: str = HANDSHAKE_PROTOCOL_VERSIONS[-1], + spec_version: str = LATEST_HANDSHAKE_VERSION, ) -> AbstractAsyncContextManager[Client]: ... @@ -85,7 +85,7 @@ async def connect_in_memory( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - spec_version: str = HANDSHAKE_PROTOCOL_VERSIONS[-1], + spec_version: str = LATEST_HANDSHAKE_VERSION, ) -> AsyncIterator[Client]: """Yield a Client connected to the server over the in-memory transport. @@ -122,7 +122,7 @@ async def connect_over_streamable_http( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - spec_version: str = HANDSHAKE_PROTOCOL_VERSIONS[-1], + spec_version: str = LATEST_HANDSHAKE_VERSION, ) -> AsyncIterator[Client]: """Yield a Client connected to the server's streamable HTTP app, entirely in process. @@ -276,7 +276,7 @@ def base_headers(*, session_id: str | None = None) -> dict[str, str]: headers = { "accept": "application/json, text/event-stream", "content-type": "application/json", - "mcp-protocol-version": HANDSHAKE_PROTOCOL_VERSIONS[-1], + "mcp-protocol-version": LATEST_HANDSHAKE_VERSION, } if session_id is not None: headers["mcp-session-id"] = session_id @@ -286,7 +286,7 @@ def base_headers(*, session_id: str | None = None) -> dict[str, str]: def initialize_body(request_id: int = 1) -> dict[str, object]: """A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it.""" params = InitializeRequestParams( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ClientCapabilities(), client_info=Implementation(name="raw", version="0.0.0"), ) @@ -354,7 +354,7 @@ async def connect_over_sse( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, elicitation_callback: ElicitationFnT | None = None, - spec_version: str = HANDSHAKE_PROTOCOL_VERSIONS[-1], + spec_version: str = LATEST_HANDSHAKE_VERSION, ) -> AsyncIterator[Client]: """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" app, _ = build_sse_app(server) diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index 3fa0da44e8..77ef087cd6 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -21,7 +21,7 @@ from mcp.client.streamable_http import streamable_http_client from mcp.server.mcpserver import Context, MCPServer from mcp.shared.message import ClientMessageMetadata -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_HANDSHAKE_VERSION from mcp.types import ( CallToolRequest, CallToolRequestParams, @@ -431,7 +431,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: # The session id is only observable via the manager (the client transport does not expose it). (session_id,) = manager._server_instances http.headers["mcp-session-id"] = session_id - http.headers["mcp-protocol-version"] = HANDSHAKE_PROTOCOL_VERSIONS[-1] + http.headers["mcp-protocol-version"] = LATEST_HANDSHAKE_VERSION tg.cancel_scope.cancel() with anyio.fail_after(5): # pragma: no branch diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index de96dbe23a..9f935ae088 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -4,8 +4,8 @@ from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions from mcp.shared.message import SessionMessage +from mcp.shared.version import LATEST_HANDSHAKE_VERSION from mcp.types import ( - LATEST_PROTOCOL_VERSION, ClientCapabilities, Implementation, InitializeRequestParams, @@ -59,7 +59,7 @@ async def run_server(): id="init-1", method="initialize", params=InitializeRequestParams( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ClientCapabilities(), client_info=Implementation(name="test-client", version="1.0.0"), ).model_dump(by_alias=True, exclude_none=True), diff --git a/tests/issues/test_552_windows_hang.py b/tests/issues/test_552_windows_hang.py index e2210a6d65..82d4074e91 100644 --- a/tests/issues/test_552_windows_hang.py +++ b/tests/issues/test_552_windows_hang.py @@ -9,7 +9,7 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_HANDSHAKE_VERSION from mcp.types import InitializeResult @@ -33,7 +33,7 @@ async def test_initialize_succeeds_and_shutdown_returns_after_the_server_exits_m "jsonrpc": "2.0", "id": request["id"], "result": {{ - "protocolVersion": {json.dumps(HANDSHAKE_PROTOCOL_VERSIONS[-1])}, + "protocolVersion": {json.dumps(LATEST_HANDSHAKE_VERSION)}, "capabilities": {{}}, "serverInfo": {{"name": "test-server", "version": "1.0"}} }} diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 0744e63022..cc157247c9 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -7,8 +7,8 @@ from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage +from mcp.shared.version import LATEST_HANDSHAKE_VERSION from mcp.types import ( - LATEST_PROTOCOL_VERSION, CallToolRequest, CallToolRequestParams, CallToolResult, @@ -138,7 +138,7 @@ async def run_server(): id=1, method="initialize", params=InitializeRequestParams( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ClientCapabilities(), client_info=Implementation(name="test", version="1.0"), ).model_dump(by_alias=True, mode="json", exclude_none=True), @@ -212,7 +212,7 @@ async def run_server(): id=1, method="initialize", params=InitializeRequestParams( - protocol_version=LATEST_PROTOCOL_VERSION, + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ClientCapabilities(), client_info=Implementation(name="test", version="1.0"), ).model_dump(by_alias=True, mode="json", exclude_none=True), diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index 9c2a1f5c46..d0c4dfa559 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -19,7 +19,7 @@ from mcp.server.connection import Connection from mcp.shared.dispatcher import CallOptions from mcp.shared.exceptions import NoBackChannelError -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, @@ -40,7 +40,7 @@ ) _CLIENT_INFO = Implementation(name="t", version="0") -_MODERN = MODERN_PROTOCOL_VERSIONS[0] +_MODERN = LATEST_MODERN_VERSION class StubOutbound: @@ -115,10 +115,10 @@ def test_from_envelope_with_explicit_outbound_has_standalone_channel(): def test_for_loop_seeds_version_from_hint_or_latest_and_is_not_born_ready(): """SDK-defined: `for_loop` seeds `protocol_version` from the hint when given, - else `HANDSHAKE_PROTOCOL_VERSIONS[-1]`; the connection awaits the initialize handshake.""" + else `LATEST_HANDSHAKE_VERSION`; the connection awaits the initialize handshake.""" out = StubOutbound() conn = Connection.for_loop(out) - assert conn.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert conn.protocol_version == LATEST_HANDSHAKE_VERSION assert conn.has_standalone_channel is True assert not conn.initialized.is_set() assert conn.initialize_accepted is False @@ -229,7 +229,7 @@ async def test_send_request_validates_the_client_result_against_the_surface_sche async def test_send_request_passes_a_spec_valid_client_result(): """A spec-valid client result passes the surface gate and parses to the typed model.""" conn = Connection.for_loop(StubOutbound(result={"roots": [{"uri": "file:///ws"}]})) - assert conn.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert conn.protocol_version == LATEST_HANDSHAKE_VERSION result = await conn.send_request(ListRootsRequest()) assert isinstance(result, ListRootsResult) assert str(result.roots[0].uri) == "file:///ws" diff --git a/tests/server/test_session.py b/tests/server/test_session.py index f6f2a61e18..ea57441dcc 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -17,9 +17,8 @@ from mcp.server.session import ServerSession from mcp.shared.dispatcher import CallOptions from mcp.shared.message import ServerMessageMetadata -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION from mcp.types import ( - LATEST_PROTOCOL_VERSION, ClientCapabilities, Implementation, SamplingCapability, @@ -65,7 +64,7 @@ def _make_session( outbound: StubOutbound, *, capabilities: ClientCapabilities | None = None, - protocol_version: str = HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version: str = LATEST_HANDSHAKE_VERSION, ) -> ServerSession: """Single-channel session: the stub is both request and standalone outbound.""" client_info = Implementation(name="c", version="0") if capabilities is not None else None @@ -75,7 +74,7 @@ def _make_session( def _two_channel_session(request_ch: StubOutbound, standalone_ch: StubOutbound) -> ServerSession: """Distinct request/standalone outbounds so routing assertions can tell the channels apart.""" - conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None, outbound=standalone_ch) + conn = Connection.from_envelope(LATEST_HANDSHAKE_VERSION, None, None, outbound=standalone_ch) return ServerSession(request_ch, conn) @@ -193,7 +192,7 @@ async def test_send_request_passes_a_spec_valid_client_result(): async def test_send_request_skips_the_surface_gate_when_method_absent_at_version(): """Surface row absent for the connection's version: gate is bypassed and only `result_type` validates.""" - session = _make_session(StubOutbound(result={}), protocol_version=MODERN_PROTOCOL_VERSIONS[0]) + session = _make_session(StubOutbound(result={}), protocol_version=LATEST_MODERN_VERSION) result = await session.send_request(types.PingRequest(), types.EmptyResult) assert isinstance(result, types.EmptyResult) diff --git a/tests/server/test_streamable_http_modern.py b/tests/server/test_streamable_http_modern.py index 35ee17f3d6..92b0729601 100644 --- a/tests/server/test_streamable_http_modern.py +++ b/tests/server/test_streamable_http_modern.py @@ -20,7 +20,7 @@ from mcp.shared.exceptions import NoBackChannelError from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.transport_context import TransportContext -from mcp.shared.version import MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_MODERN_VERSION from mcp.types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, @@ -56,7 +56,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: return httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://testserver", - headers={MCP_PROTOCOL_VERSION_HEADER: MODERN_PROTOCOL_VERSIONS[0]}, + headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}, ) @@ -114,7 +114,7 @@ async def test_handle_modern_request_returns_transport_security_error_response() def _list_tools_body() -> dict[str, Any]: """A minimal valid 2026-07-28 ``tools/list`` request body, including the required ``_meta`` envelope.""" meta = { - PROTOCOL_VERSION_META_KEY: MODERN_PROTOCOL_VERSIONS[0], + PROTOCOL_VERSION_META_KEY: LATEST_MODERN_VERSION, CLIENT_INFO_META_KEY: {"name": "raw", "version": "0.0.0"}, CLIENT_CAPABILITIES_META_KEY: {}, } diff --git a/tests/shared/test_inbound.py b/tests/shared/test_inbound.py index eaa2a59bf4..dcf0490c83 100644 --- a/tests/shared/test_inbound.py +++ b/tests/shared/test_inbound.py @@ -17,7 +17,7 @@ InboundModernRoute, classify_inbound_request, ) -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION, MODERN_PROTOCOL_VERSIONS from mcp.types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, @@ -33,7 +33,7 @@ UNSUPPORTED_PROTOCOL_VERSION, ) -MODERN = MODERN_PROTOCOL_VERSIONS[0] +MODERN = LATEST_MODERN_VERSION """The modern protocol-version string, read from the registry — never inlined here.""" CLIENT_INFO = {"name": "t", "version": "0"} @@ -107,18 +107,18 @@ def test_envelope_rung_rejects_non_mapping_shapes(body: dict[str, Any]) -> None: def test_version_rung_rejects_unsupported_with_data_shape() -> None: """Spec-mandated: an envelope version outside the modern set rejects with the ``supported``/``requested`` data.""" rejection = assert_rejected( - classify_inbound_request(envelope(version=HANDSHAKE_PROTOCOL_VERSIONS[-1])), + classify_inbound_request(envelope(version=LATEST_HANDSHAKE_VERSION)), UNSUPPORTED_PROTOCOL_VERSION, ) assert rejection.data == { "supported": list(MODERN_PROTOCOL_VERSIONS), - "requested": HANDSHAKE_PROTOCOL_VERSIONS[-1], + "requested": LATEST_HANDSHAKE_VERSION, } def test_version_rung_data_reflects_supplied_supported_list() -> None: """SDK-defined: the caller-supplied ``supported_modern_versions`` is what rejection ``data.supported`` echoes.""" - custom = (HANDSHAKE_PROTOCOL_VERSIONS[-1],) + custom = (LATEST_HANDSHAKE_VERSION,) rejection = assert_rejected( classify_inbound_request(envelope(), supported_modern_versions=custom), UNSUPPORTED_PROTOCOL_VERSION, @@ -144,7 +144,7 @@ def test_header_rung_passes_when_header_matches_envelope() -> None: @pytest.mark.parametrize( "headers", [ - pytest.param({MCP_PROTOCOL_VERSION_HEADER: HANDSHAKE_PROTOCOL_VERSIONS[-1]}, id="mismatch"), + pytest.param({MCP_PROTOCOL_VERSION_HEADER: LATEST_HANDSHAKE_VERSION}, id="mismatch"), pytest.param({}, id="header-absent"), ], ) @@ -176,7 +176,7 @@ def test_ladder_first_failure_wins() -> None: """Spec-mandated: rungs evaluate in order — header-mismatch and version-unsupported would both fail; the header rung fires first so an inconsistent client is told it disagrees with itself rather than that its body version is unsupported.""" - body = envelope(version=HANDSHAKE_PROTOCOL_VERSIONS[-1]) + body = envelope(version=LATEST_HANDSHAKE_VERSION) result = classify_inbound_request(body, headers={MCP_PROTOCOL_VERSION_HEADER: MODERN}) assert_rejected(result, HEADER_MISMATCH) From bdcdeb008538d665925ebceaa60d6cb0b7925989 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 23 Jun 2026 16:39:19 +0000 Subject: [PATCH 14/22] serve_one returns dict; Client accessor and re-entry guard fixes serve_one now returns the kernel's dict result and lets exceptions propagate; the modern HTTP entry composes to_jsonrpc_response around it directly, and modern_on_request no longer round-trips through JSONRPCError on the in-process path. The dctx.request_id assert drops. Client: the protocol_version/server_info/server_capabilities accessors now raise the same RuntimeError as .session instead of bare assert. __aenter__ publishes self._session only after the handshake succeeds, and a separate _entered flag makes the one-shot re-entry guard explicit. _drop_notify is renamed to say which direction it sinks. Adds a TODO at mode='legacy' for the eventual default flip and a TODO above the accessors for the connected-view shape. migration.md: point at the era-neutral session/client accessors instead of initialize_result, and stop listing client.instructions as non-nullable. The legacy-mode connect test now passes mode='legacy' explicitly so it asserts what that mode does, not what the default is. --- docs/migration.md | 24 +++--- src/mcp/client/client.py | 66 ++++++++++----- src/mcp/server/_streamable_http_modern.py | 6 +- src/mcp/server/runner.py | 64 +++++++-------- tests/client/test_client.py | 25 +++++- tests/interaction/_requirements.py | 2 +- .../lowlevel/test_client_connect.py | 14 ++-- tests/server/test_runner.py | 80 +++++++++---------- 8 files changed, 158 insertions(+), 123 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index cb2cc4dce4..acc0fd1078 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -161,7 +161,7 @@ Note: `sse_client` retains its `headers`, `timeout`, `sse_read_timeout`, and `au ### `StreamableHTTPTransport.protocol_version` attribute removed -The transport no longer holds per-connection protocol state; era-dependent headers (e.g. `MCP-Protocol-Version`) are now supplied per-message by the session. If you were reading `transport.protocol_version` to learn the negotiated version, read it from `session.initialize_result.protocol_version` instead. +The transport no longer holds per-connection protocol state; era-dependent headers (e.g. `MCP-Protocol-Version`) are now supplied per-message by the session. If you were reading `transport.protocol_version` to learn the negotiated version, read `session.protocol_version` (or `client.protocol_version` on the high-level `Client`) instead. ### `terminate_windows_process` removed @@ -330,9 +330,9 @@ result = await session.list_resources(params=PaginatedRequestParams(cursor="next result = await session.list_tools(params=PaginatedRequestParams(cursor="next_page_token")) ``` -### `ClientSession.get_server_capabilities()` replaced by `initialize_result` property +### `ClientSession.get_server_capabilities()` replaced by era-neutral accessors -`ClientSession` now stores the full `InitializeResult` via an `initialize_result` property. This provides access to `server_info`, `capabilities`, `instructions`, and the negotiated `protocol_version` through a single property. The `get_server_capabilities()` method has been removed. +`ClientSession` now exposes the negotiated server metadata as properties: `server_capabilities`, `server_info`, `instructions`, and `protocol_version`. These are populated by whichever connection step ran (`initialize()` for ≤2025-11-25 servers, `discover()` for 2026-07-28+). The `get_server_capabilities()` method has been removed. **Before (v1):** @@ -344,15 +344,15 @@ capabilities = session.get_server_capabilities() **After (v2):** ```python -result = session.initialize_result -if result is not None: - capabilities = result.capabilities - server_info = result.server_info - instructions = result.instructions - version = result.protocol_version +capabilities = session.server_capabilities +server_info = session.server_info +instructions = session.instructions +version = session.protocol_version ``` -The high-level `Client.initialize_result` returns the same `InitializeResult` but is non-nullable — initialization is guaranteed inside the context manager, so no `None` check is needed. Like `session.initialize_result`, this replaces v1's `ClientSession.get_server_capabilities()`; use `client.initialize_result.capabilities` instead. +The raw handshake result is also retained as `session.initialize_result` (legacy path) or `session.discover_result` (modern path) — exactly one is non-`None`. + +On the high-level `Client`, `client.server_capabilities`, `client.server_info`, and `client.protocol_version` are non-nullable inside the context manager. `client.instructions` remains `str | None` since the server may omit it. ### `McpError` renamed to `MCPError` @@ -770,9 +770,9 @@ async def my_tool(ctx: Context[MyLifespanState]) -> str: ... ### Version constants -`SUPPORTED_PROTOCOL_VERSIONS` is deprecated — it's now the union of `HANDSHAKE_PROTOCOL_VERSIONS` (initialize-handshake versions) and `MODERN_PROTOCOL_VERSIONS` (per-request-envelope versions). If you were using it to mean "versions the initialize handshake accepts", switch to `HANDSHAKE_PROTOCOL_VERSIONS`. +`SUPPORTED_PROTOCOL_VERSIONS` is deprecated — it's now the union of `HANDSHAKE_PROTOCOL_VERSIONS` (initialize-handshake versions) and `MODERN_PROTOCOL_VERSIONS` (per-request-envelope versions). If you were using it to mean "versions the initialize handshake accepts", switch to `HANDSHAKE_PROTOCOL_VERSIONS`. Named scalars derived from these tuples are now exported alongside them — `LATEST_HANDSHAKE_VERSION`, `LATEST_MODERN_VERSION`, `OLDEST_SUPPORTED_VERSION` — so prefer those over indexing the tuples directly. -`LATEST_PROTOCOL_VERSION` now reflects the newest protocol revision the SDK supports (`2026-07-28`). Code that used it to mean "the version `.initialize()` offers" should switch to `HANDSHAKE_PROTOCOL_VERSIONS[-1]`. +`LATEST_PROTOCOL_VERSION` now reflects the newest protocol revision the SDK supports (`2026-07-28`). Code that used it to mean "the version `.initialize()` offers" should switch to `LATEST_HANDSHAKE_VERSION`. ### `ProgressContext` and `progress()` context manager removed diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 8bfeeb6b9a..0248618712 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -5,7 +5,7 @@ from collections.abc import Mapping from contextlib import AsyncExitStack from dataclasses import KW_ONLY, dataclass, field -from typing import Any, Literal +from typing import Any, Literal, TypeVar import anyio from typing_extensions import deprecated @@ -48,6 +48,20 @@ initialize), or a modern protocol-version string (adopt directly). The ``str`` arm is for forward-compat; ``Client.__post_init__`` rejects anything outside that set at construction.""" +_T = TypeVar("_T") + + +def _connected(value: _T | None) -> _T: + """Narrow a post-handshake session attribute from ``T | None`` to ``T``. + + ``Client.__aenter__`` only assigns ``_session`` after the handshake succeeds, so inside + ``async with Client(...)`` these attributes are always populated; the ``.session`` gate + raises before this is reached otherwise. The guard exists for pyright, not runtime. + """ + if value is None: # pragma: no cover + raise RuntimeError("Client must be used within an async context manager") + return value + def _synthesize_discover(protocol_version: str) -> types.DiscoverResult: return types.DiscoverResult( @@ -60,11 +74,14 @@ def _synthesize_discover(protocol_version: str) -> types.DiscoverResult: ) -async def _drop_notify(_dctx: Any, _method: str, _params: Mapping[str, Any] | None) -> None: - """Server-side ``OnNotify`` for the modern in-process path: client→server notifications are dropped. +async def _no_inbound_client_notifications(_dctx: Any, _method: str, _params: Mapping[str, Any] | None) -> None: + """Server-side inbound ``OnNotify`` for the modern in-process path — receives nothing. - The per-request driver (`serve_one`) has no notification dispatch table; progress and - cancellation travel via `CallOptions` on the `DirectDispatcher`, not as JSON-RPC notifies. + At 2026-07-28 the spec defines no client→server notifications: ``initialized`` and + ``roots/list_changed`` are removed, and cancellation is structural (anyio scope cancel + through the direct await, not a notify). Server→client notifications (progress, log + messages) flow the other way via the per-request ``DispatchContext`` into the client's + callbacks, and are not seen here. """ @@ -127,6 +144,8 @@ async def main(): client_info: Implementation | None = None """Client implementation info to send to server.""" + # TODO(maxisbey): flip default to 'auto' once the in-proc test suite is era-decoupled + # and the probe-timeout fallback is transport-aware (stdio→fallback / HTTP→reject). mode: ConnectMode = "legacy" """'legacy' performs the initialize handshake. 'auto' probes server/discover and falls back to initialize() on legacy servers. A modern protocol-version string (e.g. '2026-07-28') adopts that version directly without @@ -139,6 +158,7 @@ async def main(): elicitation_callback: ElicitationFnT | None = None """Callback for handling elicitation requests.""" + _entered: bool = field(init=False, default=False) _session: ClientSession | None = field(init=False, default=None) _exit_stack: AsyncExitStack | None = field(init=False, default=None) _transport: Transport | None = field(init=False, default=None) @@ -175,7 +195,7 @@ async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession: tg = await exit_stack.enter_async_context(anyio.create_task_group()) exit_stack.callback(server_disp.close) on_request = modern_on_request(self._inproc_server, lifespan_state, raise_exceptions=self.raise_exceptions) - await tg.start(server_disp.run, on_request, _drop_notify) + await tg.start(server_disp.run, on_request, _no_inbound_client_notifications) dispatcher = client_disp read_stream = write_stream = None else: @@ -201,27 +221,31 @@ async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession: async def __aenter__(self) -> Client: """Enter the async context manager.""" - if self._session is not None: + if self._entered: raise RuntimeError("Client is already entered; cannot reenter") + self._entered = True async with AsyncExitStack() as exit_stack: session = await self._build_session(exit_stack) - self._session = await exit_stack.enter_async_context(session) + session = await exit_stack.enter_async_context(session) if self.mode == "legacy": - await self._session.initialize() + await session.initialize() elif self.mode == "auto": try: - await self._session.discover() + await session.discover() except MCPError as e: if e.code in (METHOD_NOT_FOUND, REQUEST_TIMEOUT): - await self._session.initialize() + await session.initialize() else: raise else: - self._session.adopt(self.prior_discover or _synthesize_discover(self.mode)) + session.adopt(self.prior_discover or _synthesize_discover(self.mode)) - # Transfer ownership to self for __aexit__ to handle + # Only publish the session after the handshake succeeds, so `_session is not None` + # implies the protocol_version/server_info/server_capabilities are populated. If the + # handshake raised above, the local exit_stack unwinds the transport for us. + self._session = session self._exit_stack = exit_stack.pop_all() return self @@ -244,26 +268,24 @@ def session(self) -> ClientSession: raise RuntimeError("Client must be used within an async context manager") return self._session + # TODO(maxisbey): the by-construction shape is for __aenter__ to return a connected-view + # type whose protocol_version/server_info/server_capabilities are non-Optional fields, + # eliminating these guards (and the one in .session). Same family as resolving the + # transport/connector at __post_init__ so the Optional internal fields disappear. @property def protocol_version(self) -> str: """Negotiated protocol version (set by initialize/discover/adopt during ``__aenter__``).""" - version = self.session.protocol_version - assert version is not None - return version + return _connected(self.session.protocol_version) @property def server_info(self) -> Implementation: """Server name/version (set by initialize/discover/adopt during ``__aenter__``).""" - info = self.session.server_info - assert info is not None - return info + return _connected(self.session.server_info) @property def server_capabilities(self) -> ServerCapabilities: """Server capabilities (set by initialize/discover/adopt during ``__aenter__``).""" - caps = self.session.server_capabilities - assert caps is not None - return caps + return _connected(self.session.server_capabilities) @property def instructions(self) -> str | None: diff --git a/src/mcp/server/_streamable_http_modern.py b/src/mcp/server/_streamable_http_modern.py index 732656fd8b..07e6cfd5a0 100644 --- a/src/mcp/server/_streamable_http_modern.py +++ b/src/mcp/server/_streamable_http_modern.py @@ -25,7 +25,7 @@ from starlette.types import Receive, Scope, Send from mcp.server.connection import Connection -from mcp.server.runner import serve_one +from mcp.server.runner import serve_one, to_jsonrpc_response from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings from mcp.shared.dispatcher import CallOptions from mcp.shared.exceptions import NoBackChannelError @@ -193,5 +193,7 @@ async def handle_modern_request( request_id=req.id, message_metadata=ServerMessageMetadata(request_context=request), ) - msg = await serve_one(app, dctx, req.method, req.params, connection=connection, lifespan_state=lifespan_state) + msg = await to_jsonrpc_response( + req.id, serve_one(app, dctx, req.method, req.params, connection=connection, lifespan_state=lifespan_state) + ) await _write(msg, scope, receive, send) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 8b71edb69f..b162ca587a 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -36,7 +36,7 @@ from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher, handler_exception_to_error_data from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.transport_context import TransportContext -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION from mcp.types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, @@ -187,13 +187,12 @@ async def to_jsonrpc_response( ) -> JSONRPCResponse | JSONRPCError: """Await ``coro`` and wrap its outcome as the JSON-RPC reply for ``request_id``. - The exception-to-wire boundary for the request-per-call drivers - (`serve_one`, the modern HTTP entry). `MCPError` and `ValidationError` + The exception-to-wire boundary for the modern HTTP entry, which composes + this around `serve_one` directly. `MCPError` and `ValidationError` map via the shared `handler_exception_to_error_data` ladder; any other exception is logged and surfaced as `INTERNAL_ERROR` so handler internals never reach the wire. Set ``raise_unhandled`` to let unmapped exceptions - propagate instead of being sanitized — used by the in-process test path so - handler tracebacks reach the caller. + propagate instead of being sanitized. """ try: result = await coro @@ -422,7 +421,7 @@ def _negotiate_initialize(params: Mapping[str, Any] | None) -> tuple[InitializeR """Validate `initialize` params and pick the protocol version.""" init = InitializeRequestParams.model_validate(params or {}, by_name=False) requested = init.protocol_version - negotiated = requested if requested in HANDSHAKE_PROTOCOL_VERSIONS else HANDSHAKE_PROTOCOL_VERSIONS[-1] + negotiated = requested if requested in HANDSHAKE_PROTOCOL_VERSIONS else LATEST_HANDSHAKE_VERSION return init, negotiated def _handle_initialize(self, params: Mapping[str, Any] | None) -> InitializeResult: @@ -508,25 +507,21 @@ async def serve_one( *, connection: Connection, lifespan_state: LifespanT, - raise_exceptions: bool = False, -) -> JSONRPCResponse | JSONRPCError: - """Handle a single request ``(method, params)`` and return its JSON-RPC reply. +) -> dict[str, Any]: + """Handle a single request ``(method, params)`` and return its result dict. The single-exchange driver: builds the kernel, runs `on_request` once under - `dctx`, maps the outcome to a `JSONRPCResponse` / `JSONRPCError` via - `to_jsonrpc_response`, and tears down `connection.exit_stack` (shielded) on - the way out. The entry constructs the (born-ready) `Connection` and the - `dctx`; this only consumes them. ``raise_exceptions`` lets unmapped handler - exceptions propagate instead of being sanitized to `INTERNAL_ERROR`. + `dctx`, and tears down `connection.exit_stack` (shielded) on the way out. + The entry constructs the (born-ready) `Connection` and the `dctx`; this + only consumes them. + + Raises whatever the handler chain raises (`MCPError` / `ValidationError` / + unmapped); callers own the exception-to-wire mapping. The HTTP entry + composes this with `to_jsonrpc_response`. """ runner = ServerRunner(server, connection, lifespan_state) try: - # Single-exchange driver only handles requests; both entries populate `request_id`. - # TODO(L54): drop once `DispatchContext` is split so `OnRequest` carries a non-Optional id. - assert dctx.request_id is not None - return await to_jsonrpc_response( - dctx.request_id, runner.on_request(dctx, method, params), raise_unhandled=raise_exceptions - ) + return await runner.on_request(dctx, method, params) finally: await aclose_shielded(connection) @@ -540,7 +535,9 @@ def modern_on_request( in-process server on the modern per-request-envelope path (each request carries protocol version, client info, and capabilities in `params._meta`; no `initialize` handshake). ``raise_exceptions`` lets unmapped handler - exceptions propagate to the caller for debuggable in-process testing. + exceptions propagate to the caller for debuggable in-process testing; + otherwise they are sanitized to `MCPError(INTERNAL_ERROR)` so the in-process + path matches the wire path's leak guard. """ async def handle( @@ -548,21 +545,20 @@ async def handle( ) -> dict[str, Any]: meta = (params or {}).get("_meta", {}) connection = Connection.from_envelope( - meta.get(PROTOCOL_VERSION_META_KEY, MODERN_PROTOCOL_VERSIONS[-1]), + meta.get(PROTOCOL_VERSION_META_KEY, LATEST_MODERN_VERSION), meta.get(CLIENT_INFO_META_KEY), meta.get(CLIENT_CAPABILITIES_META_KEY), ) - msg = await serve_one( - server, - dctx, - method, - params, - connection=connection, - lifespan_state=lifespan_state, - raise_exceptions=raise_exceptions, - ) - if isinstance(msg, JSONRPCError): - raise MCPError(code=msg.error.code, message=msg.error.message, data=msg.error.data) - return msg.result + try: + return await serve_one(server, dctx, method, params, connection=connection, lifespan_state=lifespan_state) + except (MCPError, ValidationError): + # DirectDispatcher's ladder maps these onward; this layer only owns the raise_exceptions + # decision for unmapped exceptions, which DirectDispatcher would otherwise leak via str(exc). + raise + except Exception: + if raise_exceptions: + raise + logger.exception("request handler raised") + raise MCPError(code=INTERNAL_ERROR, message="Internal server error") from None return handle diff --git a/tests/client/test_client.py b/tests/client/test_client.py index e6464fde11..991cd1e5e1 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -20,7 +20,7 @@ from mcp.server.mcpserver import MCPServer from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_HANDSHAKE_VERSION from mcp.types import ( CallToolResult, EmptyResult, @@ -122,7 +122,7 @@ async def test_client_is_initialized(app: MCPServer): async def test_client_exposes_negotiated_protocol_version(app: MCPServer): """The negotiated protocol version is readable after initialization.""" async with Client(app) as client: - assert client.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert client.protocol_version == LATEST_HANDSHAKE_VERSION async def test_client_with_simple_server(simple_server: Server): @@ -223,6 +223,23 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ assert str(exc_info.value.__cause__) == "boom" +async def test_raise_exceptions_false_sanitizes_handler_error_on_modern_inproc_path(): + """`raise_exceptions=False` (the default) on the modern in-process path: an + unmapped handler exception is sanitized to an opaque `INTERNAL_ERROR` so the + in-process path matches the wire path's leak guard.""" + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + raise ValueError("boom") + + server = Server("test", on_call_tool=handle_call_tool) + async with Client(server, mode="2026-07-28", raise_exceptions=False) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("explode", {}) + assert exc_info.value.error.code == types.INTERNAL_ERROR + assert exc_info.value.error.message == "Internal server error" + assert exc_info.value.__cause__ is None + + async def test_get_prompt(app: MCPServer): """Test getting a prompt.""" async with Client(app) as client: @@ -421,7 +438,7 @@ async def scripted_server(streams: MessageStream) -> None: await server_write.send(SessionMessage(types.JSONRPCError(jsonrpc="2.0", id=frame.id, error=error))) elif frame.method == "initialize": # pragma: no branch result = types.InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities(), server_info=types.Implementation(name="legacy-only", version="0.0.1"), ) @@ -447,7 +464,7 @@ async def scripted_transport() -> AsyncIterator[TransportStreams]: with anyio.fail_after(5): async with Client(scripted_transport(), mode="auto") as client: - assert client.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert client.protocol_version == LATEST_HANDSHAKE_VERSION assert client.server_info.name == "legacy-only" assert methods_seen == ["server/discover", "initialize", "notifications/initialized"] diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index afb9e2b8a1..8714977824 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -448,7 +448,7 @@ def __post_init__(self) -> None: "lifecycle:mode:legacy-never-probes": Requirement( source="sdk", behavior=( - "A Client constructed with mode='legacy' (the default) sends initialize as its first request " + "A Client constructed with mode='legacy' sends initialize as its first request " "and never sends server/discover." ), added_in="2026-07-28", diff --git a/tests/interaction/lowlevel/test_client_connect.py b/tests/interaction/lowlevel/test_client_connect.py index aa6625bc2e..9ba3aa65e9 100644 --- a/tests/interaction/lowlevel/test_client_connect.py +++ b/tests/interaction/lowlevel/test_client_connect.py @@ -27,7 +27,7 @@ from mcp.server import Server, ServerRequestContext from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION, MODERN_PROTOCOL_VERSIONS from mcp.types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, @@ -51,7 +51,7 @@ pytestmark = pytest.mark.anyio -MODERN_VERSION = "2026-07-28" +MODERN_VERSION = LATEST_MODERN_VERSION def _tools_server(name: str = "negotiator") -> Server: @@ -77,15 +77,15 @@ async def on_request(request: httpx.Request) -> None: @requirement("lifecycle:mode:legacy-never-probes") async def test_legacy_mode_sends_initialize_and_never_probes_discover() -> None: - """`Client(server)` (mode defaults to 'legacy') opens with `initialize` and never sends `server/discover`. + """`Client(server, mode='legacy')` opens with `initialize` and never sends `server/discover`. - Requirement `lifecycle:mode:legacy-never-probes` (sdk-defined): the default mode must remain + Requirement `lifecycle:mode:legacy-never-probes` (sdk-defined): ``mode='legacy'`` must remain byte-identical to the pre-2026 client so a 2025-era server never observes modern vocabulary. """ recording = RecordingTransport(InMemoryTransport(_tools_server())) with anyio.fail_after(5): - async with Client(recording) as client: + async with Client(recording, mode="legacy") as client: await client.list_tools() sent = [m.message for m in recording.sent] @@ -273,7 +273,7 @@ async def scripted_server(streams: MessageStream) -> None: await server_write.send(SessionMessage(JSONRPCError(jsonrpc="2.0", id=frame.id, error=error))) elif isinstance(frame, JSONRPCRequest) and frame.method == "initialize": result = InitializeResult( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ServerCapabilities(), server_info=Implementation(name="legacy-only", version="0.0.1"), ) @@ -300,7 +300,7 @@ async def scripted_transport() -> AsyncIterator[TransportStreams]: with anyio.fail_after(5): async with Client(scripted_transport(), mode="auto") as client: - assert client.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert client.protocol_version == LATEST_HANDSHAKE_VERSION assert client.server_info.name == "legacy-only" assert methods_seen == ["server/discover", "initialize", "notifications/initialized"] diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 2bf7dae983..8b1eb22105 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -39,7 +39,7 @@ from mcp.shared.message import MessageMetadata from mcp.shared.peer import dump_params from mcp.shared.transport_context import TransportContext -from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS +from mcp.shared.version import LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION, OLDEST_SUPPORTED_VERSION from mcp.types import ( INTERNAL_ERROR, INVALID_PARAMS, @@ -70,7 +70,7 @@ def _initialize_params() -> dict[str, Any]: return InitializeRequestParams( - protocol_version=HANDSHAKE_PROTOCOL_VERSIONS[-1], + protocol_version=LATEST_HANDSHAKE_VERSION, capabilities=ClientCapabilities(), client_info=Implementation(name="test-client", version="1.0"), ).model_dump(by_alias=True, exclude_none=True) @@ -167,7 +167,7 @@ async def test_runner_handles_initialize_and_populates_connection(server: SrvT): assert "tools" in result["capabilities"] assert runner.connection.client_params is not None assert runner.connection.client_params.client_info.name == "test-client" - assert runner.connection.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert runner.connection.protocol_version == LATEST_HANDSHAKE_VERSION assert runner.connection.initialize_accepted is True @@ -244,7 +244,7 @@ async def test_runner_routes_to_handler_and_builds_context(server: SrvT): assert isinstance(ctx.session, ServerSession) assert ctx.session.protocol_version == runner.connection.protocol_version assert ctx.request_id is not None - assert ctx.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert ctx.protocol_version == LATEST_HANDSHAKE_VERSION @pytest.mark.anyio @@ -289,7 +289,7 @@ async def test_runner_rejects_snake_case_initialize_params(server: SrvT): """Inbound wire payloads validate alias-only; Python field names are not accepted (`protocol_version` must arrive as `protocolVersion`).""" snake = { - "protocol_version": LATEST_PROTOCOL_VERSION, + "protocol_version": LATEST_HANDSHAKE_VERSION, "capabilities": {}, "client_info": {"name": "c", "version": "0"}, } @@ -814,7 +814,7 @@ async def test_runner_with_born_ready_connection_skips_init_gate(server: SrvT): """A `Connection.from_envelope` connection is born ready: the kernel's init-gate is open without any handshake. The kernel is mode-agnostic - the same `on_request` reads `connection.initialize_accepted` as a fact.""" - born_ready = Connection.from_envelope(HANDSHAKE_PROTOCOL_VERSIONS[-1], None, None) + born_ready = Connection.from_envelope(LATEST_HANDSHAKE_VERSION, None, None) async with connected_runner(server, initialized=False, connection=born_ready) as (client, runner): assert runner.connection.initialize_accepted is True assert runner.connection.initialized.is_set() @@ -847,7 +847,7 @@ async def greet(ctx: Ctx, params: GreetParams) -> dict[str, Any]: @pytest.mark.anyio async def test_runner_spec_method_with_invalid_params_is_invalid_params_at_the_negotiated_version(server: SrvT): async with connected_runner(server) as (client, runner): - assert runner.connection.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert runner.connection.protocol_version == LATEST_HANDSHAKE_VERSION with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/call", {"name": 42}) assert exc.value.error.code == INVALID_PARAMS @@ -926,9 +926,9 @@ async def discover(ctx: Ctx, params: RequestParams) -> Any: async def test_on_request_rejects_initialize_at_modern_version_with_method_not_found(server: SrvT): """Spec-mandated: `initialize` has no `CLIENT_REQUESTS` row at the modern version; kernel dispatch (not the inbound classifier) rejects it.""" - born_ready = Connection.from_envelope(MODERN_PROTOCOL_VERSIONS[0], None, None) + born_ready = Connection.from_envelope(LATEST_MODERN_VERSION, None, None) async with connected_runner(server, initialized=False, connection=born_ready) as (client, runner): - assert runner.connection.protocol_version == MODERN_PROTOCOL_VERSIONS[0] + assert runner.connection.protocol_version == LATEST_MODERN_VERSION with pytest.raises(MCPError) as exc: await client.send_raw_request("initialize", _initialize_params()) assert exc.value.error.code == METHOD_NOT_FOUND @@ -943,7 +943,7 @@ async def echo(ctx: Ctx, params: RequestParams) -> dict[str, Any]: return {"echoed": True} server.add_request_handler("myorg/echo", RequestParams, echo) - born_ready = Connection.from_envelope(MODERN_PROTOCOL_VERSIONS[0], None, None) + born_ready = Connection.from_envelope(LATEST_MODERN_VERSION, None, None) async with connected_runner(server, initialized=False, connection=born_ready) as (client, _): result = await client.send_raw_request("myorg/echo", None) assert result == {"echoed": True} @@ -997,7 +997,7 @@ async def list_tools(ctx: Ctx, params: PaginatedRequestParams | None) -> ListToo @pytest.mark.anyio async def test_runner_initialize_echoes_supported_version_and_falls_back_to_latest(server: SrvT): - oldest = HANDSHAKE_PROTOCOL_VERSIONS[0] + oldest = OLDEST_SUPPORTED_VERSION async with connected_runner(server, initialized=False) as (client, _): params = {**_initialize_params(), "protocolVersion": oldest} result = await client.send_raw_request("initialize", params) @@ -1005,7 +1005,7 @@ async def test_runner_initialize_echoes_supported_version_and_falls_back_to_late async with connected_runner(server, initialized=False) as (client, _): params = {**_initialize_params(), "protocolVersion": "1999-01-01"} result = await client.send_raw_request("initialize", params) - assert result["protocolVersion"] == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert result["protocolVersion"] == LATEST_HANDSHAKE_VERSION @pytest.mark.anyio @@ -1351,36 +1351,34 @@ async def _append_async(dst: list[int], v: int) -> None: @pytest.mark.anyio -async def test_serve_one_runs_handler_and_returns_jsonrpc_response(server: SrvT): +async def test_serve_one_runs_handler_and_returns_result_dict(server: SrvT): """The single-exchange driver: builds the kernel, runs `on_request` once, - wraps via `to_jsonrpc_response`, and tears down `connection.exit_stack`.""" - conn = Connection.from_envelope(HANDSHAKE_PROTOCOL_VERSIONS[-1], None, None) + returns the agnostic result dict, and tears down `connection.exit_stack`.""" + conn = Connection.from_envelope(LATEST_HANDSHAKE_VERSION, None, None) cleaned: list[int] = [] conn.exit_stack.push_async_callback(_append_async, cleaned, 1) - reply = await serve_one( + result = await serve_one( server, _StubDispatchContext(9), "tools/list", None, connection=conn, lifespan_state=_LIFESPAN ) - assert isinstance(reply, JSONRPCResponse) - assert reply.id == 9 - assert reply.result["tools"][0]["name"] == "t" + assert result["tools"][0]["name"] == "t" assert cleaned == [1] ctx = _seen_ctx[0] - assert ctx.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert ctx.protocol_version == LATEST_HANDSHAKE_VERSION @pytest.mark.anyio -async def test_serve_one_maps_error_to_jsonrpc_error_and_still_closes_exit_stack(server: SrvT): +async def test_serve_one_propagates_error_and_still_closes_exit_stack(server: SrvT): """SDK-defined: a kernel-produced error (here `METHOD_NOT_FOUND` for an - unregistered method) is wrapped as a `JSONRPCError`, and the per-request - exit stack is closed on the error path too.""" - conn = Connection.from_envelope(HANDSHAKE_PROTOCOL_VERSIONS[-1], None, None) + unregistered method) propagates as `MCPError`, and the per-request exit + stack is closed on the error path too.""" + conn = Connection.from_envelope(LATEST_HANDSHAKE_VERSION, None, None) cleaned: list[int] = [] conn.exit_stack.push_async_callback(_append_async, cleaned, 1) - reply = await serve_one( - server, _StubDispatchContext(2), "resources/list", None, connection=conn, lifespan_state=_LIFESPAN - ) - assert isinstance(reply, JSONRPCError) - assert reply.error.code == METHOD_NOT_FOUND + with pytest.raises(MCPError) as exc_info: + await serve_one( + server, _StubDispatchContext(2), "resources/list", None, connection=conn, lifespan_state=_LIFESPAN + ) + assert exc_info.value.error.code == METHOD_NOT_FOUND assert cleaned == [1] @@ -1389,17 +1387,17 @@ async def test_serve_one_reads_connection_protocol_version_as_a_fact(server: Srv """`serve_one` builds the kernel over the entry's `Connection`; the kernel reads `connection.protocol_version` for the version gate. A `from_envelope` connection at a modern version rejects a method absent there.""" - conn = Connection.from_envelope(MODERN_PROTOCOL_VERSIONS[0], None, None) - reply = await serve_one( - server, - _StubDispatchContext(1), - "logging/setLevel", - {"level": "info"}, - connection=conn, - lifespan_state=_LIFESPAN, - ) - assert isinstance(reply, JSONRPCError) - assert reply.error.code == METHOD_NOT_FOUND + conn = Connection.from_envelope(LATEST_MODERN_VERSION, None, None) + with pytest.raises(MCPError) as exc_info: + await serve_one( + server, + _StubDispatchContext(1), + "logging/setLevel", + {"level": "info"}, + connection=conn, + lifespan_state=_LIFESPAN, + ) + assert exc_info.value.error.code == METHOD_NOT_FOUND @pytest.mark.anyio @@ -1423,5 +1421,5 @@ async def test_serve_connection_drives_dispatcher_loop_and_tears_down(server: Sr assert cleaned == [] close() assert cleaned == [1] - assert conn.protocol_version == HANDSHAKE_PROTOCOL_VERSIONS[-1] + assert conn.protocol_version == LATEST_HANDSHAKE_VERSION assert conn.client_params is not None From 78823d48c932e8257045f5a2f4fce619a84cec7c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 23 Jun 2026 18:15:28 +0000 Subject: [PATCH 15/22] Move to_jsonrpc_response to the HTTP entry; tighten migration.md and version-constant aliases runner.py is now JSON-RPC-wire-model-free: to_jsonrpc_response moves into _streamable_http_modern.py (its only caller) as a private helper, and the unused raise_unhandled parameter is dropped. types/__init__.py imports LATEST_PROTOCOL_VERSION directly from shared.version instead of bouncing through _types.py. migration.md: the era-neutral-accessors section now says "at most one" of initialize_result/discover_result is non-None (both are None before any handshake on the lowlevel session) and names which slot a 2025 vs 2026 server fills; notes that lowlevel ClientSession still lets you call methods before any handshake, as in v1. Test files drop the per-file MODERN/_MODERN/MODERN_VERSION aliases in favour of LATEST_MODERN_VERSION directly. --- docs/migration.md | 8 +- src/mcp/server/_streamable_http_modern.py | 29 +++++- src/mcp/server/runner.py | 34 +------ src/mcp/types/__init__.py | 3 +- src/mcp/types/_types.py | 1 - .../lowlevel/test_client_connect.py | 20 ++-- .../transports/test_hosting_http_modern.py | 9 +- tests/server/test_connection.py | 21 ++-- tests/server/test_runner.py | 97 +------------------ tests/server/test_streamable_http_modern.py | 76 ++++++++++++++- tests/shared/test_inbound.py | 17 ++-- 11 files changed, 138 insertions(+), 177 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index acc0fd1078..f35721dcc9 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -332,7 +332,7 @@ result = await session.list_tools(params=PaginatedRequestParams(cursor="next_pag ### `ClientSession.get_server_capabilities()` replaced by era-neutral accessors -`ClientSession` now exposes the negotiated server metadata as properties: `server_capabilities`, `server_info`, `instructions`, and `protocol_version`. These are populated by whichever connection step ran (`initialize()` for ≤2025-11-25 servers, `discover()` for 2026-07-28+). The `get_server_capabilities()` method has been removed. +`ClientSession` now exposes the negotiated server metadata as properties: `server_capabilities`, `server_info`, `instructions`, and `protocol_version`. These are populated by whichever connection step ran (`initialize()` for ≤2025-11-25 servers, `discover()` for 2026-07-28+), and are `None` if none has — matching v1's `get_server_capabilities()`. The `get_server_capabilities()` method has been removed. **Before (v1):** @@ -350,9 +350,9 @@ instructions = session.instructions version = session.protocol_version ``` -The raw handshake result is also retained as `session.initialize_result` (legacy path) or `session.discover_result` (modern path) — exactly one is non-`None`. +The raw handshake result is also retained: `session.initialize_result` is set after `initialize()` (≤2025-11-25 servers — including `stateless_http=True` servers, which still answer `initialize`); `session.discover_result` is set after `discover()` (2026-07-28+ servers). At most one is non-`None`. -On the high-level `Client`, `client.server_capabilities`, `client.server_info`, and `client.protocol_version` are non-nullable inside the context manager. `client.instructions` remains `str | None` since the server may omit it. +On the high-level `Client`, `client.server_capabilities`, `client.server_info`, and `client.protocol_version` are non-nullable inside the context manager. `client.instructions` remains `str | None` since the server may omit it. (The lowlevel `ClientSession` still lets you call methods before any handshake, as in v1; `Client` always handshakes on enter.) ### `McpError` renamed to `MCPError` @@ -772,8 +772,6 @@ async def my_tool(ctx: Context[MyLifespanState]) -> str: ... `SUPPORTED_PROTOCOL_VERSIONS` is deprecated — it's now the union of `HANDSHAKE_PROTOCOL_VERSIONS` (initialize-handshake versions) and `MODERN_PROTOCOL_VERSIONS` (per-request-envelope versions). If you were using it to mean "versions the initialize handshake accepts", switch to `HANDSHAKE_PROTOCOL_VERSIONS`. Named scalars derived from these tuples are now exported alongside them — `LATEST_HANDSHAKE_VERSION`, `LATEST_MODERN_VERSION`, `OLDEST_SUPPORTED_VERSION` — so prefer those over indexing the tuples directly. -`LATEST_PROTOCOL_VERSION` now reflects the newest protocol revision the SDK supports (`2026-07-28`). Code that used it to mean "the version `.initialize()` offers" should switch to `LATEST_HANDSHAKE_VERSION`. - ### `ProgressContext` and `progress()` context manager removed The `mcp.shared.progress` module (`ProgressContext`, `Progress`, and the `progress()` context manager) has been removed. This module had no real-world adoption — all users send progress notifications via `Context.report_progress()` or `session.send_progress_notification()` directly. diff --git a/src/mcp/server/_streamable_http_modern.py b/src/mcp/server/_streamable_http_modern.py index 07e6cfd5a0..d151bc259c 100644 --- a/src/mcp/server/_streamable_http_modern.py +++ b/src/mcp/server/_streamable_http_modern.py @@ -14,7 +14,7 @@ import json import logging -from collections.abc import Mapping +from collections.abc import Awaitable, Mapping from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, TypeVar @@ -25,14 +25,16 @@ from starlette.types import Receive, Scope, Send from mcp.server.connection import Connection -from mcp.server.runner import serve_one, to_jsonrpc_response +from mcp.server.runner import serve_one from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings from mcp.shared.dispatcher import CallOptions from mcp.shared.exceptions import NoBackChannelError from mcp.shared.inbound import ERROR_CODE_HTTP_STATUS, InboundLadderRejection, classify_inbound_request +from mcp.shared.jsonrpc_dispatcher import handler_exception_to_error_data from mcp.shared.message import MessageMetadata, ServerMessageMetadata from mcp.shared.transport_context import TransportContext from mcp.types import ( + INTERNAL_ERROR, INVALID_REQUEST, PARSE_ERROR, ClientCapabilities, @@ -99,6 +101,27 @@ def _typed(model: type[_ModelT], raw: Any) -> _ModelT | None: return None +async def _to_jsonrpc_response( + request_id: RequestId, coro: Awaitable[dict[str, Any]] +) -> JSONRPCResponse | JSONRPCError: + """Await ``coro`` and wrap its outcome as the JSON-RPC reply for ``request_id``. + + The exception-to-wire boundary for the modern HTTP entry, composed around + `serve_one`. `MCPError` and `ValidationError` map via the shared + `handler_exception_to_error_data` ladder; any other exception is logged and + surfaced as `INTERNAL_ERROR` so handler internals never reach the wire. + """ + try: + result = await coro + except Exception as exc: + error = handler_exception_to_error_data(exc) + if error is None: + logger.exception("request handler raised") + error = ErrorData(code=INTERNAL_ERROR, message="Internal server error") + return JSONRPCError(jsonrpc="2.0", id=request_id, error=error) + return JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result) + + async def _write( msg: JSONRPCResponse | JSONRPCError, scope: Scope, @@ -193,7 +216,7 @@ async def handle_modern_request( request_id=req.id, message_metadata=ServerMessageMetadata(request_context=request), ) - msg = await to_jsonrpc_response( + msg = await _to_jsonrpc_response( req.id, serve_one(app, dctx, req.method, req.params, connection=connection, lifespan_state=lifespan_state) ) await _write(msg, scope, receive, send) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index b162ca587a..3fa8b3bc79 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -33,7 +33,7 @@ from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnNotify, OnRequest from mcp.shared.exceptions import MCPError -from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher, handler_exception_to_error_data +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.transport_context import TransportContext from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION @@ -48,9 +48,6 @@ Implementation, InitializeRequestParams, InitializeResult, - JSONRPCError, - JSONRPCResponse, - RequestId, RequestParams, RequestParamsMeta, ) @@ -69,7 +66,6 @@ "serve_connection", "serve_loop", "serve_one", - "to_jsonrpc_response", ] logger = logging.getLogger(__name__) @@ -182,31 +178,6 @@ async def aclose_shielded(connection: Connection) -> None: ) -async def to_jsonrpc_response( - request_id: RequestId, coro: Awaitable[dict[str, Any]], *, raise_unhandled: bool = False -) -> JSONRPCResponse | JSONRPCError: - """Await ``coro`` and wrap its outcome as the JSON-RPC reply for ``request_id``. - - The exception-to-wire boundary for the modern HTTP entry, which composes - this around `serve_one` directly. `MCPError` and `ValidationError` - map via the shared `handler_exception_to_error_data` ladder; any other - exception is logged and surfaced as `INTERNAL_ERROR` so handler internals - never reach the wire. Set ``raise_unhandled`` to let unmapped exceptions - propagate instead of being sanitized. - """ - try: - result = await coro - except Exception as exc: - error = handler_exception_to_error_data(exc) - if error is None: - if raise_unhandled: - raise - logger.exception("request handler raised") - error = ErrorData(code=INTERNAL_ERROR, message="Internal server error") - return JSONRPCError(jsonrpc="2.0", id=request_id, error=error) - return JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result) - - def _apply_middleware( middleware: ServerMiddleware[Any], call_next: CallNext, ctx: ServerRequestContext[Any, Any] ) -> Awaitable[HandlerResult]: @@ -516,8 +487,7 @@ async def serve_one( only consumes them. Raises whatever the handler chain raises (`MCPError` / `ValidationError` / - unmapped); callers own the exception-to-wire mapping. The HTTP entry - composes this with `to_jsonrpc_response`. + unmapped); callers own the exception-to-wire mapping. """ runner = ServerRunner(server, connection, lifespan_state) try: diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index 992d584687..491047ae74 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -4,12 +4,13 @@ https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/draft/schema.json """ +from mcp.shared.version import LATEST_PROTOCOL_VERSION + # Re-export everything from _types for backward compatibility from mcp.types._types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, DEFAULT_NEGOTIATED_VERSION, - LATEST_PROTOCOL_VERSION, LOG_LEVEL_META_KEY, PROTOCOL_VERSION_META_KEY, Annotations, diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index a08b4a3e59..815803c34f 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -20,7 +20,6 @@ from pydantic.alias_generators import to_camel from typing_extensions import NotRequired, TypedDict -from mcp.shared.version import LATEST_PROTOCOL_VERSION as LATEST_PROTOCOL_VERSION from mcp.types.jsonrpc import RequestId DEFAULT_NEGOTIATED_VERSION: Final[str] = "2025-03-26" diff --git a/tests/interaction/lowlevel/test_client_connect.py b/tests/interaction/lowlevel/test_client_connect.py index 9ba3aa65e9..f508dfc90c 100644 --- a/tests/interaction/lowlevel/test_client_connect.py +++ b/tests/interaction/lowlevel/test_client_connect.py @@ -51,8 +51,6 @@ pytestmark = pytest.mark.anyio -MODERN_VERSION = LATEST_MODERN_VERSION - def _tools_server(name: str = "negotiator") -> Server: """A low-level server with one list-tools handler, so a feature request has something to reach.""" @@ -108,7 +106,7 @@ async def test_pinned_mode_sends_no_connect_time_traffic() -> None: with anyio.fail_after(5): async with ( mounted_app(_tools_server(), on_request=on_request) as (http, _), - Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode=MODERN_VERSION) as client, + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode=LATEST_MODERN_VERSION) as client, ): assert requests == [] # entering the Client produced zero HTTP traffic result = await client.list_tools() @@ -128,7 +126,7 @@ async def test_prior_discover_populates_state_with_zero_connect_time_traffic() - immediately with zero round trips. """ prior = DiscoverResult( - supported_versions=[MODERN_VERSION], + supported_versions=[LATEST_MODERN_VERSION], capabilities=ServerCapabilities(tools=ToolsCapability(list_changed=False)), server_info=Implementation(name="cached-server", version="9.9.9"), ) @@ -139,7 +137,7 @@ async def test_prior_discover_populates_state_with_zero_connect_time_traffic() - mounted_app(_tools_server(), on_request=on_request) as (http, _), Client( streamable_http_client(f"{BASE_URL}/mcp", http_client=http), - mode=MODERN_VERSION, + mode=LATEST_MODERN_VERSION, prior_discover=prior, ) as client, ): @@ -167,7 +165,7 @@ async def test_auto_mode_probes_server_discover_and_adopts_the_result() -> None: mounted_app(server, on_request=on_request) as (http, _), Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto") as client, ): - assert client.protocol_version == MODERN_VERSION + assert client.protocol_version == LATEST_MODERN_VERSION assert client.server_info.name == "discoverable" await client.list_tools() @@ -211,9 +209,9 @@ async def discover(ctx: ServerRequestContext, params: types.RequestParams | None mounted_app(server, on_request=on_request) as (http, _), Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto") as client, ): - assert client.protocol_version == MODERN_VERSION + assert client.protocol_version == LATEST_MODERN_VERSION - assert calls == [MODERN_VERSION, MODERN_VERSION] + assert calls == [LATEST_MODERN_VERSION, LATEST_MODERN_VERSION] assert [json.loads(r.content)["method"] for r in requests][:2] == ["server/discover", "server/discover"] @@ -332,7 +330,7 @@ async def list_tools( assert len(observed) == 2 for meta in observed: - assert meta[PROTOCOL_VERSION_META_KEY] == MODERN_VERSION + assert meta[PROTOCOL_VERSION_META_KEY] == LATEST_MODERN_VERSION assert meta[CLIENT_INFO_META_KEY] == {"name": "enveloper", "version": "1.2.3"} assert CLIENT_CAPABILITIES_META_KEY in meta @@ -350,7 +348,7 @@ async def test_http_protocol_version_header_matches_meta_protocol_version_on_eve with anyio.fail_after(5): async with ( mounted_app(_tools_server(), on_request=on_request) as (http, _), - Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode=MODERN_VERSION) as client, + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode=LATEST_MODERN_VERSION) as client, ): await client.list_tools() await client.list_tools() @@ -359,4 +357,4 @@ async def test_http_protocol_version_header_matches_meta_protocol_version_on_eve for request in requests: body = json.loads(request.content) assert request.headers["mcp-protocol-version"] == body["params"]["_meta"][PROTOCOL_VERSION_META_KEY] - assert request.headers["mcp-protocol-version"] == MODERN_VERSION + assert request.headers["mcp-protocol-version"] == LATEST_MODERN_VERSION diff --git a/tests/interaction/transports/test_hosting_http_modern.py b/tests/interaction/transports/test_hosting_http_modern.py index 1ed8dea201..52b20629e6 100644 --- a/tests/interaction/transports/test_hosting_http_modern.py +++ b/tests/interaction/transports/test_hosting_http_modern.py @@ -20,6 +20,7 @@ from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext +from mcp.shared.version import LATEST_MODERN_VERSION from mcp.types import ( CLIENT_CAPABILITIES_META_KEY, INTERNAL_ERROR, @@ -45,8 +46,6 @@ pytestmark = pytest.mark.anyio -MODERN_VERSION = "2026-07-28" - def _modern_headers(*, method: str, name: str | None = None) -> dict[str, str]: """Request headers for a 2026-07-28 POST. @@ -54,7 +53,7 @@ def _modern_headers(*, method: str, name: str | None = None) -> dict[str, str]: The Accept/Content-Type baseline plus the ``MCP-Protocol-Version`` routing header and the ``Mcp-Method`` / ``Mcp-Name`` advisory headers a 2026-era client always sends. """ - headers = base_headers() | {"mcp-protocol-version": MODERN_VERSION, "mcp-method": method} + headers = base_headers() | {"mcp-protocol-version": LATEST_MODERN_VERSION, "mcp-method": method} if name is not None: headers["mcp-name"] = name return headers @@ -67,7 +66,7 @@ def _meta_envelope() -> dict[str, object]: capabilities travel on each request instead of once per session. """ return { - "io.modelcontextprotocol/protocolVersion": MODERN_VERSION, + "io.modelcontextprotocol/protocolVersion": LATEST_MODERN_VERSION, "io.modelcontextprotocol/clientInfo": {"name": "raw", "version": "0.0.0"}, "io.modelcontextprotocol/clientCapabilities": {}, } @@ -335,7 +334,7 @@ async def on_response(response: httpx.Response) -> None: ): session.adopt( DiscoverResult( - supported_versions=[MODERN_VERSION], + supported_versions=[LATEST_MODERN_VERSION], capabilities=ServerCapabilities(), server_info=Implementation(name="srv", version="0"), ) diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index d0c4dfa559..3a09aa15d7 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -40,7 +40,6 @@ ) _CLIENT_INFO = Implementation(name="t", version="0") -_MODERN = LATEST_MODERN_VERSION class StubOutbound: @@ -70,8 +69,8 @@ async def notify(self, method: str, params: Mapping[str, Any] | None, opts: Call def test_from_envelope_is_born_ready_with_no_back_channel(): """SDK-defined: `from_envelope` populates `protocol_version`, sets `initialized`, and holds the no-channel sentinel so `has_standalone_channel` derives False.""" - conn = Connection.from_envelope(_MODERN, None, None) - assert conn.protocol_version == _MODERN + conn = Connection.from_envelope(LATEST_MODERN_VERSION, None, None) + assert conn.protocol_version == LATEST_MODERN_VERSION assert conn.initialized.is_set() assert conn.initialize_accepted is True assert conn.has_standalone_channel is False @@ -83,11 +82,11 @@ def test_from_envelope_records_client_params_when_both_info_and_caps_supplied(): """SDK-defined: when both client info and capabilities are supplied, `from_envelope` synthesizes `client_params` so capability checks can run.""" caps = ClientCapabilities(sampling=SamplingCapability()) - conn = Connection.from_envelope(_MODERN, _CLIENT_INFO, caps) + conn = Connection.from_envelope(LATEST_MODERN_VERSION, _CLIENT_INFO, caps) assert conn.client_params is not None assert conn.client_params.client_info.name == "t" assert conn.client_params.capabilities.sampling is not None - assert conn.client_params.protocol_version == _MODERN + assert conn.client_params.protocol_version == LATEST_MODERN_VERSION @pytest.mark.parametrize( @@ -99,7 +98,7 @@ def test_from_envelope_leaves_client_params_none_when_either_is_missing( ): """SDK-defined: `client_params` is only synthesized when both info and caps are present; either missing leaves it `None`.""" - conn = Connection.from_envelope(_MODERN, info, caps) + conn = Connection.from_envelope(LATEST_MODERN_VERSION, info, caps) assert conn.client_params is None @@ -107,7 +106,7 @@ def test_from_envelope_with_explicit_outbound_has_standalone_channel(): """SDK-defined: duplex modern transports pass an outbound; `has_standalone_channel` derives True since the held outbound is not the no-channel sentinel.""" out = StubOutbound() - conn = Connection.from_envelope(_MODERN, None, None, outbound=out) + conn = Connection.from_envelope(LATEST_MODERN_VERSION, None, None, outbound=out) assert conn.has_standalone_channel is True assert conn.outbound is out assert conn.initialized.is_set() @@ -124,8 +123,8 @@ def test_for_loop_seeds_version_from_hint_or_latest_and_is_not_born_ready(): assert conn.initialize_accepted is False assert conn.client_params is None - hinted = Connection.for_loop(out, protocol_version_hint=_MODERN) - assert hinted.protocol_version == _MODERN + hinted = Connection.for_loop(out, protocol_version_hint=LATEST_MODERN_VERSION) + assert hinted.protocol_version == LATEST_MODERN_VERSION def test_for_loop_records_session_id_when_supplied(): @@ -248,7 +247,7 @@ class _CustomResult(BaseModel): async def test_send_request_skips_the_surface_gate_when_method_absent_at_version(): """Surface row absent for the negotiated version: gate is bypassed and only the inferred result type validates.""" - conn = Connection.for_loop(StubOutbound(result={}), protocol_version_hint=_MODERN) + conn = Connection.for_loop(StubOutbound(result={}), protocol_version_hint=LATEST_MODERN_VERSION) result = await conn.send_request(PingRequest()) assert isinstance(result, EmptyResult) @@ -328,7 +327,7 @@ def test_connection_check_capability_false_when_no_client_params_recorded(): conn = Connection.for_loop(StubOutbound()) assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is False # Same for a born-ready connection that supplied neither info nor caps. - assert Connection.from_envelope(_MODERN, None, None).check_capability(ClientCapabilities()) is False + assert Connection.from_envelope(LATEST_MODERN_VERSION, None, None).check_capability(ClientCapabilities()) is False @pytest.mark.parametrize( diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 8b1eb22105..30c611066c 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -4,7 +4,7 @@ `Server` as the registry. The `connected_runner` helper starts both sides and (by default) performs the initialize handshake, so each test exercises only the behaviour under test. Driver tests (`serve_connection`, `serve_one`, -`to_jsonrpc_response`, `aclose_shielded`) follow at the bottom. +`aclose_shielded`) follow at the bottom. """ from collections.abc import AsyncIterator, Mapping @@ -30,7 +30,6 @@ otel_middleware, serve_connection, serve_one, - to_jsonrpc_response, ) from mcp.server.session import ServerSession from mcp.shared.dispatcher import CallOptions, DispatchContext, DispatchMiddleware, OnRequest @@ -50,8 +49,6 @@ ErrorData, Implementation, InitializeRequestParams, - JSONRPCError, - JSONRPCResponse, ListToolsResult, NotificationParams, PaginatedRequestParams, @@ -1201,98 +1198,6 @@ async def _append(i: int) -> None: assert "abandoning remaining callbacks" not in caplog.text -# --- to_jsonrpc_response ------------------------------------------------------- - - -@pytest.mark.anyio -async def test_to_jsonrpc_response_wraps_success_as_jsonrpc_response(): - """SDK-defined: a handler coroutine resolving to a result dict is wrapped as a - `JSONRPCResponse` carrying the supplied id and the dict verbatim as `result`.""" - - async def ok() -> dict[str, Any]: - return {"k": "v"} - - reply = await to_jsonrpc_response(7, ok()) - assert isinstance(reply, JSONRPCResponse) - assert reply.id == 7 - assert reply.result == {"k": "v"} - - -@pytest.mark.anyio -async def test_to_jsonrpc_response_maps_mcp_error_to_jsonrpc_error(): - """SDK-defined: an `MCPError` raised by the handler coroutine is wrapped as a - `JSONRPCError` whose `error` carries the same code, message, and data.""" - - async def fail() -> dict[str, Any]: - raise MCPError(code=METHOD_NOT_FOUND, message="nope", data="x") - - reply = await to_jsonrpc_response("rid", fail()) - assert isinstance(reply, JSONRPCError) - assert reply.id == "rid" - assert reply.error == ErrorData(code=METHOD_NOT_FOUND, message="nope", data="x") - - -@pytest.mark.anyio -async def test_to_jsonrpc_response_maps_validation_error_to_invalid_params(): - """SDK-defined: a pydantic `ValidationError` escaping the handler coroutine is - mapped to `INVALID_PARAMS` with a generic message (validator detail does not - reach the wire).""" - - async def fail() -> dict[str, Any]: - Tool.model_validate({"name": 123}) # raises ValidationError - raise NotImplementedError - - reply = await to_jsonrpc_response(1, fail()) - assert isinstance(reply, JSONRPCError) - assert reply.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") - - -@pytest.mark.anyio -async def test_to_jsonrpc_response_maps_unmapped_exception_to_internal_error_and_logs( - caplog: pytest.LogCaptureFixture, -): - """SDK-defined: an unmapped exception is logged server-side and surfaced as - `INTERNAL_ERROR` with a generic message; the exception text never reaches the - wire.""" - - async def fail() -> dict[str, Any]: - raise RuntimeError("boom") - - reply = await to_jsonrpc_response(1, fail()) - assert isinstance(reply, JSONRPCError) - assert reply.error.code == INTERNAL_ERROR - # Handler internals never reach the wire. - assert "boom" not in reply.error.message - assert "request handler raised" in caplog.text - - -@pytest.mark.anyio -async def test_to_jsonrpc_response_raise_unhandled_propagates_unmapped_exception(): - """SDK-defined: ``raise_unhandled=True`` lets an unmapped exception escape - instead of being sanitized to `INTERNAL_ERROR` — used by the in-process test - path so the original traceback reaches the caller.""" - - async def fail() -> dict[str, Any]: - raise RuntimeError("boom") - - with pytest.raises(RuntimeError, match="boom"): - await to_jsonrpc_response(1, fail(), raise_unhandled=True) - - -@pytest.mark.anyio -async def test_to_jsonrpc_response_raise_unhandled_still_maps_mcp_error(): - """SDK-defined: ``raise_unhandled`` only affects unmapped exceptions; an - `MCPError` is still converted to a `JSONRPCError` (it is protocol-level, not - a crash).""" - - async def fail() -> dict[str, Any]: - raise MCPError(code=METHOD_NOT_FOUND, message="nope") - - reply = await to_jsonrpc_response(1, fail(), raise_unhandled=True) - assert isinstance(reply, JSONRPCError) - assert reply.error.code == METHOD_NOT_FOUND - - # --- aclose_shielded ----------------------------------------------------------- diff --git a/tests/server/test_streamable_http_modern.py b/tests/server/test_streamable_http_modern.py index 92b0729601..7b655bdd6f 100644 --- a/tests/server/test_streamable_http_modern.py +++ b/tests/server/test_streamable_http_modern.py @@ -15,20 +15,31 @@ from starlette.types import Receive, Scope, Send from mcp.server import Server, ServerRequestContext, runner -from mcp.server._streamable_http_modern import _SingleExchangeDispatchContext, handle_modern_request +from mcp.server._streamable_http_modern import ( + _SingleExchangeDispatchContext, + _to_jsonrpc_response, + handle_modern_request, +) from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.exceptions import NoBackChannelError +from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.transport_context import TransportContext from mcp.shared.version import LATEST_MODERN_VERSION from mcp.types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, + INTERNAL_ERROR, + INVALID_PARAMS, INVALID_REQUEST, + METHOD_NOT_FOUND, PARSE_ERROR, PROTOCOL_VERSION_META_KEY, + ErrorData, + JSONRPCError, + JSONRPCResponse, ListToolsResult, PaginatedRequestParams, + Tool, ) pytestmark = pytest.mark.anyio @@ -198,3 +209,64 @@ async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | assert response.status_code == 200 # pragma: lax no cover assert response.json()["result"]["tools"] == [] # pragma: lax no cover assert "abandoning remaining callbacks" in caplog.text # pragma: lax no cover + + +# --- _to_jsonrpc_response ------------------------------------------------------ + + +async def test_to_jsonrpc_response_wraps_success_as_jsonrpc_response() -> None: + """SDK-defined: a handler coroutine resolving to a result dict is wrapped as a + `JSONRPCResponse` carrying the supplied id and the dict verbatim as `result`.""" + + async def ok() -> dict[str, Any]: + return {"k": "v"} + + reply = await _to_jsonrpc_response(7, ok()) + assert isinstance(reply, JSONRPCResponse) + assert reply.id == 7 + assert reply.result == {"k": "v"} + + +async def test_to_jsonrpc_response_maps_mcp_error_to_jsonrpc_error() -> None: + """SDK-defined: an `MCPError` raised by the handler coroutine is wrapped as a + `JSONRPCError` whose `error` carries the same code, message, and data.""" + + async def fail() -> dict[str, Any]: + raise MCPError(code=METHOD_NOT_FOUND, message="nope", data="x") + + reply = await _to_jsonrpc_response("rid", fail()) + assert isinstance(reply, JSONRPCError) + assert reply.id == "rid" + assert reply.error == ErrorData(code=METHOD_NOT_FOUND, message="nope", data="x") + + +async def test_to_jsonrpc_response_maps_validation_error_to_invalid_params() -> None: + """SDK-defined: a pydantic `ValidationError` escaping the handler coroutine is + mapped to `INVALID_PARAMS` with a generic message (validator detail does not + reach the wire).""" + + async def fail() -> dict[str, Any]: + Tool.model_validate({"name": 123}) # raises ValidationError + raise NotImplementedError + + reply = await _to_jsonrpc_response(1, fail()) + assert isinstance(reply, JSONRPCError) + assert reply.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + + +async def test_to_jsonrpc_response_maps_unmapped_exception_to_internal_error_and_logs( + caplog: pytest.LogCaptureFixture, +) -> None: + """SDK-defined: an unmapped exception is logged server-side and surfaced as + `INTERNAL_ERROR` with a generic message; the exception text never reaches the + wire.""" + + async def fail() -> dict[str, Any]: + raise RuntimeError("boom") + + reply = await _to_jsonrpc_response(1, fail()) + assert isinstance(reply, JSONRPCError) + assert reply.error.code == INTERNAL_ERROR + # Handler internals never reach the wire. + assert "boom" not in reply.error.message + assert "request handler raised" in caplog.text diff --git a/tests/shared/test_inbound.py b/tests/shared/test_inbound.py index dcf0490c83..a8e275a4fa 100644 --- a/tests/shared/test_inbound.py +++ b/tests/shared/test_inbound.py @@ -33,9 +33,6 @@ UNSUPPORTED_PROTOCOL_VERSION, ) -MODERN = LATEST_MODERN_VERSION -"""The modern protocol-version string, read from the registry — never inlined here.""" - CLIENT_INFO = {"name": "t", "version": "0"} CLIENT_CAPS: dict[str, Any] = {} @@ -43,7 +40,7 @@ def envelope( method: str = "tools/list", *, - version: str = MODERN, + version: str = LATEST_MODERN_VERSION, drop: frozenset[str] = frozenset(), ) -> dict[str, Any]: """Build a JSON-RPC body carrying a complete modern ``_meta`` envelope. @@ -123,7 +120,7 @@ def test_version_rung_data_reflects_supplied_supported_list() -> None: classify_inbound_request(envelope(), supported_modern_versions=custom), UNSUPPORTED_PROTOCOL_VERSION, ) - assert rejection.data == {"supported": list(custom), "requested": MODERN} + assert rejection.data == {"supported": list(custom), "requested": LATEST_MODERN_VERSION} # --- rung 3: header ↔ envelope agreement --------------------------------------- @@ -137,7 +134,7 @@ def test_header_rung_does_not_reject_when_headers_arg_is_none() -> None: def test_header_rung_passes_when_header_matches_envelope() -> None: """Spec-mandated: an HTTP version header equal to the envelope version passes rung 3.""" - result = classify_inbound_request(envelope(), headers={MCP_PROTOCOL_VERSION_HEADER: MODERN}) + result = classify_inbound_request(envelope(), headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}) assert isinstance(result, InboundModernRoute) @@ -158,9 +155,9 @@ def test_header_rung_rejects_on_disagreement(headers: dict[str, str]) -> None: def test_all_rungs_pass_yields_route() -> None: """Spec-mandated: a complete envelope at a supported version with agreeing header routes, surfacing the envelope.""" - result = classify_inbound_request(envelope(), headers={MCP_PROTOCOL_VERSION_HEADER: MODERN}) + result = classify_inbound_request(envelope(), headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}) assert isinstance(result, InboundModernRoute) - assert result.protocol_version == MODERN + assert result.protocol_version == LATEST_MODERN_VERSION assert result.client_info == CLIENT_INFO assert result.client_capabilities == CLIENT_CAPS @@ -168,7 +165,7 @@ def test_all_rungs_pass_yields_route() -> None: @pytest.mark.parametrize("method", ["initialize", "myorg/custom", "does/not/exist"]) def test_classifier_passes_unknown_method_through_to_route(method: str) -> None: """SDK-defined: the classifier does not gate on method — kernel dispatch is the single owner of that decision.""" - result = classify_inbound_request(envelope(method), headers={MCP_PROTOCOL_VERSION_HEADER: MODERN}) + result = classify_inbound_request(envelope(method), headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}) assert isinstance(result, InboundModernRoute) @@ -177,7 +174,7 @@ def test_ladder_first_failure_wins() -> None: would both fail; the header rung fires first so an inconsistent client is told it disagrees with itself rather than that its body version is unsupported.""" body = envelope(version=LATEST_HANDSHAKE_VERSION) - result = classify_inbound_request(body, headers={MCP_PROTOCOL_VERSION_HEADER: MODERN}) + result = classify_inbound_request(body, headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}) assert_rejected(result, HEADER_MISMATCH) From f9a15e15fa16620adba93a89419fe8397b4a3a8c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 23 Jun 2026 18:15:49 +0000 Subject: [PATCH 16/22] Client: collapse _inproc_server/_transport into a single _connect closure __post_init__ now resolves a single _connect closure from the shape of the server argument alone (in-process vs URL vs Transport instance). mode and raise_exceptions are passed to the closure at enter time so they're read at the same moment __aenter__ reads them for the handshake step. _build_session collapses to one line of logic; the mutually-exclusive Optional fields and the assert that guarded them are gone. JSONRPCDispatcher.on_stream_exception is now public-mutable so ClientSession can install its message_handler routing after the dispatcher is built; the install only happens when no caller-supplied hook is already set. ClientSession.adopt() now clears the opposite result slot so at most one of initialize_result/discover_result is non-None by construction. --- src/mcp/client/client.py | 87 +++++++++++++++++----------- src/mcp/client/session.py | 13 +++++ src/mcp/shared/jsonrpc_dispatcher.py | 11 ++-- 3 files changed, 72 insertions(+), 39 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 0248618712..e6524a862a 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Awaitable, Callable, Mapping from contextlib import AsyncExitStack from dataclasses import KW_ONLY, dataclass, field from typing import Any, Literal, TypeVar @@ -21,6 +21,7 @@ from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import Dispatcher, ProgressFnT from mcp.shared.exceptions import MCPDeprecationWarning, MCPError +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( METHOD_NOT_FOUND, @@ -50,6 +51,43 @@ _T = TypeVar("_T") +_Connector = Callable[[AsyncExitStack, ConnectMode, bool], Awaitable["Dispatcher[Any]"]] +"""Resolved at ``__post_init__`` from the shape of ``server`` alone: enter whatever resources +are needed onto the exit stack and hand back the ``Dispatcher`` ``ClientSession`` will drive. +``mode`` and ``raise_exceptions`` are passed at call time so they're read at the same moment +``__aenter__`` reads them for the handshake step.""" + + +def _connect_transport(transport: Transport) -> _Connector: + """Connector for the stream-backed paths (URL, user-supplied ``Transport``).""" + + async def connect(exit_stack: AsyncExitStack, _mode: ConnectMode, _raise_exceptions: bool) -> Dispatcher[Any]: + read_stream, write_stream = await exit_stack.enter_async_context(transport) + return JSONRPCDispatcher(read_stream, write_stream) + + return connect + + +def _connect_inproc(server: Server[Any]) -> _Connector: + """Connector for an in-process ``Server``: legacy mode drives the stream loop via + ``InMemoryTransport``; any other mode drives the modern per-request path through a + ``DirectDispatcher`` peer pair (no streams, no JSON-RPC framing, no initialize handshake).""" + + async def connect(exit_stack: AsyncExitStack, mode: ConnectMode, raise_exceptions: bool) -> Dispatcher[Any]: + if mode == "legacy": + transport = InMemoryTransport(server, raise_exceptions=raise_exceptions) + read_stream, write_stream = await exit_stack.enter_async_context(transport) + return JSONRPCDispatcher(read_stream, write_stream) + lifespan_state = await exit_stack.enter_async_context(server.lifespan(server)) + client_disp, server_disp = create_direct_dispatcher_pair() + tg = await exit_stack.enter_async_context(anyio.create_task_group()) + exit_stack.callback(server_disp.close) + on_request = modern_on_request(server, lifespan_state, raise_exceptions=raise_exceptions) + await tg.start(server_disp.run, on_request, _no_inbound_client_notifications) + return client_disp + + return connect + def _connected(value: _T | None) -> _T: """Narrow a post-handshake session attribute from ``T | None`` to ``T``. @@ -161,19 +199,9 @@ async def main(): _entered: bool = field(init=False, default=False) _session: ClientSession | None = field(init=False, default=None) _exit_stack: AsyncExitStack | None = field(init=False, default=None) - _transport: Transport | None = field(init=False, default=None) - _inproc_server: Server[Any] | None = field(init=False, default=None) + _connect: _Connector = field(init=False, repr=False, compare=False) def __post_init__(self) -> None: - if isinstance(self.server, MCPServer): - self._inproc_server = self.server._lowlevel_server # pyright: ignore[reportPrivateUsage] - elif isinstance(self.server, Server): - self._inproc_server = self.server - elif isinstance(self.server, str): - self._transport = streamable_http_client(self.server) - else: - self._transport = self.server - if self.mode not in ("legacy", "auto") and self.mode not in MODERN_PROTOCOL_VERSIONS: hint = ( f" ({self.mode!r} is a handshake-era version — use mode='legacy')" @@ -184,31 +212,20 @@ def __post_init__(self) -> None: f"mode must be 'legacy', 'auto', or one of {list(MODERN_PROTOCOL_VERSIONS)}; got {self.mode!r}{hint}" ) - async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession: - """Set up the dispatcher/transport and return an un-entered ClientSession.""" - dispatcher: Dispatcher[Any] | None - if self._inproc_server is not None and self.mode != "legacy": - # Modern in-process path: drive the server through a DirectDispatcher peer-pair - # with one `serve_one` per request — no streams, no initialize handshake. - lifespan_state = await exit_stack.enter_async_context(self._inproc_server.lifespan(self._inproc_server)) - client_disp, server_disp = create_direct_dispatcher_pair() - tg = await exit_stack.enter_async_context(anyio.create_task_group()) - exit_stack.callback(server_disp.close) - on_request = modern_on_request(self._inproc_server, lifespan_state, raise_exceptions=self.raise_exceptions) - await tg.start(server_disp.run, on_request, _no_inbound_client_notifications) - dispatcher = client_disp - read_stream = write_stream = None + srv = self.server + if isinstance(srv, MCPServer): + srv = srv._lowlevel_server # pyright: ignore[reportPrivateUsage] + if isinstance(srv, Server): + self._connect = _connect_inproc(srv) + elif isinstance(srv, str): + self._connect = _connect_transport(streamable_http_client(srv)) else: - if self._inproc_server is not None: - transport: Transport = InMemoryTransport(self._inproc_server, raise_exceptions=self.raise_exceptions) - else: - assert self._transport is not None - transport = self._transport - read_stream, write_stream = await exit_stack.enter_async_context(transport) - dispatcher = None + self._connect = _connect_transport(srv) + + async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession: + """Enter the resolved connector and return an un-entered ClientSession.""" + dispatcher = await self._connect(exit_stack, self.mode, self.raise_exceptions) return ClientSession( - read_stream=read_stream, - write_stream=write_stream, dispatcher=dispatcher, read_timeout_seconds=self.read_timeout_seconds, sampling_callback=self.sampling_callback, diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 94a1b2e59b..c38ab44304 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -216,6 +216,14 @@ def __init__( if read_stream is not None or write_stream is not None: raise ValueError("pass read_stream/write_stream or dispatcher, not both") self._dispatcher: Dispatcher[Any] = dispatcher + if isinstance(dispatcher, JSONRPCDispatcher) and dispatcher.on_stream_exception is None: + # Route transport-level Exception items into message_handler — only + # stream-backed dispatchers carry these; DirectDispatcher has none. + # Don't clobber a caller-supplied hook. + # TODO(maxisbey): this leaves a bound-method ref on the dispatcher after + # the session exits (memory pin) and a second wrap of the same dispatcher + # would skip install. The Transport-as-Dispatcher rework removes this seam. + dispatcher.on_stream_exception = self._on_stream_exception else: if read_stream is None or write_stream is None: raise ValueError("read_stream and write_stream are required when no dispatcher is given") @@ -358,6 +366,9 @@ async def initialize(self) -> types.InitializeResult: def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None: """Install negotiated state from a result the caller already holds (no wire traffic). + Clears the opposite slot, so at most one of `initialize_result` / + `discover_result` is ever non-None. + Raises: RuntimeError: `result` is a `DiscoverResult` whose `supported_versions` shares nothing with this client's `MODERN_PROTOCOL_VERSIONS`. @@ -374,10 +385,12 @@ def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None: capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) self._stamp = _make_modern_stamp(mutual[-1], client_info, capabilities) self._discover_result = result + self._initialize_result = None self._negotiated_version = mutual[-1] else: self._stamp = _make_handshake_stamp(result.protocol_version) self._initialize_result = result + self._discover_result = None self._negotiated_version = result.protocol_version async def discover(self) -> types.DiscoverResult: diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 859fd5d7d9..24f1d3593a 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -75,7 +75,7 @@ def handler_exception_to_error_data(exc: BaseException) -> ErrorData | None: with empty ``data`` (no pydantic text on the wire). Returns ``None`` for any other exception so each caller applies its own catch-all - `JSONRPCDispatcher` currently pins ``code=0`` for v1 compat, - `to_jsonrpc_response` uses `INTERNAL_ERROR`. + the modern HTTP entry uses `INTERNAL_ERROR`. """ if isinstance(exc, MCPError): return exc.error @@ -268,7 +268,10 @@ def __init__( self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode self._raise_handler_exceptions = raise_handler_exceptions self._inline_methods = inline_methods - self._on_stream_exception = on_stream_exception + self.on_stream_exception = on_stream_exception + """Observer for ``Exception`` items on the read stream. Mutable so a session can + bind it after the dispatcher is built (e.g. ``ClientSession`` routing into + ``message_handler``); only consulted inside ``run()`` so pre-enter assignment is safe.""" self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} @@ -484,11 +487,11 @@ async def _dispatch( are awaited; any other `await` would head-of-line block the read loop. """ if isinstance(item, Exception): - if self._on_stream_exception is None: + if self.on_stream_exception is None: logger.debug("transport yielded exception: %r", item) return try: - await self._on_stream_exception(item) + await self.on_stream_exception(item) except Exception: logger.exception("on_stream_exception observer raised") return From 91c0224b2ea1487035303c28132403b6c06e0e84 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 09:08:15 +0000 Subject: [PATCH 17/22] Retag dispatcher-hook TODO with its ledger anchor --- src/mcp/client/session.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c38ab44304..ee09bc5ae4 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -220,9 +220,9 @@ def __init__( # Route transport-level Exception items into message_handler — only # stream-backed dispatchers carry these; DirectDispatcher has none. # Don't clobber a caller-supplied hook. - # TODO(maxisbey): this leaves a bound-method ref on the dispatcher after - # the session exits (memory pin) and a second wrap of the same dispatcher - # would skip install. The Transport-as-Dispatcher rework removes this seam. + # TODO(L78): this leaves a bound-method ref on the dispatcher after the + # session exits (memory pin) and a second wrap of the same dispatcher would + # skip install. The Transport-as-Dispatcher rework (L77) removes this seam. dispatcher.on_stream_exception = self._on_stream_exception else: if read_stream is None or write_stream is None: From eab740b91eb3080a75b7b871d334da2c7c7c1d51 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 10:26:14 +0000 Subject: [PATCH 18/22] Restore pv header on dispatcher-written POSTs; widen auto-mode probe fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The streamable-HTTP transport now clears its cached MCP-Protocol-Version header when an initialize POST goes out, then lets every other POST read the cache again (re-collapsing _base_headers into _prepare_headers). This restores the header on JSON-RPC response/error/cancelled POSTs the dispatcher writes without per-message metadata, while still preventing a discover-probe value from leaking onto a fallback initialize. Client(mode='auto') now also falls back to initialize() when the probe is rejected with INVALID_REQUEST — what a deployed v1.x stateful (or stateless) streamable-HTTP server returns for a session-id-less request or an unknown protocol-version header. The lifecycle:discover requirement text is updated to match. --- src/mcp/client/client.py | 5 ++- src/mcp/client/streamable_http.py | 39 +++++++++---------- tests/client/test_streamable_http.py | 36 ++++++++++++----- tests/interaction/_requirements.py | 7 ++-- .../lowlevel/test_client_connect.py | 28 +++++++++---- 5 files changed, 75 insertions(+), 40 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index e6524a862a..809fd4172f 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -24,6 +24,7 @@ from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( + INVALID_REQUEST, METHOD_NOT_FOUND, REQUEST_TIMEOUT, CallToolResult, @@ -252,7 +253,9 @@ async def __aenter__(self) -> Client: try: await session.discover() except MCPError as e: - if e.code in (METHOD_NOT_FOUND, REQUEST_TIMEOUT): + # TODO(L73): invert this allowlist into a `classify_probe_outcome` denylist — + # fall back on every rpc-error/4xx that isn't a recognized modern error. + if e.code in (METHOD_NOT_FOUND, INVALID_REQUEST, REQUEST_TIMEOUT): await session.initialize() else: raise diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index d4c4d3995c..78a8614765 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -81,17 +81,21 @@ def __init__(self, url: str) -> None: """ self.url = url self.session_id: str | None = None - # Captured from the first stamped POST's metadata; reused on transport-internal - # GET/DELETE that don't carry per-message metadata. + # Captured from each stamped POST's metadata. Reused on outbound HTTP that carries + # no per-message header (transport-internal GET/DELETE, and dispatcher-written + # response/error/cancel POSTs that bypass the session's stamp). Cleared when an + # `initialize` POST goes out so a probe-stamped value cannot leak onto the handshake. self._protocol_version_header: str | None = None - def _base_headers(self) -> dict[str, str]: - """Build MCP-specific request headers (accept / content-type / session-id). - - These headers will be merged with the httpx.AsyncClient's default headers, - with these MCP-specific headers taking precedence. POSTs use this directly: - their protocol-version header arrives per-message via ``metadata.headers``, - so they must never read the cached value. + def _prepare_headers(self) -> dict[str, str]: + """Build MCP-specific request headers for any outbound HTTP request. + + These are merged with the ``httpx.AsyncClient`` defaults (these take + precedence). The cached ``MCP-Protocol-Version`` is included whenever + present so messages that don't pass through the session's stamp — + response/error/cancel POSTs, transport-internal GET/DELETE — still + carry the negotiated version. Per-message headers are layered on top + by the caller. """ headers: dict[str, str] = { "accept": "application/json, text/event-stream", @@ -99,15 +103,6 @@ def _base_headers(self) -> dict[str, str]: } if self.session_id: headers[MCP_SESSION_ID] = self.session_id - return headers - - def _prepare_headers(self) -> dict[str, str]: - """Base headers plus the cached protocol-version header. - - Used by transport-internal GET/DELETE (listen stream, resumption, - reconnect, terminate) which don't carry per-message metadata. - """ - headers = self._base_headers() if self._protocol_version_header: headers[MCP_PROTOCOL_VERSION_HEADER] = self._protocol_version_header return headers @@ -249,13 +244,17 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" - headers = self._base_headers() message = ctx.session_message.message + is_initialization = self._is_initialization_request(message) + if is_initialization: + # `initialize` is the negotiation, not a "subsequent request" — discard any + # probe-stamped value so the discover→fallback path can't leak it onto the handshake. + self._protocol_version_header = None + headers = self._prepare_headers() if ctx.metadata is not None and ctx.metadata.headers is not None: headers.update(ctx.metadata.headers) if MCP_PROTOCOL_VERSION_HEADER in ctx.metadata.headers: self._protocol_version_header = ctx.metadata.headers[MCP_PROTOCOL_VERSION_HEADER] - is_initialization = self._is_initialization_request(message) async with ctx.client.stream( "POST", diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index 77b1fdc061..086694aae4 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -17,7 +17,7 @@ from mcp.client.streamable_http import streamable_http_client from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER, encode_header_value from mcp.shared.message import ClientMessageMetadata, SessionMessage -from mcp.types import METHOD_NOT_FOUND, JSONRPCError, JSONRPCRequest +from mcp.types import METHOD_NOT_FOUND, JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse @pytest.mark.parametrize( @@ -104,19 +104,25 @@ def handler(request: httpx.Request) -> httpx.Response: @pytest.mark.anyio -async def test_post_does_not_read_cached_protocol_version_header() -> None: - """A POST's protocol-version header comes only from its own ``metadata.headers``. - - The first POST carries (and caches) a pv header; the second POST sends no metadata - and must therefore carry no pv header — a stale cached value would poison the - fallback ``initialize`` after a failed discover probe. The cache exists for - transport-internal GET/DELETE only. +async def test_initialize_post_clears_cached_pv_header_and_unstamped_posts_read_it() -> None: + """``initialize`` discards the cached protocol-version header; every other POST reads it. + + Steps: + 1. A stamped probe POST caches its ``MCP-Protocol-Version`` header. + 2. An ``initialize`` POST clears that cache before building headers, so the fallback + handshake never carries a probe-stamped value. + 3. A subsequent stamped POST re-seeds the cache with the negotiated version. + 4. An unstamped POST (a JSON-RPC response written by the dispatcher, which never + passes through the session's stamp) then reads the cache and carries the + negotiated version — the spec MUST for all post-initialization HTTP requests. """ recorded: list[httpx.Request] = [] def handler(request: httpx.Request) -> httpx.Response: recorded.append(request) body = json.loads(request.content) + if "id" not in body or "result" in body: + return httpx.Response(202) return httpx.Response(200, json={"jsonrpc": "2.0", "id": body["id"], "result": {}}) with anyio.fail_after(5): @@ -133,6 +139,18 @@ def handler(request: httpx.Request) -> httpx.Response: await read.receive() await write.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=2, method="initialize", params={}))) await read.receive() - assert [r.method for r in recorded] == ["POST", "POST"] + await write.send( + SessionMessage( + message=JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized"), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: "2025-11-25"}), + ) + ) + # An unstamped JSON-RPC response — what the dispatcher writes when answering + # a server-initiated request (sampling/elicitation/roots). + await write.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=99, result={}))) + + assert [r.method for r in recorded] == ["POST", "POST", "POST", "POST"] assert recorded[0].headers[MCP_PROTOCOL_VERSION_HEADER] == "2026-07-28" assert MCP_PROTOCOL_VERSION_HEADER not in recorded[1].headers + assert recorded[2].headers[MCP_PROTOCOL_VERSION_HEADER] == "2025-11-25" + assert recorded[3].headers[MCP_PROTOCOL_VERSION_HEADER] == "2025-11-25" diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 8714977824..7cad27b73a 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -438,12 +438,13 @@ def __post_init__(self) -> None: "lifecycle:discover:network-error-raises": Requirement( source="sdk", behavior=( - "An HTTP timeout, connection error, or non-404 4xx/5xx during server/discover raises to the " - "caller without falling back to initialize." + "A network/connection error or 5xx during server/discover raises to the caller without " + "falling back to initialize. A 4xx with a JSON-RPC error body is a server-side rejection " + "and falls back (legacy servers reject the probe with 400 INVALID_REQUEST)." ), transports=("streamable-http", "streamable-http-stateless"), added_in="2026-07-28", - note="HTTP-only: distinguishes transport-level failures from the -32601 fallback signal.", + note="HTTP-only: distinguishes transport-level failures from server-side rejection.", ), "lifecycle:mode:legacy-never-probes": Requirement( source="sdk", diff --git a/tests/interaction/lowlevel/test_client_connect.py b/tests/interaction/lowlevel/test_client_connect.py index f508dfc90c..f2c6b6393a 100644 --- a/tests/interaction/lowlevel/test_client_connect.py +++ b/tests/interaction/lowlevel/test_client_connect.py @@ -32,6 +32,7 @@ CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, INTERNAL_ERROR, + INVALID_REQUEST, METHOD_NOT_FOUND, PROTOCOL_VERSION_META_KEY, UNSUPPORTED_PROTOCOL_VERSION, @@ -217,7 +218,7 @@ async def discover(ctx: ServerRequestContext, params: types.RequestParams | None @requirement("lifecycle:discover:network-error-raises") async def test_auto_mode_reraises_a_non_fallback_discover_error_without_initializing() -> None: - """A `server/discover` failure outside the {-32601, -32001, -32022} ladder raises without falling back. + """A `server/discover` failure that is not a recognised legacy-server rejection raises without falling back. Requirement `lifecycle:discover:network-error-raises` (sdk-defined): a 5xx-class error from the probe is surfaced to the caller; the client never sends `initialize`. Exercised here as @@ -248,14 +249,27 @@ def is_internal_error(exc: MCPError) -> bool: @requirement("lifecycle:discover:fallback-method-not-found") -async def test_auto_mode_falls_back_to_initialize_when_discover_is_method_not_found() -> None: - """A -32601 from `server/discover` makes an auto-negotiating client run the legacy `initialize` handshake. +@pytest.mark.parametrize( + ("probe_code", "probe_message"), + [ + (METHOD_NOT_FOUND, "Method not found"), + (INVALID_REQUEST, "Bad Request: Missing session ID"), + ], + ids=["method-not-found", "invalid-request"], +) +async def test_auto_mode_falls_back_to_initialize_on_a_legacy_probe_rejection( + probe_code: int, probe_message: str +) -> None: + """A legacy server's rejection of `server/discover` makes an auto-negotiating client fall back to `initialize`. Requirement `lifecycle:discover:fallback-method-not-found` (spec stdio#backward-compatibility): a legacy-era server that does not implement `server/discover` is connected to via the - handshake, and the session lands at a handshake-era protocol version. A real `Server` always - implements `server/discover`, so this test plays the server's side of the wire by hand. - Reserve this pattern for behaviour no real server can be made to produce. + handshake, and the session lands at a handshake-era protocol version. The probe rejection + arrives as METHOD_NOT_FOUND from a server that routes the unknown method, or as + INVALID_REQUEST from a deployed v1.x stateful streamable-HTTP server that rejects the + session-id-less probe before dispatch. A real `Server` always implements `server/discover`, + so this test plays the server's side of the wire by hand. Reserve this pattern for behaviour + no real server can be made to produce. """ methods_seen: list[str] = [] @@ -267,7 +281,7 @@ async def scripted_server(streams: MessageStream) -> None: assert isinstance(frame, JSONRPCRequest | JSONRPCNotification) methods_seen.append(frame.method) if isinstance(frame, JSONRPCRequest) and frame.method == "server/discover": - error = types.ErrorData(code=METHOD_NOT_FOUND, message="Method not found") + error = types.ErrorData(code=probe_code, message=probe_message) await server_write.send(SessionMessage(JSONRPCError(jsonrpc="2.0", id=frame.id, error=error))) elif isinstance(frame, JSONRPCRequest) and frame.method == "initialize": result = InitializeResult( From a7d1275f0a187019d240171dacc76fdf838b62db Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 12:41:07 +0000 Subject: [PATCH 19/22] Add negotiate_auto: denylist probe classifier for mode='auto' MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The mode='auto' connect path now goes through client/_probe.py's negotiate_auto, which inverts the previous allowlist into a denylist: every MCPError from the server/discover probe falls back to initialize(), the sole exception being -32022 with a disjoint modern-only supported list. An unparseable probe result also falls back. Any non-MCPError exception (network/connection errors, anyio resource errors) propagates — an outage is never an era verdict. ClientSession gains send_discover(version) (the raw probe, no retry, no adopt), and discover() is reimplemented on top of it. The __aenter__ mode='auto' arm collapses to a single negotiate_auto call. tests/client/test_probe.py covers the verdict table directly; the interaction-suite fallback test broadens to a parametrized rpc-error set, and the previous "INTERNAL_ERROR raises" assertion is replaced with a network-error case (under the denylist, INTERNAL_ERROR now falls back). --- src/mcp/client/_probe.py | 79 ++++++ src/mcp/client/client.py | 19 +- src/mcp/client/session.py | 51 ++-- tests/client/test_client.py | 13 +- tests/client/test_probe.py | 242 ++++++++++++++++++ tests/interaction/_requirements.py | 11 +- .../lowlevel/test_client_connect.py | 37 ++- 7 files changed, 385 insertions(+), 67 deletions(-) create mode 100644 src/mcp/client/_probe.py create mode 100644 tests/client/test_probe.py diff --git a/src/mcp/client/_probe.py b/src/mcp/client/_probe.py new file mode 100644 index 0000000000..1da4cfcbcb --- /dev/null +++ b/src/mcp/client/_probe.py @@ -0,0 +1,79 @@ +"""Connect-time era negotiation for ``mode='auto'``. + +The ``server/discover`` probe is sent at the newest modern version. Anything +that is not positive evidence the peer is a modern MCP server falls back to +the legacy ``initialize`` handshake — a *denylist* (only the disjoint-modern +case raises) rather than an allowlist of fallback codes. + +Every ``MCPError`` falls back except ``-32022`` with a disjoint modern-only +``supported`` list. The streamable-HTTP transport already maps HTTP-layer +4xx rejections (no JSON-RPC body) into ``MCPError`` codes, so those reach +the same path. Any non-``MCPError`` exception (network/connection errors, +anyio cancellation, the ``RuntimeError`` from ``adopt()`` on no-mutual) +propagates to the caller; an outage or in-process bug is never an era verdict. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import ValidationError + +from mcp import types +from mcp.client.session import ClientSession +from mcp.shared.exceptions import MCPError +from mcp.shared.version import ( + HANDSHAKE_PROTOCOL_VERSIONS, + LATEST_MODERN_VERSION, + MODERN_PROTOCOL_VERSIONS, +) +from mcp.types import UNSUPPORTED_PROTOCOL_VERSION + + +def _parse_supported(data: Any) -> list[str] | None: + """Pull ``data.supported`` off a -32022 error, or ``None`` if not actionable.""" + try: + return types.UnsupportedProtocolVersionErrorData.model_validate(data).supported + except ValidationError: + return None + + +async def negotiate_auto(session: ClientSession) -> None: + """Drive the ``mode='auto'`` connect-time policy on ``session``. + + Probes ``server/discover`` once (twice if the server names a mutual + modern version via -32022), then either ``adopt()``s the result or falls + back to ``initialize()``. Idempotent only in the sense that one of + ``session.discover_result`` / ``session.initialize_result`` is set on + return. + + Raises: + MCPError: The server is modern-only and shares no version with this + client (-32022 with a disjoint ``supported`` list). + Exception: Any transport/network error from the probe propagates as-is. + """ + version = LATEST_MODERN_VERSION + for attempt in range(2): + try: + raw = await session.send_discover(version) + except MCPError as e: + if e.code == UNSUPPORTED_PROTOCOL_VERSION: + supported = _parse_supported(e.error.data) + mutual = [v for v in MODERN_PROTOCOL_VERSIONS if v in (supported or ())] + if mutual and attempt == 0: + version = mutual[-1] + continue + if supported is not None and not any(v in HANDSHAKE_PROTOCOL_VERSIONS for v in supported): + raise # server is modern-only and disjoint — real incompatibility + await session.initialize() # every other rpc-error → legacy (the denylist) + return + # any other exception (httpx.TransportError, ConnectionError, anyio errors, + # RuntimeError from adopt) → propagate + try: + result = types.DiscoverResult.model_validate(raw) + except ValidationError: + await session.initialize() # unparseable result → not modern evidence + return + session.adopt(result) + return + raise AssertionError("unreachable") # pragma: no cover — loop body always returns or raises diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 809fd4172f..4822324c7d 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -12,6 +12,7 @@ from mcp import types from mcp.client._memory import InMemoryTransport +from mcp.client._probe import negotiate_auto from mcp.client._transport import Transport from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT from mcp.client.streamable_http import streamable_http_client @@ -20,13 +21,10 @@ from mcp.server.runner import modern_on_request from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import Dispatcher, ProgressFnT -from mcp.shared.exceptions import MCPDeprecationWarning, MCPError +from mcp.shared.exceptions import MCPDeprecationWarning from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS from mcp.types import ( - INVALID_REQUEST, - METHOD_NOT_FOUND, - REQUEST_TIMEOUT, CallToolResult, CompleteResult, EmptyResult, @@ -183,8 +181,7 @@ async def main(): client_info: Implementation | None = None """Client implementation info to send to server.""" - # TODO(maxisbey): flip default to 'auto' once the in-proc test suite is era-decoupled - # and the probe-timeout fallback is transport-aware (stdio→fallback / HTTP→reject). + # TODO(maxisbey): flip default to 'auto' once the in-proc test suite is era-decoupled. mode: ConnectMode = "legacy" """'legacy' performs the initialize handshake. 'auto' probes server/discover and falls back to initialize() on legacy servers. A modern protocol-version string (e.g. '2026-07-28') adopts that version directly without @@ -250,15 +247,7 @@ async def __aenter__(self) -> Client: if self.mode == "legacy": await session.initialize() elif self.mode == "auto": - try: - await session.discover() - except MCPError as e: - # TODO(L73): invert this allowlist into a `classify_probe_outcome` denylist — - # fall back on every rpc-error/4xx that isn't a recognized modern error. - if e.code in (METHOD_NOT_FOUND, INVALID_REQUEST, REQUEST_TIMEOUT): - await session.initialize() - else: - raise + await negotiate_auto(session) else: session.adopt(self.prior_discover or _synthesize_discover(self.mode)) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index ee09bc5ae4..c4445056a7 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -393,6 +393,35 @@ def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None: self._discover_result = None self._negotiated_version = result.protocol_version + async def send_discover(self, version: str) -> dict[str, Any]: + """Send a single ``server/discover`` at ``version`` and return the raw result dict. + + No retry, no ``adopt()``. The ``_meta`` envelope and the + ``Mcp-Protocol-Version`` header are stamped at ``version`` so the + server-side era router sees a coherent probe. Used by ``discover()`` and + the connect-time auto-negotiation policy. + + Raises: + MCPError: The server returned a JSON-RPC error. + ProbeNotRecognized: The transport bounced the request at its own + layer (HTTP 4xx without a JSON-RPC error body). + """ + client_info = self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True) + capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) + params = { + "_meta": { + PROTOCOL_VERSION_META_KEY: version, + CLIENT_INFO_META_KEY: client_info, + CLIENT_CAPABILITIES_META_KEY: capabilities, + } + } + opts: CallOptions = { + "timeout": DISCOVER_TIMEOUT_SECONDS, + "cancel_on_abandon": False, + "headers": {MCP_PROTOCOL_VERSION_HEADER: version}, + } + return await self._dispatcher.send_raw_request("server/discover", params, opts) + async def discover(self) -> types.DiscoverResult: """Probe `server/discover` and adopt the result. @@ -412,26 +441,8 @@ async def discover(self) -> types.DiscoverResult: if self._discover_result is not None: return self._discover_result - client_info = self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True) - capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) - - async def probe(version: str) -> dict[str, Any]: - params = { - "_meta": { - PROTOCOL_VERSION_META_KEY: version, - CLIENT_INFO_META_KEY: client_info, - CLIENT_CAPABILITIES_META_KEY: capabilities, - } - } - opts: CallOptions = { - "timeout": DISCOVER_TIMEOUT_SECONDS, - "cancel_on_abandon": False, - "headers": {MCP_PROTOCOL_VERSION_HEADER: version}, - } - return await self._dispatcher.send_raw_request("server/discover", params, opts) - try: - raw = await probe(LATEST_MODERN_VERSION) + raw = await self.send_discover(LATEST_MODERN_VERSION) except MCPError as e: if e.code != UNSUPPORTED_PROTOCOL_VERSION: raise @@ -443,7 +454,7 @@ async def probe(version: str) -> dict[str, Any]: mutual = [v for v in MODERN_PROTOCOL_VERSIONS if v in data.supported] if not mutual: raise - raw = await probe(mutual[-1]) + raw = await self.send_discover(mutual[-1]) result = types.DiscoverResult.model_validate(raw) self.adopt(result) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 991cd1e5e1..2e3a360722 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -415,13 +415,14 @@ async def test_client_auto_mode_probes_discover_then_adopts(simple_server: Serve assert (await client.list_resources()).resources[0].name == "Test Resource" -@pytest.mark.parametrize("code", [types.METHOD_NOT_FOUND, types.REQUEST_TIMEOUT]) +@pytest.mark.parametrize("code", [types.METHOD_NOT_FOUND, types.REQUEST_TIMEOUT, types.INTERNAL_ERROR]) async def test_client_auto_mode_falls_back_to_initialize_on_legacy_signal(code: int) -> None: - """`mode='auto'`: when `server/discover` is rejected with -32601 or -32001, - `Client.__aenter__` runs the legacy `initialize()` handshake and lands at a - handshake-era protocol version. The session itself does not fall back — - that policy lives here. A real `Server` always implements `server/discover`, - so the server side is hand-played.""" + """`mode='auto'`: any JSON-RPC error from `server/discover` makes + `Client.__aenter__` run the legacy `initialize()` handshake and land at a + handshake-era protocol version. The denylist policy treats every server-sent + rpc-error as "not modern" — including INTERNAL_ERROR, since a legacy server + may crash on the unknown method before reaching its router. A real `Server` + always implements `server/discover`, so the server side is hand-played.""" methods_seen: list[str] = [] async def scripted_server(streams: MessageStream) -> None: diff --git a/tests/client/test_probe.py b/tests/client/test_probe.py new file mode 100644 index 0000000000..95c3dbac23 --- /dev/null +++ b/tests/client/test_probe.py @@ -0,0 +1,242 @@ +"""Unit tests for the connect-time auto-negotiation policy (`mcp.client._probe.negotiate_auto`). + +`negotiate_auto` is a small policy function that drives a `ClientSession` through the +``server/discover`` probe and decides between ``adopt()`` (modern), ``initialize()`` (legacy +fallback), or letting the probe's exception propagate. The policy is a *denylist*: every +``MCPError`` falls back to ``initialize()``, the sole exception being -32022 with a disjoint +modern-only ``supported`` list. Any non-``MCPError`` exception (network errors, anyio +resource errors) propagates untouched — an outage is never an era verdict. + +These tests pin the classifier in isolation with a stub session; the end-to-end wire shape is +covered by ``tests/interaction/lowlevel/test_client_connect.py``. +""" + +from __future__ import annotations + +from typing import Any, cast + +import anyio +import httpx +import pytest + +from mcp import types +from mcp.client._probe import _parse_supported, negotiate_auto +from mcp.client.session import ClientSession +from mcp.shared.exceptions import MCPError +from mcp.shared.version import ( + HANDSHAKE_PROTOCOL_VERSIONS, + LATEST_MODERN_VERSION, + MODERN_PROTOCOL_VERSIONS, +) +from mcp.types import ( + INTERNAL_ERROR, + INVALID_REQUEST, + METHOD_NOT_FOUND, + PARSE_ERROR, + UNSUPPORTED_PROTOCOL_VERSION, + Implementation, + ServerCapabilities, +) + +pytestmark = pytest.mark.anyio + + +class _StubSession: + """Minimal stand-in for `ClientSession` exposing only what `negotiate_auto` touches. + + `send_discover` plays back a script (raise an exception, or return a dict); + `initialize` and `adopt` just record that they were called. + """ + + def __init__(self, *script: dict[str, Any] | Exception) -> None: + self._script: list[dict[str, Any] | Exception] = list(script) + self.probed_at: list[str] = [] + self.initialized: bool = False + self.adopted: types.DiscoverResult | None = None + + async def send_discover(self, version: str) -> dict[str, Any]: + self.probed_at.append(version) + step = self._script.pop(0) + if isinstance(step, Exception): + raise step + return step + + async def initialize(self) -> None: + self.initialized = True + + def adopt(self, result: types.DiscoverResult) -> None: + self.adopted = result + + +async def _negotiate(session: _StubSession) -> None: + """Drive `negotiate_auto` against the stub; cast at one seam so the tests stay suppression-free.""" + await negotiate_auto(cast("ClientSession", session)) + + +def _discover_dict(versions: list[str] | None = None) -> dict[str, Any]: + return types.DiscoverResult( + supported_versions=versions or list(MODERN_PROTOCOL_VERSIONS), + capabilities=ServerCapabilities(), + server_info=Implementation(name="stub", version="0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True) + + +def _err_32022(supported: Any) -> MCPError: + return MCPError( + code=UNSUPPORTED_PROTOCOL_VERSION, + message="unsupported protocol version", + data={"supported": supported, "requested": LATEST_MODERN_VERSION}, + ) + + +# --- happy path: modern server --- + + +async def test_a_valid_discover_result_is_adopted_without_initializing() -> None: + """A parseable `DiscoverResult` from the probe is adopted; `initialize()` is never called.""" + session = _StubSession(_discover_dict()) + await _negotiate(session) + assert session.adopted is not None + assert session.adopted.server_info.name == "stub" + assert not session.initialized + assert session.probed_at == [LATEST_MODERN_VERSION] + + +async def test_an_unparseable_discover_result_falls_back_to_initialize() -> None: + """A probe response that does not validate as `DiscoverResult` is not modern evidence, + so the policy falls back to the legacy handshake instead of adopting garbage.""" + session = _StubSession({"not": "a discover result"}) + await _negotiate(session) + assert session.initialized + assert session.adopted is None + + +# --- the denylist: every JSON-RPC error code falls back --- + + +@pytest.mark.parametrize( + "code", + [ + pytest.param(METHOD_NOT_FOUND, id="method-not-found-32601"), + pytest.param(INVALID_REQUEST, id="invalid-request-32600"), + pytest.param(INTERNAL_ERROR, id="internal-error-32603"), + pytest.param(PARSE_ERROR, id="parse-error-32700"), + ], +) +async def test_any_jsonrpc_error_from_the_probe_falls_back_to_initialize(code: int) -> None: + """The denylist: every server-sent JSON-RPC error code is treated as "not modern" and + triggers the legacy `initialize()` handshake. Legacy servers reject the unknown + ``server/discover`` method with various codes (-32601, -32600, -32603, -32700) depending + on where in their pipeline the request bounces.""" + session = _StubSession(MCPError(code=code, message="nope")) + await _negotiate(session) + assert session.initialized + assert session.adopted is None + assert session.probed_at == [LATEST_MODERN_VERSION] + + +# --- -32022 corrective retry --- + + +async def test_unsupported_version_with_a_mutual_modern_version_retries_once_then_adopts() -> None: + """-32022 with a `supported` list naming a modern version we speak: re-probe once at + the highest mutual version, then adopt the second response.""" + session = _StubSession(_err_32022(list(MODERN_PROTOCOL_VERSIONS)), _discover_dict()) + await _negotiate(session) + assert session.probed_at == [LATEST_MODERN_VERSION, MODERN_PROTOCOL_VERSIONS[-1]] + assert session.adopted is not None + assert not session.initialized + + +async def test_unsupported_version_naming_only_handshake_versions_falls_back_to_initialize() -> None: + """-32022 with `supported` naming only handshake-era versions: the server is reachable + via the legacy handshake, so fall back rather than raise.""" + session = _StubSession(_err_32022(list(HANDSHAKE_PROTOCOL_VERSIONS))) + await _negotiate(session) + assert session.initialized + assert session.adopted is None + assert session.probed_at == [LATEST_MODERN_VERSION] + + +async def test_unsupported_version_with_disjoint_modern_only_supported_reraises() -> None: + """-32022 with `supported` naming only modern versions we *don't* speak: this is the + one denylist exception — the server is modern-only and there is no mutual version, so + falling back to `initialize()` would also fail. The original `MCPError` re-raises.""" + session = _StubSession(_err_32022(["2099-01-01"])) + with pytest.raises(MCPError) as exc_info: + await _negotiate(session) + assert exc_info.value.code == UNSUPPORTED_PROTOCOL_VERSION + assert not session.initialized + assert session.adopted is None + + +@pytest.mark.parametrize( + "data", + [ + pytest.param(None, id="no-data"), + pytest.param({"supported": "not-a-list"}, id="malformed-supported"), + pytest.param({"requested": LATEST_MODERN_VERSION}, id="missing-supported"), + ], +) +async def test_unsupported_version_with_unparseable_data_falls_back_to_initialize(data: Any) -> None: + """-32022 with no/malformed `error.data`: nothing actionable, so fall through to the + denylist's `initialize()` fallback rather than guess or raise.""" + session = _StubSession(MCPError(code=UNSUPPORTED_PROTOCOL_VERSION, message="bad version", data=data)) + await _negotiate(session) + assert session.initialized + assert session.adopted is None + assert session.probed_at == [LATEST_MODERN_VERSION] + + +async def test_a_second_unsupported_version_after_the_corrective_retry_does_not_loop() -> None: + """The corrective -32022 retry happens at most once; a second -32022 naming a + modern-only `supported` list re-raises rather than re-probing forever (the loop + guard makes this the disjoint-modern case on attempt two).""" + session = _StubSession(_err_32022(list(MODERN_PROTOCOL_VERSIONS)), _err_32022(list(MODERN_PROTOCOL_VERSIONS))) + with pytest.raises(MCPError) as exc_info: + await _negotiate(session) + assert exc_info.value.code == UNSUPPORTED_PROTOCOL_VERSION + assert session.probed_at == [LATEST_MODERN_VERSION, MODERN_PROTOCOL_VERSIONS[-1]] + assert not session.initialized + assert session.adopted is None + + +# --- non-MCP errors propagate --- + + +@pytest.mark.parametrize( + "exc", + [ + pytest.param(httpx.ConnectError("connection refused"), id="httpx-connect-error"), + pytest.param(anyio.ClosedResourceError(), id="anyio-closed-resource"), + ], +) +async def test_a_network_or_resource_error_from_the_probe_propagates_unchanged(exc: Exception) -> None: + """Anything that is not an `MCPError` propagates as-is; an outage or in-process bug + is never an era verdict, and `initialize()` is not called.""" + session = _StubSession(exc) + with pytest.raises(type(exc)): + await _negotiate(session) + assert not session.initialized + assert session.adopted is None + + +# --- helper --- + + +@pytest.mark.parametrize( + ("data", "expected"), + [ + ({"supported": ["2026-07-28"], "requested": "x"}, ["2026-07-28"]), + ({"supported": [], "requested": "x"}, []), + (None, None), + ({"supported": 123, "requested": "x"}, None), + ("not a dict", None), + ], +) +def test_parse_supported_returns_none_for_anything_not_shaped_like_the_spec_error_data( + data: Any, expected: list[str] | None +) -> None: + """`_parse_supported` returns the `supported` list when `error.data` validates as + `UnsupportedProtocolVersionErrorData`, and `None` otherwise — never raises.""" + assert _parse_supported(data) == expected diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 7cad27b73a..a2b4178054 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -430,17 +430,18 @@ def __post_init__(self) -> None: "lifecycle:discover:fallback-method-not-found": Requirement( source=f"{SPEC_2026_BASE_URL}/basic/transports/stdio#backward-compatibility", behavior=( - "When server/discover returns -32601 (or HTTP 404), an auto-negotiating client falls back to " - "the legacy initialize handshake and the connection succeeds at a handshake-era version." + "When server/discover returns any JSON-RPC error or a bare HTTP 4xx, an auto-negotiating " + "client falls back to the legacy initialize handshake and the connection succeeds at a " + "handshake-era version (legacy servers reject the probe with various codes)." ), added_in="2026-07-28", ), "lifecycle:discover:network-error-raises": Requirement( source="sdk", behavior=( - "A network/connection error or 5xx during server/discover raises to the caller without " - "falling back to initialize. A 4xx with a JSON-RPC error body is a server-side rejection " - "and falls back (legacy servers reject the probe with 400 INVALID_REQUEST)." + "A network/connection error during server/discover propagates to the caller without " + "falling back to initialize; any rpc-error or 4xx falls back (legacy servers reject the " + "probe with various codes). An outage is never an era verdict." ), transports=("streamable-http", "streamable-http-stateless"), added_in="2026-07-28", diff --git a/tests/interaction/lowlevel/test_client_connect.py b/tests/interaction/lowlevel/test_client_connect.py index f2c6b6393a..33c988e14e 100644 --- a/tests/interaction/lowlevel/test_client_connect.py +++ b/tests/interaction/lowlevel/test_client_connect.py @@ -31,7 +31,6 @@ from mcp.types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, - INTERNAL_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, PROTOCOL_VERSION_META_KEY, @@ -217,31 +216,27 @@ async def discover(ctx: ServerRequestContext, params: types.RequestParams | None @requirement("lifecycle:discover:network-error-raises") -async def test_auto_mode_reraises_a_non_fallback_discover_error_without_initializing() -> None: - """A `server/discover` failure that is not a recognised legacy-server rejection raises without falling back. - - Requirement `lifecycle:discover:network-error-raises` (sdk-defined): a 5xx-class error from - the probe is surfaced to the caller; the client never sends `initialize`. Exercised here as - the JSON-RPC INTERNAL_ERROR branch (which the modern HTTP entry maps to a 5xx). The error +async def test_auto_mode_propagates_a_network_error_from_discover_without_initializing() -> None: + """A network/connection error during `server/discover` propagates to the caller without falling back. + + Requirement `lifecycle:discover:network-error-raises` (sdk-defined): under the denylist policy + every server-sent rpc-error and every transport-layer 4xx falls back to `initialize()`; the + only probe failures that reach the caller are real outages — network errors, anyio resource + errors, and the disjoint-modern -32022 case. Exercised here as an `httpx.ConnectError` from + the underlying transport, which the policy must not classify as an era verdict. The error reaches the test wrapped in the streamable-http transport's task-group teardown, so - `pytest.RaisesGroup` flattens before matching. + `pytest.RaisesGroup` flattens before matching. The probe POST is recorded before the + transport raises, so the `initialize` fallback observably did not happen. """ + requests: list[httpx.Request] = [] - async def discover(ctx: ServerRequestContext, params: types.RequestParams | None) -> DiscoverResult: - raise MCPError(code=INTERNAL_ERROR, message="storage unavailable") - - server = _tools_server() - server.add_request_handler("server/discover", types.RequestParams, discover) - requests, on_request = _request_recorder() - - def is_internal_error(exc: MCPError) -> bool: - return exc.code == INTERNAL_ERROR + def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + raise httpx.ConnectError("connection refused") with anyio.fail_after(5): - async with mounted_app(server, on_request=on_request) as (http, _): - with pytest.RaisesGroup( - pytest.RaisesExc(MCPError, check=is_internal_error), flatten_subgroups=True - ): # pragma: no branch + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http: + with pytest.RaisesGroup(httpx.ConnectError, flatten_subgroups=True): # pragma: no branch async with Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http), mode="auto"): raise NotImplementedError("entering the Client should have raised") # pragma: no cover From 2d12b96b4d9eee482ffd09c1c3ac111af1821602 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 13:16:21 +0000 Subject: [PATCH 20/22] Flip Client default mode to 'auto' MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The default Client(...) now probes server/discover and falls back to initialize() (via negotiate_auto's denylist). For an in-process Server, the default path is now DirectDispatcher per-request rather than the legacy InMemoryTransport stream loop. DirectDispatcher gains raise_handler_exceptions (matching JSONRPCDispatcher's knob): True chains the original exception via __cause__; False sanitizes to MCPError(INTERNAL_ERROR). modern_on_request collapses to a pure envelope-builder (no exception ladder of its own), and Client threads raise_exceptions into create_direct_dispatcher_pair. Tests that exercise legacy-specific semantics — server-initiated sampling/elicitation push, message_handler delivery, ping, InMemoryTransport mechanics, JSON-RPC wire-shape recording — are pinned to mode='legacy' explicitly (~64 sites across 26 test files plus the client_via_http / connect_over_sse / auth-harness helpers). These are census-driven, not failure-driven: ~23 sites would have passed under 'auto' but silently stopped testing their subject. Client.send_ping() is deprecated (ping is removed from 2026-07-28); it only works under mode='legacy'. docs/migration.md gains a section explaining the default change and when to pin mode='legacy'; docs/testing.md notes the same for test authors. --- docs/migration.md | 12 ++++++-- docs/testing.md | 5 ++++ pyproject.toml | 3 ++ src/mcp/client/client.py | 23 +++++++++------ src/mcp/server/runner.py | 22 +++------------ src/mcp/shared/direct_dispatcher.py | 28 +++++++++++++++---- tests/client/test_client.py | 25 +++++++++-------- tests/client/test_list_methods_cursor.py | 2 +- tests/client/test_list_roots_callback.py | 4 +-- tests/client/test_logging_callback.py | 1 + tests/client/test_sampling_callback.py | 6 ++-- tests/client/test_session.py | 23 +++++++++++++-- tests/client/test_session_concurrency.py | 2 +- tests/client/transports/test_memory.py | 6 ++-- tests/interaction/_connect.py | 5 ++++ tests/interaction/auth/_harness.py | 3 +- tests/interaction/lowlevel/test_ping.py | 2 +- tests/interaction/lowlevel/test_timeouts.py | 8 +++--- tests/interaction/lowlevel/test_wire.py | 10 +++---- .../transports/test_client_transport_http.py | 4 +-- .../transports/test_legacy_wire.py | 2 +- tests/interaction/transports/test_sse.py | 4 +-- tests/interaction/transports/test_stdio.py | 2 +- tests/server/mcpserver/test_elicitation.py | 4 +-- tests/server/mcpserver/test_integration.py | 8 +++--- tests/server/mcpserver/test_server.py | 6 ++-- tests/server/mcpserver/test_title.py | 10 +++---- .../server/mcpserver/test_url_elicitation.py | 22 +++++++-------- tests/server/test_cancel_handling.py | 2 +- tests/server/test_completion_with_context.py | 2 +- tests/server/test_streamable_http_manager.py | 2 +- tests/shared/test_jsonrpc_dispatcher.py | 2 +- tests/shared/test_otel.py | 2 +- 33 files changed, 155 insertions(+), 107 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index f35721dcc9..fb82495350 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -352,7 +352,15 @@ version = session.protocol_version The raw handshake result is also retained: `session.initialize_result` is set after `initialize()` (≤2025-11-25 servers — including `stateless_http=True` servers, which still answer `initialize`); `session.discover_result` is set after `discover()` (2026-07-28+ servers). At most one is non-`None`. -On the high-level `Client`, `client.server_capabilities`, `client.server_info`, and `client.protocol_version` are non-nullable inside the context manager. `client.instructions` remains `str | None` since the server may omit it. (The lowlevel `ClientSession` still lets you call methods before any handshake, as in v1; `Client` always handshakes on enter.) +On the high-level `Client`, `client.server_capabilities`, `client.server_info`, and `client.protocol_version` are non-nullable inside the context manager. `client.instructions` remains `str | None` since the server may omit it. (The lowlevel `ClientSession` still lets you call methods before any handshake, as in v1; `Client` always connects on enter — by default it probes `server/discover` and falls back to the initialize handshake.) + +### `Client` defaults to `mode='auto'` + +In v1, connecting to a server always performed the `initialize` handshake. In v2, `Client` defaults to `mode='auto'`: on enter it probes `server/discover` and, if the server doesn't support it, falls back to the `initialize` handshake. Pass `mode='legacy'` to force the initialize handshake and reproduce v1's byte-identical pre-2026 behavior, or pass a modern protocol-version string (e.g. `mode='2026-07-28'`) to pin a version without probing. + +For an in-process `Client(server)` (where `server` is a `Server` or `MCPServer` instance), `mode='auto'` dispatches calls directly through `DirectDispatcher` with no JSON-RPC framing. Pass `mode='legacy'` if you need the in-memory JSON-RPC transport that v1 used. + +`Client.send_ping()` is deprecated (ping is removed in 2026-07-28); pin `mode='legacy'` if you need it. ### `McpError` renamed to `MCPError` @@ -832,7 +840,7 @@ async with Client(server) as client: result = await client.call_tool("my_tool", {"x": 1}) ``` -`Client` accepts the same callback parameters the old helper did (`sampling_callback`, `list_roots_callback`, `logging_callback`, `message_handler`, `elicitation_callback`, `client_info`) plus `raise_exceptions` to surface server-side errors. +`Client` accepts the same callback parameters the old helper did (`sampling_callback`, `list_roots_callback`, `logging_callback`, `message_handler`, `elicitation_callback`, `client_info`) plus `raise_exceptions` to surface server-side errors and `mode` to control version negotiation (`'auto'` by default; `'legacy'` reproduces v1's initialize-only handshake). If you need direct access to the underlying `ClientSession` and memory streams (e.g., for low-level transport testing), `create_client_server_memory_streams` is still available in `mcp.shared.memory`: diff --git a/docs/testing.md b/docs/testing.md index 9a222c9067..c1d263b763 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -74,4 +74,9 @@ async def test_call_add_tool(client: Client): 1. If you are using `trio`, you should set `"trio"` as the `anyio_backend`. Check more information in the [anyio documentation](https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on). 2. The `client` fixture creates a connected client that can be reused across multiple tests. +!!! note + `Client(app)` connects in-process and is era-neutral by default — it probes the server and picks the + appropriate protocol path. Pin `mode='legacy'` if your test exercises legacy-specific semantics + (sampling/elicitation push, `message_handler`). + There you go! You can now extend your tests to cover more scenarios. diff --git a/pyproject.toml b/pyproject.toml index e02c727be2..a36d152dc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,6 +220,9 @@ filterwarnings = [ # 2026-07-28 restricts progress to server->client; the client send path is # advisory-deprecated and a handful of tests still exercise it. "ignore:Client-to-server progress is deprecated as of 2026-07-28.*:mcp.MCPDeprecationWarning", + # 2026-07-28 drops ping; Client.send_ping() is advisory-deprecated and the + # legacy interaction/transport tests still drive it. + "ignore:ping is removed as of 2026-07-28.*:mcp.MCPDeprecationWarning", ] [tool.markdown.lint] diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 4822324c7d..1ab8209b18 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -78,10 +78,10 @@ async def connect(exit_stack: AsyncExitStack, mode: ConnectMode, raise_exception read_stream, write_stream = await exit_stack.enter_async_context(transport) return JSONRPCDispatcher(read_stream, write_stream) lifespan_state = await exit_stack.enter_async_context(server.lifespan(server)) - client_disp, server_disp = create_direct_dispatcher_pair() + client_disp, server_disp = create_direct_dispatcher_pair(raise_handler_exceptions=raise_exceptions) tg = await exit_stack.enter_async_context(anyio.create_task_group()) exit_stack.callback(server_disp.close) - on_request = modern_on_request(server, lifespan_state, raise_exceptions=raise_exceptions) + on_request = modern_on_request(server, lifespan_state) await tg.start(server_disp.run, on_request, _no_inbound_client_notifications) return client_disp @@ -151,7 +151,7 @@ async def main(): server: Server[Any] | MCPServer | Transport | str """The MCP server to connect to. - If the server is a `Server` or `MCPServer` instance, it will be wrapped in an `InMemoryTransport`. + If the server is a `Server` or `MCPServer` instance, it will be connected in-process. If the server is a URL string, it will be used as the URL for a `streamable_http_client` transport. If the server is a `Transport` instance, it will be used directly. """ @@ -181,11 +181,14 @@ async def main(): client_info: Implementation | None = None """Client implementation info to send to server.""" - # TODO(maxisbey): flip default to 'auto' once the in-proc test suite is era-decoupled. - mode: ConnectMode = "legacy" - """'legacy' performs the initialize handshake. 'auto' probes server/discover and falls back to initialize() - on legacy servers. A modern protocol-version string (e.g. '2026-07-28') adopts that version directly without - a handshake — supply prior_discover to reuse a known DiscoverResult, or omit it to synthesize a minimal one.""" + mode: ConnectMode = "auto" + """How to negotiate the protocol version. + + 'auto' (the default) probes `server/discover` and falls back to the initialize handshake on legacy servers; + for an in-process `Server`/`MCPServer` it dispatches directly without JSON-RPC framing. 'legacy' forces the + initialize handshake (byte-identical pre-2026 behavior). A modern protocol-version string (e.g. '2026-07-28') + adopts that version directly without a probe — supply `prior_discover` to reuse a known DiscoverResult, or + omit it to synthesize a minimal one.""" prior_discover: types.DiscoverResult | None = None """A previously-obtained DiscoverResult to install via .adopt() when mode is a version pin. @@ -301,6 +304,10 @@ def instructions(self) -> str | None: """Server-provided instructions text, if any.""" return self.session.instructions + @deprecated( + "ping is removed as of 2026-07-28; the method only works under mode='legacy'.", + category=MCPDeprecationWarning, + ) async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> EmptyResult: """Send a ping request to the server.""" return await self.session.send_ping(meta=meta) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 3fa8b3bc79..e71a8a6e7a 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -496,18 +496,14 @@ async def serve_one( await aclose_shielded(connection) -def modern_on_request( - server: Server[LifespanT], lifespan_state: LifespanT, *, raise_exceptions: bool = False -) -> OnRequest: +def modern_on_request(server: Server[LifespanT], lifespan_state: LifespanT) -> OnRequest: """Return an `OnRequest` callback that serves each call via `serve_one` with a fresh per-request `Connection`. Wire this into the server side of a `DirectDispatcher` peer-pair to drive an in-process server on the modern per-request-envelope path (each request carries protocol version, client info, and capabilities in `params._meta`; - no `initialize` handshake). ``raise_exceptions`` lets unmapped handler - exceptions propagate to the caller for debuggable in-process testing; - otherwise they are sanitized to `MCPError(INTERNAL_ERROR)` so the in-process - path matches the wire path's leak guard. + no `initialize` handshake). Like `serve_one`, this raises whatever the + handler chain raises - the dispatcher owns the exception-to-error mapping. """ async def handle( @@ -519,16 +515,6 @@ async def handle( meta.get(CLIENT_INFO_META_KEY), meta.get(CLIENT_CAPABILITIES_META_KEY), ) - try: - return await serve_one(server, dctx, method, params, connection=connection, lifespan_state=lifespan_state) - except (MCPError, ValidationError): - # DirectDispatcher's ladder maps these onward; this layer only owns the raise_exceptions - # decision for unmapped exceptions, which DirectDispatcher would otherwise leak via str(exc). - raise - except Exception: - if raise_exceptions: - raise - logger.exception("request handler raised") - raise MCPError(code=INTERNAL_ERROR, message="Internal server error") from None + return await serve_one(server, dctx, method, params, connection=connection, lifespan_state=lifespan_state) return handle diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index d521840bef..de99739b1f 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -9,8 +9,10 @@ (`ServerRunner`, `Context`, `Connection`) without wire-level moving parts * embed a server in-process when the JSON-RPC overhead is unnecessary -Unlike `JSONRPCDispatcher`, exceptions raised in a handler propagate directly -to the caller - there is no exception-to-`ErrorData` boundary here. +Like `JSONRPCDispatcher`, this is an exception-to-error boundary: a handler +exception surfaces to the caller as `MCPError`. The `raise_handler_exceptions` +knob controls whether unmapped exceptions are sanitized (matching the wire +path) or chained as ``__cause__`` for in-process debugging. """ from __future__ import annotations @@ -96,8 +98,9 @@ class DirectDispatcher: they are silently dropped. """ - def __init__(self, transport_ctx: TransportContext): + def __init__(self, transport_ctx: TransportContext, *, raise_handler_exceptions: bool = True): self._transport_ctx = transport_ctx + self._raise_handler_exceptions = raise_handler_exceptions self._peer: DirectDispatcher | None = None self._on_request: OnRequest | None = None self._on_notify: OnNotify | None = None @@ -235,7 +238,14 @@ async def _dispatch_request( # tests see what runner-over-JSONRPC would. raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="") from e except Exception as e: - raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e + # Single owner of the in-proc exception-to-error policy (mirrors + # JSONRPCDispatcher / `_streamable_http_modern._to_jsonrpc_response` + # for the wire paths). True chains the original for in-process + # debugging; False sanitizes to match the wire path's leak guard. + if self._raise_handler_exceptions: + raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e + logger.exception("request handler raised") + raise MCPError(code=INTERNAL_ERROR, message="Internal server error") from None except TimeoutError: raise MCPError( code=REQUEST_TIMEOUT, @@ -259,6 +269,7 @@ def create_direct_dispatcher_pair( *, can_send_request: bool = True, headers: Mapping[str, str] | None = None, + raise_handler_exceptions: bool = True, ) -> tuple[DirectDispatcher, DirectDispatcher]: """Create two `DirectDispatcher` instances wired to each other. @@ -266,14 +277,19 @@ def create_direct_dispatcher_pair( can_send_request: Sets `TransportContext.can_send_request` on both sides. Pass `False` to simulate a transport with no back-channel. headers: Sets `TransportContext.headers` on both sides. + raise_handler_exceptions: When `True` (the default - this is an + in-process debugging substrate), an unmapped handler exception + reaches the caller as `MCPError` with the original chained as + ``__cause__``. When `False` it is sanitized to an opaque + `INTERNAL_ERROR` so the in-process path matches the wire. Returns: A `(client, server)` pair. The wiring is symmetric, so the roles are conventional only. """ ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request, headers=headers) - client = DirectDispatcher(ctx) - server = DirectDispatcher(ctx) + client = DirectDispatcher(ctx, raise_handler_exceptions=raise_handler_exceptions) + server = DirectDispatcher(ctx, raise_handler_exceptions=raise_handler_exceptions) client.connect_to(server) server.connect_to(client) return client, server diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 2e3a360722..792bdc8a7c 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -107,7 +107,7 @@ def greeting_prompt(name: str) -> str: async def test_client_is_initialized(app: MCPServer): """Test that the client is initialized after entering context.""" - async with Client(app) as client: + async with Client(app, mode="legacy") as client: assert client.server_capabilities == snapshot( ServerCapabilities( experimental={}, @@ -121,7 +121,7 @@ async def test_client_is_initialized(app: MCPServer): async def test_client_exposes_negotiated_protocol_version(app: MCPServer): """The negotiated protocol version is readable after initialization.""" - async with Client(app) as client: + async with Client(app, mode="legacy") as client: assert client.protocol_version == LATEST_HANDSHAKE_VERSION @@ -137,8 +137,8 @@ async def test_client_with_simple_server(simple_server: Server): async def test_client_send_ping(app: MCPServer): - async with Client(app) as client: - result = await client.send_ping() + async with Client(app, mode="legacy") as client: + result = await client.send_ping() # pyright: ignore[reportDeprecated] assert result == snapshot(EmptyResult()) @@ -278,27 +278,28 @@ async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotif server = Server(name="test_server", on_progress=handle_progress) - async with Client(server) as client: - await client.send_progress_notification(progress_token="token123", progress=50.0) # pyright: ignore[reportDeprecated] - await event.wait() - assert received_from_client == snapshot({"progress_token": "token123", "progress": 50.0}) + with anyio.fail_after(5): + async with Client(server, mode="legacy") as client: + await client.send_progress_notification(progress_token="token123", progress=50.0) # pyright: ignore[reportDeprecated] + await event.wait() + assert received_from_client == snapshot({"progress_token": "token123", "progress": 50.0}) async def test_client_subscribe_resource(simple_server: Server): - async with Client(simple_server) as client: + async with Client(simple_server, mode="legacy") as client: result = await client.subscribe_resource("memory://test") assert result == snapshot(EmptyResult()) async def test_client_unsubscribe_resource(simple_server: Server): - async with Client(simple_server) as client: + async with Client(simple_server, mode="legacy") as client: result = await client.unsubscribe_resource("memory://test") assert result == snapshot(EmptyResult()) async def test_client_set_logging_level(simple_server: Server): """Test setting logging level.""" - async with Client(simple_server) as client: + async with Client(simple_server, mode="legacy") as client: result = await client.set_logging_level("debug") # pyright: ignore[reportDeprecated] assert result == snapshot(EmptyResult()) @@ -361,7 +362,7 @@ def test_client_with_url_initializes_streamable_http_transport(): async def test_client_uses_transport_directly(app: MCPServer): transport = InMemoryTransport(app) - async with Client(transport) as client: + async with Client(transport, mode="legacy") as client: result = await client.call_tool("greet", {"name": "Transport"}) assert result == snapshot( CallToolResult( diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index e7e63304fc..30955d4308 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -64,7 +64,7 @@ async def test_list_methods_params_parameter( See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format """ - async with Client(full_featured_server) as client: + async with Client(full_featured_server, mode="legacy") as client: spies = stream_spy() # Test without params (omitted) diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 423ab967bb..72119a8cfc 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -31,7 +31,7 @@ async def test_list_roots(context: Context, message: str): return True # Test with list_roots callback - async with Client(server, list_roots_callback=list_roots_callback) as client: + async with Client(server, list_roots_callback=list_roots_callback, mode="legacy") as client: # Make a request to trigger sampling callback result = await client.call_tool("test_list_roots", {"message": "test message"}) assert result.is_error is False @@ -41,7 +41,7 @@ async def test_list_roots(context: Context, message: str): # Without a list_roots callback the client responds with an MCPError, which the # tool body doesn't catch — the wrapper re-raises it as a top-level JSON-RPC # error rather than wrapping it as an isError result. - async with Client(server) as client: + async with Client(server, mode="legacy") as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("test_list_roots", {"message": "test message"}) assert exc_info.value.error.code == INVALID_REQUEST diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 7a870bcd55..0b7e363a17 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -64,6 +64,7 @@ async def message_handler( server, logging_callback=logging_collector, message_handler=message_handler, + mode="legacy", ) as client: # First verify our test tool works result = await client.call_tool("test_tool", {}) diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 901caa69f8..af255ee249 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -42,7 +42,7 @@ async def test_sampling_tool(message: str, ctx: Context) -> bool: return True # Test with sampling callback - async with Client(server, sampling_callback=sampling_callback) as client: + async with Client(server, sampling_callback=sampling_callback, mode="legacy") as client: # Make a request to trigger sampling callback result = await client.call_tool("test_sampling", {"message": "Test message for sampling"}) assert result.is_error is False @@ -52,7 +52,7 @@ async def test_sampling_tool(message: str, ctx: Context) -> bool: # Without a sampling callback the client responds with an MCPError, which the # tool body doesn't catch — the wrapper re-raises it as a top-level JSON-RPC # error rather than wrapping it as an isError result. - async with Client(server) as client: + async with Client(server, mode="legacy") as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("test_sampling", {"message": "Test message for sampling"}) assert exc_info.value.error.code == INVALID_REQUEST @@ -93,7 +93,7 @@ async def test_tool(message: str, ctx: Context) -> bool: assert not hasattr(result, "content_as_list") or not callable(getattr(result, "content_as_list", None)) return True - async with Client(server, sampling_callback=sampling_callback) as client: + async with Client(server, sampling_callback=sampling_callback, mode="legacy") as client: result = await client.call_tool("test_backwards_compat", {"message": "Test"}) assert result.is_error is False assert isinstance(result.content[0], TextContent) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 933171eabd..c24a4569c5 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1478,9 +1478,7 @@ async def send_raw_request( return item async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: - self.notifies.append( - method - ) # pragma: no cover — recorded so a wrongly-sent notification fails the == [] assert + self.notifies.append(method) def _discover_result_dict() -> dict[str, Any]: @@ -1491,6 +1489,25 @@ def _discover_result_dict() -> dict[str, Any]: ).model_dump(by_alias=True, mode="json", exclude_none=True) +@pytest.mark.anyio +async def test_initialize_is_idempotent_and_returns_the_cached_result() -> None: + """A second `initialize()` returns the first call's result by identity and sends nothing + over the wire — the early-return guard short-circuits before the dispatcher is touched.""" + init_result = InitializeResult( + protocol_version=LATEST_HANDSHAKE_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True) + dispatcher = _ScriptedDispatcher(init_result) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + first = await session.initialize() + second = await session.initialize() + assert first is second + assert [method for method, _ in dispatcher.calls] == ["initialize"] + assert dispatcher.notifies == ["notifications/initialized"] + + @pytest.mark.anyio async def test_discover_adopts_the_returned_result_and_installs_the_modern_stamp() -> None: """SDK-defined: a successful `server/discover` is adopted and subsequent requests diff --git a/tests/client/test_session_concurrency.py b/tests/client/test_session_concurrency.py index 7072325104..512a876ee9 100644 --- a/tests/client/test_session_concurrency.py +++ b/tests/client/test_session_concurrency.py @@ -121,7 +121,7 @@ async def sampling_callback( stop_reason="endTurn", ) - async with Client(server, sampling_callback=sampling_callback) as client: + async with Client(server, sampling_callback=sampling_callback, mode="legacy") as client: with anyio.fail_after(5): async with anyio.create_task_group() as task_group: # pragma: no branch diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 51a026c138..da9b95f721 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -75,13 +75,13 @@ async def test_with_mcpserver(mcpserver_server: MCPServer): async def test_server_is_running(mcpserver_server: MCPServer): """Test that the server is running and responding to requests.""" - async with Client(mcpserver_server) as client: + async with Client(mcpserver_server, mode="legacy") as client: assert client.server_capabilities.tools is not None async def test_list_tools(mcpserver_server: MCPServer): """Test listing tools through the transport.""" - async with Client(mcpserver_server) as client: + async with Client(mcpserver_server, mode="legacy") as client: tools_result = await client.list_tools() assert len(tools_result.tools) > 0 tool_names = [t.name for t in tools_result.tools] @@ -90,7 +90,7 @@ async def test_list_tools(mcpserver_server: MCPServer): async def test_call_tool(mcpserver_server: MCPServer): """Test calling a tool through the transport.""" - async with Client(mcpserver_server) as client: + async with Client(mcpserver_server, mode="legacy") as client: result = await client.call_tool("greet", {"name": "World"}) assert result is not None assert len(result.content) > 0 diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index a9383837d2..bb8e75ed43 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -238,6 +238,9 @@ async def client_via_http( transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) async with Client( transport, + # Callers assert the legacy HTTP wire shape (session-id header, standalone GET stream, + # closing DELETE); the modern flow is sessionless and would silently change the subject. + mode="legacy", logging_callback=logging_callback, message_handler=message_handler, elicitation_callback=elicitation_callback, @@ -378,6 +381,8 @@ def httpx_client_factory( transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) async with Client( transport, + # SSE is a legacy-only transport; the modern path has no SSE story. + mode="legacy", read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, list_roots_callback=list_roots_callback, diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py index ab360addd4..b7ffc78c86 100644 --- a/tests/interaction/auth/_harness.py +++ b/tests/interaction/auth/_harness.py @@ -470,6 +470,7 @@ async def hook(request: httpx.Request) -> None: ) headless.bind(http_client) client = await stack.enter_async_context( - Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)) + # The auth flow tests snapshot the legacy initialize-handshake HTTP shape. + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), mode="legacy") ) yield client, headless diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py index 797e20dc35..e65bbe3ae3 100644 --- a/tests/interaction/lowlevel/test_ping.py +++ b/tests/interaction/lowlevel/test_ping.py @@ -19,7 +19,7 @@ async def test_client_ping_returns_empty_result(connect: Connect) -> None: server = Server("silent") async with connect(server) as client: - result = await client.send_ping() + result = await client.send_ping() # pyright: ignore[reportDeprecated] assert result == snapshot(EmptyResult()) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index d73e66f32f..044c055032 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -48,7 +48,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("blocker", on_call_tool=call_tool) - async with Client(server) as client: + async with Client(server, mode="legacy") as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("block", {}, read_timeout_seconds=0.000001) @@ -106,7 +106,7 @@ async def sampling_callback( await release.wait() return types.CreateMessageResult(role="assistant", content=TextContent(text="too late"), model="test-model") - async with Client(recording, sampling_callback=sampling_callback) as client: + async with Client(recording, mode="legacy", sampling_callback=sampling_callback) as client: result = await client.call_tool("impatient", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="gave up")])) @@ -147,7 +147,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with Client(server, mode="legacy") as client: with pytest.raises(MCPError): await client.call_tool("block", {}, read_timeout_seconds=0.000001) @@ -179,7 +179,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("blocker", on_call_tool=call_tool) - async with Client(server, read_timeout_seconds=0.05) as client: + async with Client(server, mode="legacy", read_timeout_seconds=0.05) as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("block", {}) diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index ace780d7ec..d8f9a530bd 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -65,11 +65,11 @@ async def test_request_ids_are_unique_and_never_null() -> None: """ recording = RecordingTransport(InMemoryTransport(_echo_server())) - async with Client(recording) as client: + async with Client(recording, mode="legacy") as client: await client.list_tools() await client.call_tool("echo", {}) await client.call_tool("echo", {}) - await client.send_ping() + await client.send_ping() # pyright: ignore[reportDeprecated] sent = [message.message for message in recording.sent] request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] @@ -95,9 +95,9 @@ async def list_roots(context: ClientRequestContext) -> ListRootsResult: recording = RecordingTransport(InMemoryTransport(_echo_server())) - async with Client(recording, list_roots_callback=list_roots) as client: + async with Client(recording, mode="legacy", list_roots_callback=list_roots) as client: await client.send_roots_list_changed() # pyright: ignore[reportDeprecated] - await client.send_ping() + await client.send_ping() # pyright: ignore[reportDeprecated] sent = [message.message for message in recording.sent] sent_request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] @@ -134,7 +134,7 @@ async def test_exactly_one_initialized_notification_is_sent_after_the_handshake( """ recording = RecordingTransport(InMemoryTransport(_echo_server())) - async with Client(recording) as client: + async with Client(recording, mode="legacy") as client: await client.list_tools() sent_methods = [ diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py index 65ed03f1e4..173371d75d 100644 --- a/tests/interaction/transports/test_client_transport_http.py +++ b/tests/interaction/transports/test_client_transport_http.py @@ -181,7 +181,7 @@ async def filter_methods(scope: Scope, receive: Receive, send: Send) -> None: ): transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) with anyio.fail_after(5): # pragma: no branch - async with Client(transport) as client: # pragma: no branch + async with Client(transport, mode="legacy") as client: # pragma: no branch result = await client.list_tools() assert [tool.name for tool in result.tools] == ["echo"] @@ -240,7 +240,7 @@ async def first_post_then_404(scope: Scope, receive: Receive, send: Send) -> Non ): transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) with anyio.fail_after(5): # pragma: no branch - async with Client(transport) as client: # pragma: no branch + async with Client(transport, mode="legacy") as client: # pragma: no branch with pytest.raises(MCPError) as exc_info: # pragma: no branch await client.list_tools() diff --git a/tests/interaction/transports/test_legacy_wire.py b/tests/interaction/transports/test_legacy_wire.py index b65a50759d..99c614e1a6 100644 --- a/tests/interaction/transports/test_legacy_wire.py +++ b/tests/interaction/transports/test_legacy_wire.py @@ -66,7 +66,7 @@ async def on_response(response: httpx.Response) -> None: async with mounted_app(_server(), on_request=on_request, on_response=on_response) as (http, _): recording = RecordingTransport(streamable_http_client(f"{BASE_URL}/mcp", http_client=http)) - async with Client(recording) as client: + async with Client(recording, mode="legacy") as client: result = await client.call_tool("echo", {"text": "legacy"}) assert result == snapshot(CallToolResult(content=[TextContent(text="legacy")])) diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py index 9c7353dda5..7a8cdcf7d2 100644 --- a/tests/interaction/transports/test_sse.py +++ b/tests/interaction/transports/test_sse.py @@ -51,10 +51,10 @@ def httpx_client_factory( f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append ) with anyio.fail_after(5): - async with Client(transport) as client: + async with Client(transport, mode="legacy") as client: assert len(captured_session_id) == 1 assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers - assert await client.send_ping() == snapshot(EmptyResult()) + assert await client.send_ping() == snapshot(EmptyResult()) # pyright: ignore[reportDeprecated] assert sse._read_stream_writers == {} diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py index 60a9b93981..9b7813f2d1 100644 --- a/tests/interaction/transports/test_stdio.py +++ b/tests/interaction/transports/test_stdio.py @@ -89,7 +89,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: # Must exceed session time plus the patched PROCESS_TERMINATION_TIMEOUT (20s). with anyio.fail_after(30): - async with Client(transport, logging_callback=collect) as client: + async with Client(transport, mode="legacy", logging_callback=collect) as client: assert client.server_info.name == "stdio-echo" result = await client.call_tool("echo", {"text": "across\nprocesses"}) diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index e31bcff212..226d147195 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -43,7 +43,7 @@ async def call_tool_and_assert( text_contains: list[str] | None = None, ): """Helper to create session, call tool, and assert result.""" - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: result = await client.call_tool(tool_name, args) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -122,7 +122,7 @@ class InvalidNestedSchema(BaseModel): async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # pragma: no cover return ElicitResult(action="accept", content={}) - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: # Test both invalid schemas for tool_name, field_name in [("invalid_list", "numbers"), ("nested_model", "nested")]: result = await client.call_tool(tool_name, {}) diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index a5388b17a4..904c8f873d 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -185,7 +185,7 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] if isinstance(message, Exception): # pragma: no cover raise message - async with Client(tool_progress.mcp, message_handler=message_handler) as client: + async with Client(tool_progress.mcp, message_handler=message_handler, mode="legacy") as client: # Test progress callback progress_updates = [] @@ -215,7 +215,7 @@ async def progress_callback(progress: float, total: float | None, message: str | async def test_sampling() -> None: """Test sampling (LLM interaction) functionality.""" - async with Client(sampling.mcp, sampling_callback=sampling_callback) as client: + async with Client(sampling.mcp, sampling_callback=sampling_callback, mode="legacy") as client: assert client.server_capabilities.tools is not None # Test sampling tool @@ -227,7 +227,7 @@ async def test_sampling() -> None: async def test_elicitation() -> None: """Test elicitation (user interaction) functionality.""" - async with Client(elicitation.mcp, elicitation_callback=elicitation_callback) as client: + async with Client(elicitation.mcp, elicitation_callback=elicitation_callback, mode="legacy") as client: # Test booking with unavailable date (triggers elicitation) booking_result = await client.call_tool( "book_table", @@ -264,7 +264,7 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] if isinstance(message, Exception): # pragma: no cover raise message - async with Client(notifications.mcp, message_handler=message_handler) as client: + async with Client(notifications.mcp, message_handler=message_handler, mode="legacy") as client: # Call tool that generates notifications tool_result = await client.call_tool("process_data", {"data": "test_data"}) assert len(tool_result.content) == 1 diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index a48bd7ae47..9b469e566a 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1119,7 +1119,7 @@ async def logging_tool(msg: str, ctx: Context) -> str: mcp.add_tool(logging_tool) with patch("mcp.server.session.ServerSession.send_log_message") as mock_log: - async with Client(mcp) as client: + async with Client(mcp, mode="legacy") as client: result = await client.call_tool("logging_tool", {"msg": "test"}) assert len(result.content) == 1 content = result.content[0] @@ -1466,7 +1466,7 @@ async def test_get_unknown_prompt(self): """Test error when getting unknown prompt.""" mcp = MCPServer() - async with Client(mcp) as client: + async with Client(mcp, mode="legacy") as client: with pytest.raises(MCPError, match="Unknown prompt"): await client.get_prompt("unknown") @@ -1477,7 +1477,7 @@ async def test_get_prompt_missing_args(self): @mcp.prompt() def prompt_fn(name: str) -> str: ... # pragma: no branch - async with Client(mcp) as client: + async with Client(mcp, mode="legacy") as client: with pytest.raises(MCPError, match="Missing required arguments"): await client.get_prompt("prompt_fn") diff --git a/tests/server/mcpserver/test_title.py b/tests/server/mcpserver/test_title.py index 6624647572..9a3caf1b20 100644 --- a/tests/server/mcpserver/test_title.py +++ b/tests/server/mcpserver/test_title.py @@ -25,12 +25,10 @@ async def test_server_name_title_description_version(): # Start server and connect client async with Client(mcp) as client: - # Access initialization result from session - init_result = await client.session.initialize() - assert init_result.server_info.name == "TestServer" - assert init_result.server_info.title == "Test Server Title" - assert init_result.server_info.description == "This is a test server description." - assert init_result.server_info.version == "1.0" + assert client.server_info.name == "TestServer" + assert client.server_info.title == "Test Server Title" + assert client.server_info.description == "This is a test server description." + assert client.server_info.version == "1.0" @pytest.mark.anyio diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index 9ab03fcdab..8446ed1bad 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -34,7 +34,7 @@ async def elicitation_callback(context: ClientRequestContext, params: ElicitRequ assert params.message == "Please provide your API key to continue." return ElicitResult(action="accept") - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: result = await client.call_tool("request_api_key", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -60,7 +60,7 @@ async def elicitation_callback(context: ClientRequestContext, params: ElicitRequ assert params.mode == "url" return ElicitResult(action="decline") - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: result = await client.call_tool("oauth_flow", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -86,7 +86,7 @@ async def elicitation_callback(context: ClientRequestContext, params: ElicitRequ assert params.mode == "url" return ElicitResult(action="cancel") - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: result = await client.call_tool("payment_flow", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -112,7 +112,7 @@ async def setup_credentials(ctx: Context) -> str: async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="accept") - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: result = await client.call_tool("setup_credentials", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -146,7 +146,7 @@ async def elicitation_callback(context: ClientRequestContext, params: ElicitRequ # Return without content - this is correct for URL mode return ElicitResult(action="accept") - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: result = await client.call_tool("check_url_response", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -177,7 +177,7 @@ async def elicitation_callback(context: ClientRequestContext, params: ElicitRequ assert params.requested_schema is not None return ElicitResult(action="accept", content={"name": "Alice"}) - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: result = await client.call_tool("ask_name", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -208,7 +208,7 @@ async def trigger_elicitation(ctx: Context) -> str: async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="accept") # pragma: no cover - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: result = await client.call_tool("trigger_elicitation", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -265,7 +265,7 @@ async def test_cancel(ctx: Context) -> str: async def decline_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="decline") - async with Client(mcp, elicitation_callback=decline_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=decline_callback) as client: result = await client.call_tool("test_decline", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -275,7 +275,7 @@ async def decline_callback(context: ClientRequestContext, params: ElicitRequestP async def cancel_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="cancel") - async with Client(mcp, elicitation_callback=cancel_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=cancel_callback) as client: result = await client.call_tool("test_cancel", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -308,7 +308,7 @@ async def elicitation_callback(context: ClientRequestContext, params: ElicitRequ assert params.requested_schema is not None return ElicitResult(action="accept", content={"email": "test@example.com"}) - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: result = await client.call_tool("use_deprecated_elicit", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -335,7 +335,7 @@ async def elicitation_callback(context: ClientRequestContext, params: ElicitRequ assert params.elicitation_id == "ctx-test-001" return ElicitResult(action="accept") - async with Client(mcp, elicitation_callback=elicitation_callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=elicitation_callback) as client: result = await client.call_tool("direct_elicit_url", {}) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Result: accept" diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index cc157247c9..7c976cc2d9 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -59,7 +59,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar server = Server("test-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - async with Client(server) as client: + async with Client(server, mode="legacy") as client: # First request (will be cancelled) async def first_request(): try: diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index a01d0d4d72..592af7c35a 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -137,7 +137,7 @@ async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestPa server = Server("test-server", on_completion=handle_completion) - async with Client(server) as client: + async with Client(server, mode="legacy") as client: # Try to complete table without database context - should raise error with pytest.raises(Exception) as exc_info: await client.complete( diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 0e8afed509..d738c02cd0 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -333,7 +333,7 @@ async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestP mcp_app.router.lifespan_context(mcp_app), httpx.ASGITransport(mcp_app) as transport, httpx.AsyncClient(transport=transport) as http_client, - Client(streamable_http_client(f"http://{host}/mcp", http_client=http_client)) as client, + Client(streamable_http_client(f"http://{host}/mcp", http_client=http_client), mode="legacy") as client, ): await client.list_tools() diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 660b5cb3af..1acbb363c0 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -2357,7 +2357,7 @@ async def observe(ctx: Any, call_next: Any) -> Any: server = Server("test-server", on_call_tool=handle_call_tool) server.middleware.append(observe) - async with Client(server) as client: + async with Client(server, mode="legacy") as client: with anyio.fail_after(5): async with anyio.create_task_group() as tg: # pragma: no branch diff --git a/tests/shared/test_otel.py b/tests/shared/test_otel.py index a7df4c4294..ab45576137 100644 --- a/tests/shared/test_otel.py +++ b/tests/shared/test_otel.py @@ -19,7 +19,7 @@ def greet(name: str) -> str: """Greet someone.""" return f"Hello, {name}!" - async with Client(server) as client: + async with Client(server, mode="legacy") as client: result = await client.call_tool("greet", {"name": "World"}) assert isinstance(result.content[0], types.TextContent) From 20a564eee5b040e24e8f9e29db95175c1111cc63 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 13:35:56 +0000 Subject: [PATCH 21/22] DirectDispatcher: resync coverage tracer after each request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A handler-side exception (even one caught and converted to a result, as MCPServer's tool wrapper does) desyncs CPython 3.11's CTracer when the DirectDispatcher request path resumes the awaiting coroutine chain via throw — the same python/cpython#106749 the other transport seams already work around. The resync_tracer() checkpoint at the end of _dispatch_request restores tracing for the caller's subsequent lines. --- src/mcp/shared/direct_dispatcher.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index de99739b1f..9268feffe1 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -26,6 +26,7 @@ import anyio.abc from pydantic import ValidationError +from mcp.shared._compat import resync_tracer from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import MessageMetadata @@ -251,6 +252,8 @@ async def _dispatch_request( code=REQUEST_TIMEOUT, message=f"Timed out after {opts.get('timeout')}s waiting for {method!r}", ) from None + finally: + await resync_tracer() async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) -> None: try: From 0bb9ee772bb23bdca2c2e95537694ab6227cb988 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 13:55:48 +0000 Subject: [PATCH 22/22] Build send_discover params via DiscoverRequest; gate preconnect cancel-suppression by method - send_discover: construct params through types.DiscoverRequest/RequestParams instead of a hand-rolled dict (same wire shape; routing stays raw because the probe's _meta version is per-call, not session state). - _preconnect_stamp: only suppress cancel_on_abandon for initialize and server/discover, so lowlevel ClientSession callers that skip the handshake keep the courtesy cancel on timed-out/abandoned requests. - send_discover docstring: drop the nonexistent ProbeNotRecognized entry; fold the transport-4xx case into the MCPError Raises line. - migration.md: note the MCP_PROTOCOL_VERSION header constant move to mcp.shared.inbound.MCP_PROTOCOL_VERSION_HEADER. --- docs/migration.md | 2 ++ src/mcp/client/session.py | 31 ++++++++++++++++++------------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index fb82495350..46ec205ee9 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -163,6 +163,8 @@ Note: `sse_client` retains its `headers`, `timeout`, `sse_read_timeout`, and `au The transport no longer holds per-connection protocol state; era-dependent headers (e.g. `MCP-Protocol-Version`) are now supplied per-message by the session. If you were reading `transport.protocol_version` to learn the negotiated version, read `session.protocol_version` (or `client.protocol_version` on the high-level `Client`) instead. +The `MCP_PROTOCOL_VERSION` header-name constant has moved: import `MCP_PROTOCOL_VERSION_HEADER` from `mcp.shared.inbound` instead of `MCP_PROTOCOL_VERSION` from `mcp.client.streamable_http`. + ### `terminate_windows_process` removed The deprecated `mcp.os.win32.utilities.terminate_windows_process` function has been diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c4445056a7..8ac3e22882 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -52,8 +52,10 @@ def _preconnect_stamp(data: dict[str, Any], opts: CallOptions) -> None: - # Only initialize/discover go out before connect; both forbid cancellation. - opts["cancel_on_abandon"] = False + # initialize/discover forbid cancellation; other pre-handshake requests (lowlevel + # ClientSession callers may skip the handshake entirely) keep the courtesy cancel. + if data["method"] in ("initialize", "server/discover"): + opts["cancel_on_abandon"] = False def _make_handshake_stamp(protocol_version: str) -> Callable[[dict[str, Any], CallOptions], None]: @@ -402,25 +404,28 @@ async def send_discover(self, version: str) -> dict[str, Any]: the connect-time auto-negotiation policy. Raises: - MCPError: The server returned a JSON-RPC error. - ProbeNotRecognized: The transport bounced the request at its own - layer (HTTP 4xx without a JSON-RPC error body). + MCPError: The server returned a JSON-RPC error, or the transport + bounced the request at its own layer (a bare HTTP 4xx is + synthesized into a JSON-RPC error by the transport). """ client_info = self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True) capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True) - params = { - "_meta": { - PROTOCOL_VERSION_META_KEY: version, - CLIENT_INFO_META_KEY: client_info, - CLIENT_CAPABILITIES_META_KEY: capabilities, - } - } + request = types.DiscoverRequest( + params=types.RequestParams( + _meta={ + PROTOCOL_VERSION_META_KEY: version, + CLIENT_INFO_META_KEY: client_info, + CLIENT_CAPABILITIES_META_KEY: capabilities, + } + ) + ) + data = request.model_dump(by_alias=True, mode="json", exclude_none=True) opts: CallOptions = { "timeout": DISCOVER_TIMEOUT_SECONDS, "cancel_on_abandon": False, "headers": {MCP_PROTOCOL_VERSION_HEADER: version}, } - return await self._dispatcher.send_raw_request("server/discover", params, opts) + return await self._dispatcher.send_raw_request(data["method"], data.get("params"), opts) async def discover(self) -> types.DiscoverResult: """Probe `server/discover` and adopt the result.