Skip to content

Commit 127b209

Browse files
committed
Mint priming before per-request state; cover old-pv replay path
Hoists _mint_priming_event to the top of the SSE arm so a user EventStore raising on the priming row returns a 500 with no per-request state allocated (previously _request_streams[id] and _sse_stream_writers[id] leaked for the session). The shared _request_streams registration is pushed into each branch. Adds an old-pv-reconnect test in test_hosting_resume.py covering the priming_event-is-None replay arm; drops the no-branch pragma. The new priming-failure test covers the outer except handler, so its pragmas and the dead 'if writer:' check are removed.
1 parent dde2df9 commit 127b209

3 files changed

Lines changed: 115 additions & 22 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -546,15 +546,13 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
546546
else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
547547
)
548548

549-
# Extract the request ID outside the try block for proper scope
550549
request_id = str(message.id)
551-
# Register this stream for the request ID
552-
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
553-
REQUEST_STREAM_BUFFER_SIZE
554-
)
555-
request_stream_reader = self._request_streams[request_id][1]
556550

557551
if self.is_json_response_enabled:
552+
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
553+
REQUEST_STREAM_BUFFER_SIZE
554+
)
555+
request_stream_reader = self._request_streams[request_id][1]
558556
# Process the message
559557
metadata = ServerMessageMetadata(request_context=request)
560558
session_message = SessionMessage(message, metadata=metadata)
@@ -598,16 +596,18 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
598596
finally:
599597
await self._clean_up_memory_streams(request_id)
600598
else:
601-
# Create SSE stream
602-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
599+
# Mint the priming event before any per-request state exists:
600+
# `EventStore.store_event` is user code and may raise, in which
601+
# case the outer handler returns a 500 with nothing to clean up.
602+
# Still strictly precedes dispatch, so storage order == wire order.
603+
priming_event = await self._mint_priming_event(request_id, protocol_version)
603604

604-
# Store writer reference so close_sse_stream() can close it
605+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
605606
self._sse_stream_writers[request_id] = sse_stream_writer
606-
607-
# Store the priming event before the request is dispatched so its
608-
# event-store position precedes anything message_router can store
609-
# for this id (storage order == wire order by construction).
610-
priming_event = await self._mint_priming_event(request_id, protocol_version)
607+
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
608+
REQUEST_STREAM_BUFFER_SIZE
609+
)
610+
request_stream_reader = self._request_streams[request_id][1]
611611

