Skip to content

Commit 08c7fa1

Browse files
author
冯基魁
committed
fix: drain stdio responses after stdin EOF
1 parent 2397319 commit 08c7fa1

6 files changed

Lines changed: 157 additions & 1 deletion

File tree

src/mcp/server/lowlevel/server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,8 @@ async def run(
436436
# but also make tracing exceptions much easier during testing and when using
437437
# in-process servers.
438438
raise_exceptions: bool = False,
439+
drain_in_flight_on_read_eof: bool = False,
440+
read_eof_response_drain_timeout: float = 5.0,
439441
) -> None:
440442
"""Serve a single connection over the given streams until the read side closes.
441443
@@ -448,6 +450,8 @@ async def run(
448450
self,
449451
read_stream,
450452
write_stream,
453+
drain_in_flight_on_read_eof=drain_in_flight_on_read_eof,
454+
read_eof_response_drain_timeout=read_eof_response_drain_timeout,
451455
lifespan_state=lifespan_context,
452456
init_options=initialization_options,
453457
raise_exceptions=raise_exceptions,

src/mcp/server/mcpserver/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,7 @@ async def run_stdio_async(self) -> None:
838838
read_stream,
839839
write_stream,
840840
self._lowlevel_server.create_initialization_options(),
841+
drain_in_flight_on_read_eof=True,
841842
)
842843

843844
async def run_sse_async( # pragma: no cover

src/mcp/server/runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,8 @@ async def serve_loop(
447447
session_id: str | None = None,
448448
init_options: InitializationOptions | None = None,
449449
raise_exceptions: bool = False,
450+
drain_in_flight_on_read_eof: bool = False,
451+
read_eof_response_drain_timeout: float = 5.0,
450452
) -> None:
451453
"""Drive ``server`` in loop mode over a stream pair until the channel closes.
452454
@@ -460,6 +462,8 @@ async def serve_loop(
460462
read_stream,
461463
write_stream,
462464
raise_handler_exceptions=raise_exceptions,
465+
drain_in_flight_on_read_eof=drain_in_flight_on_read_eof,
466+
read_eof_response_drain_timeout=read_eof_response_drain_timeout,
463467
# Handle `initialize` inline so a client that pipelines it with the
464468
# next request (spec: SHOULD NOT, not MUST NOT) sees the initialized
465469
# state instead of failing the init-gate.

src/mcp/shared/jsonrpc_dispatcher.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ def __init__(
239239
raise_handler_exceptions: bool = False,
240240
inline_methods: frozenset[str] = frozenset(),
241241
on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None,
242+
drain_in_flight_on_read_eof: bool = False,
243+
read_eof_response_drain_timeout: float = 5.0,
242244
) -> None:
243245
"""Wire a dispatcher over a transport's `SessionMessage` stream pair.
244246
@@ -264,12 +266,23 @@ def __init__(
264266
)
265267
self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode
266268
self._raise_handler_exceptions = raise_handler_exceptions
269+
self._drain_in_flight_on_read_eof = drain_in_flight_on_read_eof
270+
self._read_eof_response_drain_timeout = read_eof_response_drain_timeout
271+
# Request methods handled inline in the read loop (awaited before the
272+
# next message is dequeued) instead of spawned concurrently. Use for
273+
# methods whose side effects must be observable to the next message,
274+
# e.g. `initialize`, so a pipelined follow-up sees the initialized state.
275+
# Only suitable for handlers that complete quickly, since inline handling
276+
# blocks dequeuing; a handler that awaits the peer (`send_raw_request`)
277+
# while inline will deadlock because the parked read loop cannot dequeue
278+
# the response.
267279
self._inline_methods = inline_methods
268280
self._on_stream_exception = on_stream_exception
269281

270282
self._next_id = 0
271283
self._pending: dict[RequestId, _Pending] = {}
272284
self._in_flight: dict[RequestId, _InFlight[TransportT]] = {}
285+
self._responses_in_flight: set[RequestId] = set()
273286
self._tg: anyio.abc.TaskGroup | None = None
274287
self._running = False
275288
self._closed = False
@@ -451,6 +464,12 @@ async def run(
451464
except anyio.ClosedResourceError:
452465
# Receive end closed under us (stateless SHTTP teardown); same as EOF.
453466
logger.debug("read stream closed by transport; treating as EOF")
467+
if self._drain_in_flight_on_read_eof:
468+
with anyio.move_on_after(self._read_eof_response_drain_timeout) as scope:
469+
while self._in_flight or self._responses_in_flight:
470+
await anyio.sleep(0)
471+
if scope.cancelled_caught:
472+
logger.debug("timed out draining in-flight responses after read EOF")
454473
# EOF: wake blocked `send_raw_request` waiters with CONNECTION_CLOSED.
455474
self._running = False
456475
self._closed = True
@@ -722,16 +741,24 @@ async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None
722741
await self._write_stream.send(SessionMessage(message=message, metadata=metadata))
723742

724743
async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None:
744+
key = _coerce_id(request_id)
745+
self._responses_in_flight.add(key)
725746
try:
726747
await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result))
727748
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
728749
logger.debug("dropped result for %r: write stream closed", request_id)
750+
finally:
751+
self._responses_in_flight.discard(key)
729752

730753
async def _write_error(self, request_id: RequestId, error: ErrorData) -> None:
754+
key = _coerce_id(request_id)
755+
self._responses_in_flight.add(key)
731756
try:
732757
await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error))
733758
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
734759
logger.debug("dropped error for %r: write stream closed", request_id)
760+
finally:
761+
self._responses_in_flight.discard(key)
735762

736763
async def _final_write(
737764
self,

tests/server/test_cancel_handling.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,71 @@ async def run_server():
172172
assert handler_cancelled.is_set()
173173

174174

175+
@pytest.mark.anyio
176+
async def test_server_cancels_in_flight_handlers_when_read_eof_drain_times_out():
177+
"""A bounded read-EOF drain still cancels handlers that never finish."""
178+
handler_started = anyio.Event()
179+
handler_cancelled = anyio.Event()
180+
server_run_returned = anyio.Event()
181+
182+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
183+
handler_started.set()
184+
try:
185+
await anyio.sleep_forever()
186+
finally:
187+
handler_cancelled.set()
188+
raise AssertionError # pragma: no cover
189+
190+
server = Server("test", on_call_tool=handle_call_tool)
191+
192+
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
193+
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
194+
195+
async def run_server():
196+
await server.run(
197+
server_read,
198+
server_write,
199+
server.create_initialization_options(),
200+
drain_in_flight_on_read_eof=True,
201+
read_eof_response_drain_timeout=0.01,
202+
)
203+
server_run_returned.set()
204+
205+
init_req = JSONRPCRequest(
206+
jsonrpc="2.0",
207+
id=1,
208+
method="initialize",
209+
params=InitializeRequestParams(
210+
protocol_version=LATEST_PROTOCOL_VERSION,
211+
capabilities=ClientCapabilities(),
212+
client_info=Implementation(name="test", version="1.0"),
213+
).model_dump(by_alias=True, mode="json", exclude_none=True),
214+
)
215+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
216+
call_req = JSONRPCRequest(
217+
jsonrpc="2.0",
218+
id=2,
219+
method="tools/call",
220+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
221+
)
222+
223+
with anyio.fail_after(5):
224+
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
225+
tg.start_soon(run_server)
226+
227+
await to_server.send(SessionMessage(init_req))
228+
await from_server.receive()
229+
await to_server.send(SessionMessage(initialized))
230+
await to_server.send(SessionMessage(call_req))
231+
232+
await handler_started.wait()
233+
await to_server.aclose()
234+
235+
await server_run_returned.wait()
236+
237+
assert handler_cancelled.is_set()
238+
239+
175240
@pytest.mark.anyio
176241
async def test_server_handles_transport_close_with_pending_server_to_client_requests():
177242
"""When the transport closes while handlers are blocked on server→client

tests/server/test_stdio.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import json
23
import sys
34
import threading
45
from collections.abc import AsyncIterator
@@ -7,11 +8,12 @@
78

89
import anyio
910
import pytest
11+
from anyio.lowlevel import checkpoint
1012

1113
from mcp.server.mcpserver import MCPServer
1214
from mcp.server.stdio import stdio_server
1315
from mcp.shared.message import SessionMessage
14-
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
16+
from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
1517

1618

1719
@pytest.mark.anyio
@@ -142,6 +144,59 @@ def test_mcpserver_run_stdio_serves_until_stdin_closes(monkeypatch: pytest.Monke
142144
assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={})
143145

144146

147+
def test_mcpserver_run_stdio_drains_in_flight_tool_responses_after_stdin_eof(
148+
monkeypatch: pytest.MonkeyPatch,
149+
) -> None:
150+
"""stdin EOF must not drop responses for requests the server already accepted."""
151+
server = MCPServer(name="DrainStdioServer")
152+
153+
@server.tool()
154+
async def slow_echo(text: str) -> str:
155+
await checkpoint()
156+
return text
157+
158+
payload_lines = [
159+
JSONRPCRequest(
160+
jsonrpc="2.0",
161+
id=0,
162+
method="initialize",
163+
params={
164+
"protocolVersion": "2024-11-05",
165+
"capabilities": {},
166+
"clientInfo": {"name": "stdio-replay", "version": "0.1"},
167+
},
168+
).model_dump_json(by_alias=True, exclude_none=True),
169+
JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params={}).model_dump_json(
170+
by_alias=True, exclude_none=True
171+
),
172+
JSONRPCRequest(
173+
jsonrpc="2.0",
174+
id=1,
175+
method="tools/call",
176+
params={"name": "slow_echo", "arguments": {"text": "first"}},
177+
).model_dump_json(by_alias=True, exclude_none=True),
178+
JSONRPCRequest(
179+
jsonrpc="2.0",
180+
id=2,
181+
method="tools/call",
182+
params={"name": "slow_echo", "arguments": {"text": "second"}},
183+
).model_dump_json(by_alias=True, exclude_none=True),
184+
]
185+
stdin_bytes = io.BytesIO(("\n".join(payload_lines) + "\n").encode())
186+
captured = _KeepOpenBytesIO()
187+
monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8"))
188+
monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8"))
189+
190+
_run_stdio_bounded(server)
191+
192+
output = captured.getvalue().decode()
193+
responses = [json.loads(line) for line in output.splitlines() if line]
194+
195+
assert [response["id"] for response in responses] == [0, 1, 2]
196+
assert responses[1]["result"]["content"][0]["text"] == "first"
197+
assert responses[2]["result"]["content"][0]["text"] == "second"
198+
199+
145200
def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None:
146201
"""Code after `yield` in a lifespan runs when stdin EOF ends `run("stdio")`.
147202

0 commit comments

Comments
 (0)