Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,46 @@ async def _receive_loop(self) -> None:
pass
self._response_streams.clear()

def _pop_response_stream(
self, response_id: RequestId
) -> MemoryObjectSendStream[JSONRPCResponse | JSONRPCError] | None:
"""
Pop a response stream by ID, trying alternative type representations.

JSON-RPC allows request IDs to be strings or integers. Some servers may
return a string ID (e.g., "0") even when the client sent an integer ID
(e.g., 0). Since Python dict lookups are type-sensitive (0 != "0"),
this method tries both representations to handle such mismatches.

Args:
response_id: The response ID from the incoming message.

Returns:
The response stream if found, None otherwise.
"""
# Try exact match first
stream = self._response_streams.pop(response_id, None)
if stream is not None:
return stream

# Try alternative type representation
if isinstance(response_id, str):
# Response ID is string, try integer lookup
try:
int_id = int(response_id)
stream = self._response_streams.pop(int_id, None)
if stream is not None:
return stream
except ValueError:
pass
else:
# Response ID is integer, try string lookup
stream = self._response_streams.pop(str(response_id), None)
if stream is not None:
return stream

return None

async def _handle_response(self, message: SessionMessage) -> None:
"""
Handle an incoming response or error message.
Expand Down Expand Up @@ -486,8 +526,8 @@ async def _handle_response(self, message: SessionMessage) -> None:
if router.route_response(response_id, response_data):
return # Handled

# Fall back to normal response streams
stream = self._response_streams.pop(response_id, None)
# Fall back to normal response streams, with ID type normalization
stream = self._pop_response_stream(response_id)
if stream: # pragma: no cover
await stream.send(root)
else: # pragma: no cover
Expand Down
171 changes: 171 additions & 0 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@
from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session
from mcp.shared.message import SessionMessage
from mcp.types import (
CancelledNotification,
CancelledNotificationParams,
ClientNotification,
ClientRequest,
EmptyResult,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCRequest,
JSONRPCResponse,
TextContent,
)

Expand Down Expand Up @@ -122,6 +128,171 @@ async def make_request(client_session: ClientSession):
await ev_cancelled.wait()


@pytest.mark.anyio
async def test_response_id_type_mismatch_string_to_int():
"""
Test that responses with string IDs are correctly matched to requests sent with
integer IDs.

This handles the case where a server returns "id": "0" (string) but the client
sent "id": 0 (integer). Without ID type normalization, this would cause a timeout.
"""
ev_response_received = anyio.Event()
result_holder: list[types.EmptyResult] = []

async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams

async def mock_server():
"""Receive a request and respond with a string ID instead of integer."""
message = await server_read.receive()
assert isinstance(message, SessionMessage)
root = message.message.root
assert isinstance(root, JSONRPCRequest)
# Get the original request ID (which is an integer)
request_id = root.id
assert isinstance(request_id, int), f"Expected int, got {type(request_id)}"

# Respond with the ID as a string (simulating a buggy server)
response = JSONRPCResponse(
jsonrpc="2.0",
id=str(request_id), # Convert to string to simulate mismatch
result={},
)
await server_write.send(SessionMessage(message=JSONRPCMessage(response)))

async def make_request(client_session: ClientSession):
nonlocal result_holder
# Send a ping request (uses integer ID internally)
result = await client_session.send_ping()
result_holder.append(result)
ev_response_received.set()

async with (
anyio.create_task_group() as tg,
ClientSession(read_stream=client_read, write_stream=client_write) as client_session,
):
tg.start_soon(mock_server)
tg.start_soon(make_request, client_session)

with anyio.fail_after(2):
await ev_response_received.wait()

assert len(result_holder) == 1
assert isinstance(result_holder[0], EmptyResult)


@pytest.mark.anyio
async def test_error_response_id_type_mismatch_string_to_int():
"""
Test that error responses with string IDs are correctly matched to requests
sent with integer IDs.

This handles the case where a server returns an error with "id": "0" (string)
but the client sent "id": 0 (integer).
"""
ev_error_received = anyio.Event()
error_holder: list[McpError] = []

async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams

async def mock_server():
"""Receive a request and respond with an error using a string ID."""
message = await server_read.receive()
assert isinstance(message, SessionMessage)
root = message.message.root
assert isinstance(root, JSONRPCRequest)
request_id = root.id
assert isinstance(request_id, int)

# Respond with an error, using the ID as a string
error_response = JSONRPCError(
jsonrpc="2.0",
id=str(request_id), # Convert to string to simulate mismatch
error=ErrorData(code=-32600, message="Test error"),
)
await server_write.send(SessionMessage(message=JSONRPCMessage(error_response)))

async def make_request(client_session: ClientSession):
nonlocal error_holder
try:
await client_session.send_ping()
pytest.fail("Expected McpError to be raised") # pragma: no cover
except McpError as e:
error_holder.append(e)
ev_error_received.set()

async with (
anyio.create_task_group() as tg,
ClientSession(read_stream=client_read, write_stream=client_write) as client_session,
):
tg.start_soon(mock_server)
tg.start_soon(make_request, client_session)

with anyio.fail_after(2):
await ev_error_received.wait()

assert len(error_holder) == 1
assert "Test error" in str(error_holder[0])


@pytest.mark.anyio
async def test_response_id_non_numeric_string_no_match():
"""
Test that responses with non-numeric string IDs don't incorrectly match
integer request IDs.

If a server returns "id": "abc" (non-numeric string), it should not match
a request sent with "id": 0 (integer).
"""
ev_timeout = anyio.Event()

async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams

async def mock_server():
"""Receive a request and respond with a non-numeric string ID."""
message = await server_read.receive()
assert isinstance(message, SessionMessage)

# Respond with a non-numeric string ID (should not match)
response = JSONRPCResponse(
jsonrpc="2.0",
id="not_a_number", # Non-numeric string
result={},
)
await server_write.send(SessionMessage(message=JSONRPCMessage(response)))

async def make_request(client_session: ClientSession):
try:
# Use a short timeout since we expect this to fail
from datetime import timedelta

await client_session.send_request(
ClientRequest(types.PingRequest()),
types.EmptyResult,
request_read_timeout_seconds=timedelta(seconds=0.5),
)
pytest.fail("Expected timeout") # pragma: no cover
except McpError as e:
assert "Timed out" in str(e)
ev_timeout.set()

async with (
anyio.create_task_group() as tg,
ClientSession(read_stream=client_read, write_stream=client_write) as client_session,
):
tg.start_soon(mock_server)
tg.start_soon(make_request, client_session)

with anyio.fail_after(2):
await ev_timeout.wait()


@pytest.mark.anyio
async def test_connection_closed():
"""
Expand Down
Loading