Skip to content

Commit f9a15e1

Browse files
committed
Client: collapse _inproc_server/_transport into a single _connect closure
__post_init__ now resolves a single _connect closure from the shape of the server argument alone (in-process vs URL vs Transport instance). mode and raise_exceptions are passed to the closure at enter time so they're read at the same moment __aenter__ reads them for the handshake step. _build_session collapses to one line of logic; the mutually-exclusive Optional fields and the assert that guarded them are gone. JSONRPCDispatcher.on_stream_exception is now public-mutable so ClientSession can install its message_handler routing after the dispatcher is built; the install only happens when no caller-supplied hook is already set. ClientSession.adopt() now clears the opposite result slot so at most one of initialize_result/discover_result is non-None by construction.
1 parent 78823d4 commit f9a15e1

3 files changed

Lines changed: 72 additions & 39 deletions

File tree

src/mcp/client/client.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Mapping
5+
from collections.abc import Awaitable, Callable, Mapping
66
from contextlib import AsyncExitStack
77
from dataclasses import KW_ONLY, dataclass, field
88
from typing import Any, Literal, TypeVar
@@ -21,6 +21,7 @@
2121
from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair
2222
from mcp.shared.dispatcher import Dispatcher, ProgressFnT
2323
from mcp.shared.exceptions import MCPDeprecationWarning, MCPError
24+
from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher
2425
from mcp.shared.version import HANDSHAKE_PROTOCOL_VERSIONS, MODERN_PROTOCOL_VERSIONS
2526
from mcp.types import (
2627
METHOD_NOT_FOUND,
@@ -50,6 +51,43 @@
5051

5152
_T = TypeVar("_T")
5253

54+
_Connector = Callable[[AsyncExitStack, ConnectMode, bool], Awaitable["Dispatcher[Any]"]]
55+
"""Resolved at ``__post_init__`` from the shape of ``server`` alone: enter whatever resources
56+
are needed onto the exit stack and hand back the ``Dispatcher`` ``ClientSession`` will drive.
57+
``mode`` and ``raise_exceptions`` are passed at call time so they're read at the same moment
58+
``__aenter__`` reads them for the handshake step."""
59+
60+
61+
def _connect_transport(transport: Transport) -> _Connector:
62+
"""Connector for the stream-backed paths (URL, user-supplied ``Transport``)."""
63+
64+
async def connect(exit_stack: AsyncExitStack, _mode: ConnectMode, _raise_exceptions: bool) -> Dispatcher[Any]:
65+
read_stream, write_stream = await exit_stack.enter_async_context(transport)
66+
return JSONRPCDispatcher(read_stream, write_stream)
67+
68+
return connect
69+
70+
71+
def _connect_inproc(server: Server[Any]) -> _Connector:
72+
"""Connector for an in-process ``Server``: legacy mode drives the stream loop via
73+
``InMemoryTransport``; any other mode drives the modern per-request path through a
74+
``DirectDispatcher`` peer pair (no streams, no JSON-RPC framing, no initialize handshake)."""
75+
76+
async def connect(exit_stack: AsyncExitStack, mode: ConnectMode, raise_exceptions: bool) -> Dispatcher[Any]:
77+
if mode == "legacy":
78+
transport = InMemoryTransport(server, raise_exceptions=raise_exceptions)
79+
read_stream, write_stream = await exit_stack.enter_async_context(transport)
80+
return JSONRPCDispatcher(read_stream, write_stream)
81+
lifespan_state = await exit_stack.enter_async_context(server.lifespan(server))
82+
client_disp, server_disp = create_direct_dispatcher_pair()
83+
tg = await exit_stack.enter_async_context(anyio.create_task_group())
84+
exit_stack.callback(server_disp.close)
85+
on_request = modern_on_request(server, lifespan_state, raise_exceptions=raise_exceptions)
86+
await tg.start(server_disp.run, on_request, _no_inbound_client_notifications)
87+
return client_disp
88+
89+
return connect
90+
5391

5492
def _connected(value: _T | None) -> _T:
5593
"""Narrow a post-handshake session attribute from ``T | None`` to ``T``.
@@ -161,19 +199,9 @@ async def main():
161199
_entered: bool = field(init=False, default=False)
162200
_session: ClientSession | None = field(init=False, default=None)
163201
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
164-
_transport: Transport | None = field(init=False, default=None)
165-
_inproc_server: Server[Any] | None = field(init=False, default=None)
202+
_connect: _Connector = field(init=False, repr=False, compare=False)
166203

167204
def __post_init__(self) -> None:
168-
if isinstance(self.server, MCPServer):
169-
self._inproc_server = self.server._lowlevel_server # pyright: ignore[reportPrivateUsage]
170-
elif isinstance(self.server, Server):
171-
self._inproc_server = self.server
172-
elif isinstance(self.server, str):
173-
self._transport = streamable_http_client(self.server)
174-
else:
175-
self._transport = self.server
176-
177205
if self.mode not in ("legacy", "auto") and self.mode not in MODERN_PROTOCOL_VERSIONS:
178206
hint = (
179207
f" ({self.mode!r} is a handshake-era version — use mode='legacy')"
@@ -184,31 +212,20 @@ def __post_init__(self) -> None:
184212
f"mode must be 'legacy', 'auto', or one of {list(MODERN_PROTOCOL_VERSIONS)}; got {self.mode!r}{hint}"
185213
)
186214

187-
async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession:
188-
"""Set up the dispatcher/transport and return an un-entered ClientSession."""
189-
dispatcher: Dispatcher[Any] | None
190-
if self._inproc_server is not None and self.mode != "legacy":
191-
# Modern in-process path: drive the server through a DirectDispatcher peer-pair
192-
# with one `serve_one` per request — no streams, no initialize handshake.
193-
lifespan_state = await exit_stack.enter_async_context(self._inproc_server.lifespan(self._inproc_server))
194-
client_disp, server_disp = create_direct_dispatcher_pair()
195-
tg = await exit_stack.enter_async_context(anyio.create_task_group())
196-
exit_stack.callback(server_disp.close)
197-
on_request = modern_on_request(self._inproc_server, lifespan_state, raise_exceptions=self.raise_exceptions)
198-
await tg.start(server_disp.run, on_request, _no_inbound_client_notifications)
199-
dispatcher = client_disp
200-
read_stream = write_stream = None
215+
srv = self.server
216+
if isinstance(srv, MCPServer):
217+
srv = srv._lowlevel_server # pyright: ignore[reportPrivateUsage]
218+
if isinstance(srv, Server):
219+
self._connect = _connect_inproc(srv)
220+
elif isinstance(srv, str):
221+
self._connect = _connect_transport(streamable_http_client(srv))
201222
else:
202-
if self._inproc_server is not None:
203-
transport: Transport = InMemoryTransport(self._inproc_server, raise_exceptions=self.raise_exceptions)
204-
else:
205-
assert self._transport is not None
206-
transport = self._transport
207-
read_stream, write_stream = await exit_stack.enter_async_context(transport)
208-
dispatcher = None
223+
self._connect = _connect_transport(srv)
224+
225+
async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession:
226+
"""Enter the resolved connector and return an un-entered ClientSession."""
227+
dispatcher = await self._connect(exit_stack, self.mode, self.raise_exceptions)
209228
return ClientSession(
210-
read_stream=read_stream,
211-
write_stream=write_stream,
212229
dispatcher=dispatcher,
213230
read_timeout_seconds=self.read_timeout_seconds,
214231
sampling_callback=self.sampling_callback,

src/mcp/client/session.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,14 @@ def __init__(
216216
if read_stream is not None or write_stream is not None:
217217
raise ValueError("pass read_stream/write_stream or dispatcher, not both")
218218
self._dispatcher: Dispatcher[Any] = dispatcher
219+
if isinstance(dispatcher, JSONRPCDispatcher) and dispatcher.on_stream_exception is None:
220+
# Route transport-level Exception items into message_handler — only
221+
# stream-backed dispatchers carry these; DirectDispatcher has none.
222+
# Don't clobber a caller-supplied hook.
223+
# TODO(maxisbey): this leaves a bound-method ref on the dispatcher after
224+
# the session exits (memory pin) and a second wrap of the same dispatcher
225+
# would skip install. The Transport-as-Dispatcher rework removes this seam.
226+
dispatcher.on_stream_exception = self._on_stream_exception
219227
else:
220228
if read_stream is None or write_stream is None:
221229
raise ValueError("read_stream and write_stream are required when no dispatcher is given")
@@ -358,6 +366,9 @@ async def initialize(self) -> types.InitializeResult:
358366
def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None:
359367
"""Install negotiated state from a result the caller already holds (no wire traffic).
360368
369+
Clears the opposite slot, so at most one of `initialize_result` /
370+
`discover_result` is ever non-None.
371+
361372
Raises:
362373
RuntimeError: `result` is a `DiscoverResult` whose `supported_versions`
363374
shares nothing with this client's `MODERN_PROTOCOL_VERSIONS`.
@@ -374,10 +385,12 @@ def adopt(self, result: types.InitializeResult | types.DiscoverResult) -> None:
374385
capabilities = self._build_capabilities().model_dump(by_alias=True, mode="json", exclude_none=True)
375386
self._stamp = _make_modern_stamp(mutual[-1], client_info, capabilities)
376387
self._discover_result = result
388+
self._initialize_result = None
377389
self._negotiated_version = mutual[-1]
378390
else:
379391
self._stamp = _make_handshake_stamp(result.protocol_version)
380392
self._initialize_result = result
393+
self._discover_result = None
381394
self._negotiated_version = result.protocol_version
382395

383396
async def discover(self) -> types.DiscoverResult:

src/mcp/shared/jsonrpc_dispatcher.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def handler_exception_to_error_data(exc: BaseException) -> ErrorData | None:
7575
with empty ``data`` (no pydantic text on the wire). Returns ``None`` for
7676
any other exception so each caller applies its own catch-all -
7777
`JSONRPCDispatcher` currently pins ``code=0`` for v1 compat,
78-
`to_jsonrpc_response` uses `INTERNAL_ERROR`.
78+
the modern HTTP entry uses `INTERNAL_ERROR`.
7979
"""
8080
if isinstance(exc, MCPError):
8181
return exc.error
@@ -268,7 +268,10 @@ def __init__(
268268
self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode
269269
self._raise_handler_exceptions = raise_handler_exceptions
270270
self._inline_methods = inline_methods
271-
self._on_stream_exception = on_stream_exception
271+
self.on_stream_exception = on_stream_exception
272+
"""Observer for ``Exception`` items on the read stream. Mutable so a session can
273+
bind it after the dispatcher is built (e.g. ``ClientSession`` routing into
274+
``message_handler``); only consulted inside ``run()`` so pre-enter assignment is safe."""
272275

273276
self._next_id = 0
274277
self._pending: dict[RequestId, _Pending] = {}
@@ -484,11 +487,11 @@ async def _dispatch(
484487
are awaited; any other `await` would head-of-line block the read loop.
485488
"""
486489
if isinstance(item, Exception):
487-
if self._on_stream_exception is None:
490+
if self.on_stream_exception is None:
488491
logger.debug("transport yielded exception: %r", item)
489492
return
490493
try:
491-
await self._on_stream_exception(item)
494+
await self.on_stream_exception(item)
492495
except Exception:
493496
logger.exception("on_stream_exception observer raised")
494497
return

0 commit comments

Comments
 (0)