Skip to content

Commit 14b325f

Browse files
committed
Add protocol_version pin to ClientSession for stateless 2026-07-28 mode
When ClientSession is constructed with protocol_version="2026-07-28", each outgoing request carries the io.modelcontextprotocol/* envelope (protocolVersion, clientInfo, clientCapabilities) in params._meta, and initialize() raises if called. Capabilities derivation is extracted to _build_capabilities() so both paths share it. The streamable-HTTP transport derives MCP-Protocol-Version, Mcp-Method and (for tools/call) Mcp-Name headers per POST from the body's envelope; non-header-safe values are Base64-sentinel-encoded per the spec. Envelope-less bodies get no derived headers, so unpinned behaviour is unchanged. Session-id capture, the standalone GET stream and DELETE on close are gated on traffic the pinned mode never produces. Claude-Session: https://claude.ai/code/session_017S3aJaxEHeMvftp6whnHWK
1 parent f1fa4ec commit 14b325f

2 files changed

Lines changed: 67 additions & 9 deletions

File tree

src/mcp/client/session.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,15 @@
2222
from mcp.shared.session import RequestResponder
2323
from mcp.shared.transport_context import TransportContext
2424
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
25-
from mcp.types import INTERNAL_ERROR, METHOD_NOT_FOUND, RequestId, RequestParamsMeta
25+
from mcp.types import (
26+
CLIENT_CAPABILITIES_META_KEY,
27+
CLIENT_INFO_META_KEY,
28+
INTERNAL_ERROR,
29+
METHOD_NOT_FOUND,
30+
PROTOCOL_VERSION_META_KEY,
31+
RequestId,
32+
RequestParamsMeta,
33+
)
2634
from mcp.types import methods as _methods
2735

2836
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
@@ -141,11 +149,13 @@ def __init__(
141149
message_handler: MessageHandlerFnT | None = None,
142150
client_info: types.Implementation | None = None,
143151
*,
152+
protocol_version: str | None = None,
144153
sampling_capabilities: types.SamplingCapability | None = None,
145154
dispatcher: Dispatcher[Any] | None = None,
146155
) -> None:
147156
self._session_read_timeout_seconds = read_timeout_seconds
148157
self._client_info = client_info or DEFAULT_CLIENT_INFO
158+
self._pinned_version = protocol_version
149159
self._sampling_callback = sampling_callback or _default_sampling_callback
150160
self._sampling_capabilities = sampling_capabilities
151161
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
@@ -218,6 +228,18 @@ async def send_request(
218228
"""
219229
data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
220230
method: str = data["method"]
231+
if self._pinned_version is not None:
232+
params = data.setdefault("params", {})
233+
envelope_meta = params.setdefault("_meta", {})
234+
envelope_meta.setdefault(PROTOCOL_VERSION_META_KEY, self._pinned_version)
235+
envelope_meta.setdefault(
236+
CLIENT_INFO_META_KEY,
237+
self._client_info.model_dump(by_alias=True, mode="json", exclude_none=True),
238+
)
239+
envelope_meta.setdefault(
240+
CLIENT_CAPABILITIES_META_KEY,
241+
self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True),
242+
)
221243
opts: CallOptions = {}
222244
timeout = (
223245
request_read_timeout_seconds
@@ -254,7 +276,7 @@ async def send_notification(self, notification: types.ClientNotification) -> Non
254276
data = notification.model_dump(by_alias=True, mode="json", exclude_none=True)
255277
await self._dispatcher.notify(data["method"], data.get("params"))
256278

257-
async def initialize(self) -> types.InitializeResult:
279+
def _build_capabilities(self) -> types.ClientCapabilities:
258280
sampling = (
259281
(self._sampling_capabilities or types.SamplingCapability())
260282
if self._sampling_callback is not _default_sampling_callback
@@ -273,17 +295,17 @@ async def initialize(self) -> types.InitializeResult:
273295
if self._list_roots_callback is not _default_list_roots_callback
274296
else None
275297
)
298+
return types.ClientCapabilities(sampling=sampling, elicitation=elicitation, experimental=None, roots=roots)
276299

300+
async def initialize(self) -> types.InitializeResult:
301+
if self._pinned_version is not None:
302+
raise RuntimeError("initialize() must not be called on a session pinned to a stateless protocol version")
303+
capabilities = self._build_capabilities()
277304
result = await self.send_request(
278305
types.InitializeRequest(
279306
params=types.InitializeRequestParams(
280307
protocol_version=types.LATEST_PROTOCOL_VERSION,
281-
capabilities=types.ClientCapabilities(
282-
sampling=sampling,
283-
elicitation=elicitation,
284-
experimental=None,
285-
roots=roots,
286-
),
308+
capabilities=capabilities,
287309
client_info=self._client_info,
288310
),
289311
),
@@ -309,7 +331,9 @@ def initialize_result(self) -> types.InitializeResult | None:
309331

310332
@property
311333
def protocol_version(self) -> str | None:
312-
"""The negotiated protocol version. None until `initialize()` has completed."""
334+
"""Negotiated or pinned protocol version. None until initialize() unless pinned at construction."""
335+
if self._pinned_version is not None:
336+
return self._pinned_version
313337
return self._initialize_result.protocol_version if self._initialize_result else None
314338

315339
async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:

src/mcp/client/streamable_http.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from __future__ import annotations as _annotations
44

5+
import base64
56
import contextlib
67
import logging
8+
import re
79
from collections.abc import AsyncGenerator, Awaitable, Callable
810
from contextlib import asynccontextmanager
911
from dataclasses import dataclass
@@ -23,6 +25,7 @@
2325
INTERNAL_ERROR,
2426
INVALID_REQUEST,
2527
PARSE_ERROR,
28+
PROTOCOL_VERSION_META_KEY,
2629
ErrorData,
2730
InitializeResult,
2831
JSONRPCError,
@@ -44,12 +47,42 @@
4447

4548
MCP_SESSION_ID = "mcp-session-id"
4649
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
50+
MCP_METHOD = "mcp-method"
51+
MCP_NAME = "mcp-name"
4752
LAST_EVENT_ID = "last-event-id"
4853

4954
# Reconnection defaults
5055
DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry
5156
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up
5257

58+
_B64_SENTINEL = re.compile(r"^=\?base64\?.*\?=$")
59+
# RFC 7230 token chars minus DEL; visible ASCII 0x20-0x7E is the practical bound for a header value.
60+
_HEADER_SAFE = re.compile(r"^[\x20-\x7E]*$")
61+
62+
63+
def _encode_header_value(value: str) -> str:
64+
if _HEADER_SAFE.fullmatch(value) and not _B64_SENTINEL.fullmatch(value):
65+
return value
66+
return f"=?base64?{base64.b64encode(value.encode('utf-8')).decode('ascii')}?="
67+
68+
69+
def _body_derived_headers(message: JSONRPCMessage) -> dict[str, str]:
70+
"""Derive 2026-era headers from an envelope-bearing request body. Empty dict for legacy bodies."""
71+
if not isinstance(message, JSONRPCRequest) or message.params is None:
72+
return {}
73+
meta = message.params.get("_meta")
74+
if meta is None:
75+
return {}
76+
version = meta.get(PROTOCOL_VERSION_META_KEY)
77+
if not isinstance(version, str):
78+
return {}
79+
headers: dict[str, str] = {MCP_PROTOCOL_VERSION: version, MCP_METHOD: message.method}
80+
if message.method == "tools/call":
81+
name = message.params.get("name")
82+
if isinstance(name, str):
83+
headers[MCP_NAME] = _encode_header_value(name)
84+
return headers
85+
5386

5487
class StreamableHTTPError(Exception):
5588
"""Base exception for StreamableHTTP transport errors."""
@@ -256,6 +289,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
256289
"""Handle a POST request with response processing."""
257290
headers = self._prepare_headers()
258291
message = ctx.session_message.message
292+
headers.update(_body_derived_headers(message))
259293
is_initialization = self._is_initialization_request(message)
260294

261295
async with ctx.client.stream(

0 commit comments

Comments
 (0)