diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index cceefccce..4dc2be931 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -455,6 +455,46 @@ async def _receive_loop(self) -> None: pass self._response_streams.clear() + def _pop_response_stream( + self, response_id: RequestId + ) -> MemoryObjectSendStream[JSONRPCResponse | JSONRPCError] | None: + """ + Pop a response stream by ID, trying alternative type representations. + + JSON-RPC allows request IDs to be strings or integers. Some servers may + return a string ID (e.g., "0") even when the client sent an integer ID + (e.g., 0). Since Python dict lookups are type-sensitive (0 != "0"), + this method tries both representations to handle such mismatches. + + Args: + response_id: The response ID from the incoming message. + + Returns: + The response stream if found, None otherwise. + """ + # Try exact match first + stream = self._response_streams.pop(response_id, None) + if stream is not None: + return stream + + # Try alternative type representation + if isinstance(response_id, str): + # Response ID is string, try integer lookup + try: + int_id = int(response_id) + stream = self._response_streams.pop(int_id, None) + if stream is not None: + return stream + except ValueError: + pass + else: + # Response ID is integer, try string lookup + stream = self._response_streams.pop(str(response_id), None) + if stream is not None: + return stream + + return None + async def _handle_response(self, message: SessionMessage) -> None: """ Handle an incoming response or error message. @@ -486,8 +526,8 @@ async def _handle_response(self, message: SessionMessage) -> None: if router.route_response(response_id, response_data): return # Handled - # Fall back to normal response streams - stream = self._response_streams.pop(response_id, None) + # Fall back to normal response streams, with ID type normalization + stream = self._pop_response_stream(response_id) if stream: # pragma: no cover await stream.send(root) else: # pragma: no cover diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 313ec9926..e609397e5 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -9,12 +9,18 @@ from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session +from mcp.shared.message import SessionMessage from mcp.types import ( CancelledNotification, CancelledNotificationParams, ClientNotification, ClientRequest, EmptyResult, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, TextContent, ) @@ -122,6 +128,171 @@ async def make_request(client_session: ClientSession): await ev_cancelled.wait() +@pytest.mark.anyio +async def test_response_id_type_mismatch_string_to_int(): + """ + Test that responses with string IDs are correctly matched to requests sent with + integer IDs. + + This handles the case where a server returns "id": "0" (string) but the client + sent "id": 0 (integer). Without ID type normalization, this would cause a timeout. + """ + ev_response_received = anyio.Event() + result_holder: list[types.EmptyResult] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Receive a request and respond with a string ID instead of integer.""" + message = await server_read.receive() + assert isinstance(message, SessionMessage) + root = message.message.root + assert isinstance(root, JSONRPCRequest) + # Get the original request ID (which is an integer) + request_id = root.id + assert isinstance(request_id, int), f"Expected int, got {type(request_id)}" + + # Respond with the ID as a string (simulating a buggy server) + response = JSONRPCResponse( + jsonrpc="2.0", + id=str(request_id), # Convert to string to simulate mismatch + result={}, + ) + await server_write.send(SessionMessage(message=JSONRPCMessage(response))) + + async def make_request(client_session: ClientSession): + nonlocal result_holder + # Send a ping request (uses integer ID internally) + result = await client_session.send_ping() + result_holder.append(result) + ev_response_received.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(2): + await ev_response_received.wait() + + assert len(result_holder) == 1 + assert isinstance(result_holder[0], EmptyResult) + + +@pytest.mark.anyio +async def test_error_response_id_type_mismatch_string_to_int(): + """ + Test that error responses with string IDs are correctly matched to requests + sent with integer IDs. + + This handles the case where a server returns an error with "id": "0" (string) + but the client sent "id": 0 (integer). + """ + ev_error_received = anyio.Event() + error_holder: list[McpError] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Receive a request and respond with an error using a string ID.""" + message = await server_read.receive() + assert isinstance(message, SessionMessage) + root = message.message.root + assert isinstance(root, JSONRPCRequest) + request_id = root.id + assert isinstance(request_id, int) + + # Respond with an error, using the ID as a string + error_response = JSONRPCError( + jsonrpc="2.0", + id=str(request_id), # Convert to string to simulate mismatch + error=ErrorData(code=-32600, message="Test error"), + ) + await server_write.send(SessionMessage(message=JSONRPCMessage(error_response))) + + async def make_request(client_session: ClientSession): + nonlocal error_holder + try: + await client_session.send_ping() + pytest.fail("Expected McpError to be raised") # pragma: no cover + except McpError as e: + error_holder.append(e) + ev_error_received.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(2): + await ev_error_received.wait() + + assert len(error_holder) == 1 + assert "Test error" in str(error_holder[0]) + + +@pytest.mark.anyio +async def test_response_id_non_numeric_string_no_match(): + """ + Test that responses with non-numeric string IDs don't incorrectly match + integer request IDs. + + If a server returns "id": "abc" (non-numeric string), it should not match + a request sent with "id": 0 (integer). + """ + ev_timeout = anyio.Event() + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Receive a request and respond with a non-numeric string ID.""" + message = await server_read.receive() + assert isinstance(message, SessionMessage) + + # Respond with a non-numeric string ID (should not match) + response = JSONRPCResponse( + jsonrpc="2.0", + id="not_a_number", # Non-numeric string + result={}, + ) + await server_write.send(SessionMessage(message=JSONRPCMessage(response))) + + async def make_request(client_session: ClientSession): + try: + # Use a short timeout since we expect this to fail + from datetime import timedelta + + await client_session.send_request( + ClientRequest(types.PingRequest()), + types.EmptyResult, + request_read_timeout_seconds=timedelta(seconds=0.5), + ) + pytest.fail("Expected timeout") # pragma: no cover + except McpError as e: + assert "Timed out" in str(e) + ev_timeout.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(2): + await ev_timeout.wait() + + @pytest.mark.anyio async def test_connection_closed(): """