Skip to content

Commit 1d33743

Browse files
committed
Client gains mode= and prior_discover= policy knobs (legacy and version-pin)
- mode='legacy' (default) performs the initialize handshake; a version string (e.g. '2026-07-28') adopts that version directly via .adopt() - prior_discover= reuses a known DiscoverResult; omitting it synthesizes a minimal one - 'auto' (server/discover probe with fallback) follows once .discover() lands - Interaction-suite connect fixture passes mode= for the modern arm and yields Client for all arms again; the W1b-era ClientSession adapter and type suppression are removed
1 parent 52546fa commit 1d33743

4 files changed

Lines changed: 50 additions & 74 deletions

File tree

src/mcp/client/client.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
from contextlib import AsyncExitStack
66
from dataclasses import KW_ONLY, dataclass, field
7-
from typing import Any
7+
from typing import Any, Literal
88

99
from typing_extensions import deprecated
1010

11+
from mcp import types
1112
from mcp.client._memory import InMemoryTransport
1213
from mcp.client._transport import Transport
1314
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
@@ -36,6 +37,17 @@
3637
)
3738

3839

40+
def _synthesize_discover(protocol_version: str) -> types.DiscoverResult:
41+
return types.DiscoverResult(
42+
supported_versions=[protocol_version],
43+
capabilities=types.ServerCapabilities(),
44+
server_info=types.Implementation(name="", version=""),
45+
result_type="complete",
46+
ttl_ms=0,
47+
cache_scope="public",
48+
)
49+
50+
3951
@dataclass
4052
class Client:
4153
"""A high-level MCP client for connecting to MCP servers.
@@ -95,6 +107,15 @@ async def main():
95107
client_info: Implementation | None = None
96108
"""Client implementation info to send to server."""
97109

110+
mode: Literal["legacy"] | str = "legacy"
111+
"""'legacy' performs the initialize handshake. A protocol-version string (e.g. '2026-07-28') adopts that
112+
version directly without a handshake — supply prior_discover to reuse a known DiscoverResult, or omit it
113+
to synthesize a minimal one."""
114+
115+
prior_discover: types.DiscoverResult | None = None
116+
"""A previously-obtained DiscoverResult to install via .adopt() when mode is a version pin.
117+
Ignored when mode='legacy'."""
118+
98119
elicitation_callback: ElicitationFnT | None = None
99120
"""Callback for handling elicitation requests."""
100121

@@ -132,7 +153,10 @@ async def __aenter__(self) -> Client:
132153
)
133154
)
134155

135-
await self._session.initialize()
156+
if self.mode == "legacy":
157+
await self._session.initialize()
158+
else:
159+
self._session.adopt(self.prior_discover or _synthesize_discover(self.mode))
136160

137161
# Transfer ownership to self for __aexit__ to handle
138162
self._exit_stack = exit_stack.pop_all()

tests/interaction/_connect.py

Lines changed: 16 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from starlette.routing import Mount, Route
2121

2222
from mcp.client.client import Client
23-
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
23+
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
2424
from mcp.client.sse import sse_client
2525
from mcp.client.streamable_http import streamable_http_client
2626
from mcp.server import Server
@@ -35,13 +35,11 @@
3535
from mcp.types import (
3636
LATEST_PROTOCOL_VERSION,
3737
ClientCapabilities,
38-
DiscoverResult,
3938
Implementation,
4039
InitializeRequestParams,
4140
JSONRPCMessage,
4241
JSONRPCRequest,
4342
JSONRPCResponse,
44-
ServerCapabilities,
4543
jsonrpc_message_adapter,
4644
)
4745
from tests.interaction.transports._bridge import StreamingASGITransport
@@ -129,9 +127,9 @@ async def connect_over_streamable_http(
129127
resumability tests pass an `event_store` (with `retry_interval=0` so the client's
130128
reconnection wait is a no-op).
131129
132-
When `spec_version` is a modern (2026-07-28+) revision, the modern path is exercised: a bare
133-
`ClientSession` is built over the streams and adopted from a synthesized `DiscoverResult`
134-
instead of negotiating via `Client`'s legacy initialize handshake.
130+
When `spec_version` is a modern (2026-07-28+) revision the Client is opened with
131+
`mode=<version>`, which adopts a synthesized DiscoverResult instead of running the legacy
132+
initialize handshake.
135133
"""
136134
app = server.streamable_http_app(
137135
stateless_http=stateless_http,
@@ -143,48 +141,19 @@ async def connect_over_streamable_http(
143141
async with (
144142
server.session_manager.run(),
145143
httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client,
144+
Client(
145+
streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client),
146+
mode=spec_version if spec_version in MODERN_PROTOCOL_VERSIONS else "legacy",
147+
read_timeout_seconds=read_timeout_seconds,
148+
sampling_callback=sampling_callback,
149+
list_roots_callback=list_roots_callback,
150+
logging_callback=logging_callback,
151+
message_handler=message_handler,
152+
client_info=client_info,
153+
elicitation_callback=elicitation_callback,
154+
) as client,
146155
):
147-
if spec_version in MODERN_PROTOCOL_VERSIONS:
148-
async with (
149-
streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read, write),
150-
ClientSession(
151-
read,
152-
write,
153-
read_timeout_seconds=read_timeout_seconds,
154-
sampling_callback=sampling_callback,
155-
list_roots_callback=list_roots_callback,
156-
logging_callback=logging_callback,
157-
message_handler=message_handler,
158-
client_info=client_info,
159-
elicitation_callback=elicitation_callback,
160-
) as session,
161-
):
162-
session.adopt(
163-
DiscoverResult(
164-
supported_versions=[spec_version],
165-
capabilities=ServerCapabilities(),
166-
server_info=Implementation(name="test", version="0"),
167-
result_type="complete",
168-
ttl_ms=0,
169-
cache_scope="public",
170-
)
171-
)
172-
# ClientSession quacks as Client for every modern-arm requirement; the surfaces
173-
# diverge (cursor= vs params=, .session, nullable initialize_result) so widening
174-
# Connect to the union cascades ~58 errors. Contained here until the two converge.
175-
yield session # pyright: ignore[reportReturnType]
176-
else:
177-
async with Client(
178-
streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client),
179-
read_timeout_seconds=read_timeout_seconds,
180-
sampling_callback=sampling_callback,
181-
list_roots_callback=list_roots_callback,
182-
logging_callback=logging_callback,
183-
message_handler=message_handler,
184-
client_info=client_info,
185-
elicitation_callback=elicitation_callback,
186-
) as client:
187-
yield client
156+
yield client
188157

