Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.abc import TaskGroup, TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse

Expand Down Expand Up @@ -489,10 +489,14 @@ async def post_writer(
write_stream: MemoryObjectSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
*,
task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
) -> None:
"""Handle writing requests to the server."""
try:
async with write_stream_reader:
# Signal that we're ready to receive messages
task_status.started(None)
async for session_message in write_stream_reader:
message = session_message.message
metadata = (
Expand Down Expand Up @@ -606,7 +610,10 @@ async def streamablehttp_client(
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)

tg.start_soon(
# Use tg.start() to ensure post_writer is ready before yielding.
# This prevents a race condition where the client might try to send
# a message before the writer task is ready to receive it.
await tg.start(
transport.post_writer,
client,
write_stream_reader,
Expand Down
64 changes: 64 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,70 @@ async def test_handle_sse_event_skips_empty_data():
await read_stream.aclose()


@pytest.mark.anyio
async def test_streamablehttp_no_race_condition_on_consecutive_requests(basic_server: None, basic_server_url: str):
"""Test that consecutive requests after initialize() work reliably.

This test verifies the fix for the race condition where list_tools()
could intermittently return empty results immediately after initialize().
The fix ensures post_writer is fully ready before yielding from the
context manager by using tg.start() instead of tg.start_soon().

We run multiple iterations to catch any intermittent issues.
"""
for iteration in range(10): # pragma: no branch
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == SERVER_NAME

# Immediately call list_tools() - this should never fail or return empty
tools = await session.list_tools()
assert len(tools.tools) > 0, f"Iteration {iteration}: list_tools() returned empty"
assert tools.tools[0].name == "test_tool"

# Make several more consecutive requests to ensure stability
tools2 = await session.list_tools()
assert len(tools2.tools) == len(tools.tools)

# Read a resource
resource = await session.read_resource(uri=AnyUrl("foobar://test-iteration"))
assert len(resource.contents) == 1


@pytest.mark.anyio
async def test_streamablehttp_rapid_request_sequence(basic_server: None, basic_server_url: str):
"""Test that rapid sequences of requests work correctly.

This stress test verifies that the transport handles rapid request sequences
without race conditions or message loss.
"""
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
# Initialize
result = await session.initialize()
assert isinstance(result, InitializeResult)

# Rapid sequence of requests
for i in range(20):
tools = await session.list_tools()
assert len(tools.tools) == 10, f"Request {i}: Expected 10 tools, got {len(tools.tools)}"

# Verify we can still make other types of requests
resource = await session.read_resource(uri=AnyUrl("foobar://final-test"))
assert len(resource.contents) == 1


@pytest.mark.anyio
async def test_streamablehttp_client_receives_priming_event(
event_server: tuple[SimpleEventStore, str],
Expand Down