612612
headers = {
613613
"Cache-Control": "no-cache, no-transform",
@@ -638,20 +638,16 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
638638
finally:
639639
await sse_stream_reader.aclose()
640640

641-
except Exception as err: # pragma: lax no cover
642-
# Reached only when something raises during POST handling outside
643-
# the per-SSE-stream guard above; whether tests reach this depends
644-
# on client teardown timing.
641+
except Exception as err:
645642
logger.exception("Error handling POST request")
646643
response = self._create_error_response(
647644
f"Error handling POST request: {err}",
648645
HTTPStatus.INTERNAL_SERVER_ERROR,
649646
INTERNAL_ERROR,
650647
)
651648
await response(scope, receive, send)
652-
if writer:
653-
await writer.send(Exception(err))
654-
return # pragma: no cover
649+
await writer.send(Exception(err))
650+
return
655651

656652
async def _handle_get_request(self, request: Request, send: Send) -> None:
657653
"""Handle GET request to establish SSE.
@@ -900,7 +896,7 @@ async def send_event(event_message: EventMessage) -> None:
900896
# is re-registered. The replay→live-tail ordering window here
901897
# is pre-existing and tracked separately.
902898
priming_event = await self._mint_priming_event(stream_id, replay_protocol_version)
903-
if priming_event is not None: # pragma: no branch
899+
if priming_event is not None:
904900
await sse_stream_writer.send(priming_event)
905901

906902
# Create new request streams for this connection

tests/interaction/transports/test_hosting_resume.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,46 @@ async def count(ctx: Context) -> str:
182182
)
183183

184184

185+
@requirement("hosting:resume:priming")
186+
async def test_a_pre_2025_11_25_reconnect_replays_without_minting_a_priming_event() -> None:
187+
"""A pre-2025-11-25 client reconnecting via Last-Event-ID gets the replay with no priming row.
188+
189+
The store-length assertion is the load-bearing proof that no priming cursor was minted.
190+
"""
191+
release = anyio.Event()
192+
store = SequencedEventStore()
193+
mcp = MCPServer("resumable")
194+
195+
@mcp.tool()
196+
async def count(ctx: Context) -> str:
197+
await ctx.info("tick 1") # pyright: ignore[reportDeprecated]
198+
await release.wait()
199+
await ctx.info("tick 2") # pyright: ignore[reportDeprecated]
200+
return "counted"
201+
202+
async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _):
203+
session_id = await initialize_via_http(http)
204+
with anyio.fail_after(5):
205+
async with http.stream(
206+
"POST", "/mcp", content=_tools_call(1, "count", {}), headers=base_headers(session_id=session_id)
207+
) as response:
208+
_, first = await _read_events(response, 2)
209+
release.set()
210+
await store.wait_until_stored(6)
211+
old_client_headers = base_headers(session_id=session_id) | {
212+
"mcp-protocol-version": "2025-06-18",
213+
"last-event-id": first.id,
214+
}
215+
async with http.stream("GET", "/mcp", headers=old_client_headers) as replay: # pragma: no branch
216+
assert replay.status_code == 200
217+
missed = await _read_events(replay, 2)
218+
219+
assert [(event.id, bool(event.data)) for event in missed] == snapshot([("5", True), ("6", True)])
220+
# No priming cursor was minted on reconnect: the store still holds only the six rows
221+
# written before the GET (init priming+response, POST priming, tick 1, tick 2, result).
222+
assert len(store._events) == 6
223+
224+
185225
@requirement("hosting:resume:bad-event-id")
186226
async def test_an_unknown_last_event_id_yields_an_empty_replay_stream() -> None:
187227
"""A Last-Event-ID the event store cannot map produces an empty SSE stream rather than an error.

tests/server/test_streamable_http_router.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import anyio
44
import pytest
5+
from starlette.types import Message, Scope
56

67
from mcp.server.streamable_http import (
78
REQUEST_STREAM_BUFFER_SIZE,
@@ -28,6 +29,14 @@ async def replay_events_after(self, last_event_id: EventId, send_callback: Event
2829
raise NotImplementedError
2930

3031

32+
class _PrimingFailingStore(EventStore):
33+
async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId:
34+
raise RuntimeError("backend unavailable")
35+
36+
async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None:
37+
raise NotImplementedError
38+
39+
3140
@pytest.mark.anyio
3241
async def test_router_unconsumed_request_stream_does_not_block_siblings() -> None:
3342
"""A response whose `sse_writer` is not yet receiving must not park the router (#1764).
@@ -99,3 +108,51 @@ async def test_priming_event_is_stored_before_any_routed_message() -> None:
99108
assert store.stored[0] == ("A", None)
100109
assert [sid for sid, _ in store.stored] == ["A"] * 6
101110
assert all(msg is not None for _, msg in store.stored[1:])
111+
112+
113+
@pytest.mark.anyio
114+
async def test_priming_store_failure_leaves_no_per_request_state() -> None:
115+
"""`EventStore.store_event` raising on the priming row must not leak per-request entries."""
116+
transport = StreamableHTTPServerTransport(
117+
mcp_session_id=None,
118+
is_json_response_enabled=False,
119+
event_store=_PrimingFailingStore(),
120+
)
121+
122+
body = b'{"jsonrpc":"2.0","id":"req-1","method":"tools/list","params":{}}'
123+
scope: Scope = {
124+
"type": "http",
125+
"method": "POST",
126+
"path": "/",
127+
"query_string": b"",
128+
"headers": [
129+
(b"accept", b"application/json, text/event-stream"),
130+
(b"content-type", b"application/json"),
131+
(b"mcp-protocol-version", b"2025-11-25"),
132+
],
133+
}
134+
body_sent = False
135+
136+
async def receive() -> Message:
137+
nonlocal body_sent
138+
if not body_sent:
139+
body_sent = True
140+
return {"type": "http.request", "body": body, "more_body": False}
141+
raise NotImplementedError
142+
143+
sent: list[Message] = []
144+
145+
async def asgi_send(message: Message) -> None:
146+
sent.append(message)
147+
148+
async with transport.connect() as (read_stream, _write_stream):
149+
async with anyio.create_task_group() as tg:
150+
tg.start_soon(transport.handle_request, scope, receive, asgi_send)
151+
with anyio.fail_after(5):
152+
forwarded = await read_stream.receive()
153+
assert isinstance(forwarded, Exception)
154+
155+
assert transport._request_streams == {}
156+
assert transport._sse_stream_writers == {}
157+
assert sent[0]["type"] == "http.response.start"
158+
assert sent[0]["status"] == 500

0 commit comments

Comments
 (0)