189158

190159
connect_over_streamable_http_stateless: Connect = partial(connect_over_streamable_http, stateless_http=True)

tests/interaction/lowlevel/test_cancellation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara
148148
server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool)
149149

150150
async with connect(server) as client:
151-
session = client if isinstance(client, ClientSession) else client.session # modern arm yields a bare session
152-
await session.send_notification(
151+
await client.session.send_notification(
153152
types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999))
154153
)
155154
result = await client.call_tool("echo", {})

tests/interaction/lowlevel/test_pagination.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
pagination scheme.
66
"""
77

8-
from collections.abc import Callable, Coroutine
9-
from typing import Any, TypeVar
10-
118
import pytest
129
from inline_snapshot import snapshot
1310

@@ -29,19 +26,6 @@
2926

3027
pytestmark = pytest.mark.anyio
3128

32-
_ResultT = TypeVar("_ResultT")
33-
34-
35-
def _page(list_fn: Callable[..., Coroutine[Any, Any, _ResultT]], cursor: str | None) -> Coroutine[Any, Any, _ResultT]:
36-
"""Call a paginated list method with a cursor on whichever client surface `connect` yielded.
37-
38-
`Client.list_*` takes `cursor=`; `ClientSession.list_*` takes `params=PaginatedRequestParams(...)`.
39-
"""
40-
try:
41-
return list_fn(cursor=cursor)
42-
except TypeError:
43-
return list_fn(params=types.PaginatedRequestParams(cursor=cursor))
44-
4529

4630
@requirement("tools:list:pagination")
4731
async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> None:
@@ -65,7 +49,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa
6549

6650
async with connect(server) as client:
6751
first_page = await client.list_tools()
68-
second_page = await _page(client.list_tools, first_page.next_cursor)
52+
second_page = await client.list_tools(cursor=first_page.next_cursor)
6953

7054
assert first_page.next_cursor == cursor
7155
assert seen_cursors == [None, cursor]
@@ -95,7 +79,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa
9579
requests_made = 0
9680
async with connect(server) as client:
9781
while True:
98-
result = await _page(client.list_tools, cursor)
82+
result = await client.list_tools(cursor=cursor)
9983
requests_made += 1
10084
assert requests_made <= len(pages), "the server kept returning next_cursor past the last page"
10185
collected.extend(tool.name for tool in result.tools)
@@ -138,7 +122,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa
138122
cursor: str | None = None
139123
async with connect(server) as client:
140124
while True:
141-
result = await _page(client.list_tools, cursor)
125+
result = await client.list_tools(cursor=cursor)
142126
page_sizes.append(len(result.tools))
143127
if result.next_cursor is None:
144128
break
@@ -166,7 +150,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa
166150

167151
async with connect(server) as client:
168152
with pytest.raises(MCPError) as exc_info:
169-
await _page(client.list_tools, "never-issued")
153+
await client.list_tools(cursor="never-issued")
170154

171155
assert exc_info.value.error.code == INVALID_PARAMS
172156

@@ -190,7 +174,7 @@ async def list_resources(
190174

191175
async with connect(server) as client:
192176
first_page = await client.list_resources()
193-
second_page = await _page(client.list_resources, first_page.next_cursor)
177+
second_page = await client.list_resources(cursor=first_page.next_cursor)
194178

195179
assert first_page.next_cursor == cursor
196180
assert seen_cursors == [None, cursor]
@@ -223,7 +207,7 @@ async def list_resource_templates(
223207

224208
async with connect(server) as client:
225209
first_page = await client.list_resource_templates()
226-
second_page = await _page(client.list_resource_templates, first_page.next_cursor)
210+
second_page = await client.list_resource_templates(cursor=first_page.next_cursor)
227211

228212
assert first_page.next_cursor == cursor
229213
assert seen_cursors == [None, cursor]
@@ -249,7 +233,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest
249233

250234
async with connect(server) as client:
251235
first_page = await client.list_prompts()
252-
second_page = await _page(client.list_prompts, first_page.next_cursor)
236+
second_page = await client.list_prompts(cursor=first_page.next_cursor)
253237

254238
assert first_page.next_cursor == cursor
255239
assert seen_cursors == [None, cursor]

0 commit comments

Comments
 (0)