diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index fdb127ca0a..53fa17d8d8 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -23,6 +23,7 @@ from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.version import MODERN_PROTOCOL_VERSIONS from mcp.types import ( + CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_REQUEST, PARSE_ERROR, @@ -413,10 +414,16 @@ async def _handle_sse_response( except Exception: logger.debug("SSE stream ended", exc_info=True) # pragma: no cover - # Stream ended without response - reconnect if we received an event with ID - if last_event_id is not None: # pragma: no branch - logger.info("SSE stream disconnected, reconnecting...") - await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) + # Stream ended without a terminal response/error. If the server provided an event id, + # try resuming; otherwise fail the request instead of hanging forever. + if last_event_id is None: + error_data = ErrorData(code=CONNECTION_CLOSED, message="SSE stream disconnected before response completed") + error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data)) + await ctx.read_stream_writer.send(error_msg) + return + + logger.info("SSE stream disconnected, reconnecting...") + await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) async def _handle_reconnection( self, @@ -427,7 +434,16 @@ async def _handle_reconnection( ) -> None: """Reconnect with Last-Event-ID to resume stream after server disconnect.""" # Bail if max retries exceeded - if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover + if attempt >= MAX_RECONNECTION_ATTEMPTS: + assert isinstance(ctx.session_message.message, JSONRPCRequest) + original_request_id = ctx.session_message.message.id + error_data = ErrorData( + code=CONNECTION_CLOSED, + message="SSE stream disconnected and could not be resumed", + data={"last_event_id": last_event_id}, + ) + error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data)) + await ctx.read_stream_writer.send(error_msg) logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") return @@ -471,7 +487,7 @@ async def _handle_reconnection( # Stream ended again without response - reconnect again (reset attempt counter) logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0) - except Exception as e: # pragma: no cover + except Exception as e: logger.debug(f"Reconnection failed: {e}") # Try to reconnect again if we still have an event ID await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index bbe3e67fee..e372406e4f 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -17,11 +17,21 @@ from mcp.client import ClientSession from mcp.client.streamable_http import ( MCP_PROTOCOL_VERSION, + RequestContext, StreamableHTTPTransport, _encode_header_value, streamable_http_client, ) -from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse +from mcp.shared._context_streams import create_context_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CONNECTION_CLOSED, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, +) @pytest.mark.parametrize( @@ -98,6 +108,31 @@ def test_mcp_name_header_values_are_base64_wrapped_when_unsafe_for_an_http_field assert encoded == raw +@pytest.mark.anyio +async def test_sse_response_disconnect_before_any_event_id_fails_request() -> None: + transport = StreamableHTTPTransport("http://example.com/mcp") + async with httpx.AsyncClient() as client: + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](1) + request = JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": "noop", "arguments": {}}) + ctx = RequestContext( + client=client, + session_id=None, + session_message=SessionMessage(request), + metadata=None, + read_stream_writer=read_stream_writer, + ) + response = httpx.Response(200, headers={"content-type": "text/event-stream"}, content=b"") + + async with read_stream_writer, read_stream: + await transport._handle_sse_response(response, ctx) + message = await read_stream.receive() + + assert isinstance(message, SessionMessage) + assert isinstance(message.message, JSONRPCError) + assert message.message.id == 1 + assert message.message.error.code == CONNECTION_CLOSED + + @pytest.mark.anyio async def test_pinned_transport_ignores_returned_session_id_and_never_opens_get_or_delete() -> None: """A server-issued ``Mcp-Session-Id`` never reaches a pinned client's wire: only POSTs are sent. diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index b22df0ff2b..34ce43b911 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -296,6 +296,82 @@ async def call() -> None: assert received == snapshot(["before close", "after close"]) +@requirement("hosting:resume:close-stream") +@requirement("transport:streamable-http:resumability") +@requirement("client-transport:http:reconnect-post-priming") +@requirement("client-transport:http:reconnect-retry-value") +async def test_a_call_whose_stream_closes_and_cannot_be_resumed_fails_instead_of_hanging() -> None: + """If a resumable response stream disconnects and the server session is gone, the client fails + the request instead of hanging forever. + + The server closes the call's SSE stream after emitting one related notification. The test then + deletes the active server-side session to force the client's reconnect GET to return 404. + Without a terminal response/error on the read stream, ClientSession.send_request waits forever + (read timeout defaults to None). The transport must surface a request-scoped error when it + gives up reconnecting. + """ + reconnect_attempted = anyio.Event() + allow_exit = anyio.Event() + done = anyio.Event() + raised: list[BaseException] = [] + manager_ref = None + deleted_session = False + + mcp = MCPServer("resumable") + + @mcp.tool() + async def interrupt(ctx: Context) -> str: + await ctx.info("before close") # pyright: ignore[reportDeprecated] + await ctx.close_sse_stream() + await allow_exit.wait() + return "unreachable" + + async def record_request(request: httpx.Request) -> None: + nonlocal deleted_session + if request.method != "GET": + return + if request.headers.get("last-event-id") is None: + return + reconnect_attempted.set() + if deleted_session or manager_ref is None: + return + session_ids = list(manager_ref._server_instances.keys()) + if session_ids: # pragma: no branch + del manager_ref._server_instances[session_ids[0]] + deleted_session = True + + async with mounted_app(mcp, event_store=SequencedEventStore(), retry_interval=0, on_request=record_request) as ( + http, + manager, + ): + manager_ref = manager + with anyio.fail_after(5): # pragma: no branch + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r, w), + ClientSession(r, w) as session, + anyio.create_task_group() as tg, + ): + await session.initialize() + + async def call() -> None: + try: + await session.call_tool("interrupt", {}) + except BaseException as exc: + raised.append(exc) + finally: + done.set() + + tg.start_soon(call) + await reconnect_attempted.wait() + await done.wait() + allow_exit.set() + tg.cancel_scope.cancel() + + assert len(raised) == 1 + assert isinstance(raised[0], Exception) + assert "disconnected" in str(raised[0]).lower() + + @requirement("client-transport:http:resume-stream-api") async def test_a_captured_resumption_token_replays_missed_messages_on_a_new_connection() -> None: """A resumption token captured via on_resumption_token_update on one connection lets a fresh