|
16 | 16 | from __future__ import annotations |
17 | 17 |
|
18 | 18 | import logging |
19 | | -from collections.abc import Mapping |
| 19 | +from collections.abc import Awaitable, Mapping |
20 | 20 | from dataclasses import dataclass, field |
21 | 21 | from functools import partial, reduce |
22 | 22 | from typing import TYPE_CHECKING, Any, Generic, cast |
|
33 | 33 | from mcp.shared._otel import extract_trace_context, otel_span |
34 | 34 | from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest |
35 | 35 | from mcp.shared.exceptions import MCPError |
| 36 | +from mcp.shared.jsonrpc_dispatcher import handler_exception_to_error_data |
36 | 37 | from mcp.shared.message import MessageMetadata, ServerMessageMetadata |
37 | 38 | from mcp.shared.transport_context import TransportContext |
38 | 39 | from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS |
|
46 | 47 | Implementation, |
47 | 48 | InitializeRequestParams, |
48 | 49 | InitializeResult, |
| 50 | + JSONRPCError, |
| 51 | + JSONRPCResponse, |
| 52 | + RequestId, |
49 | 53 | RequestParams, |
50 | 54 | RequestParamsMeta, |
51 | 55 | ) |
|
54 | 58 | if TYPE_CHECKING: |
55 | 59 | from mcp.server.lowlevel.server import Server |
56 | 60 |
|
57 | | -__all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "otel_middleware"] |
| 61 | +__all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "aclose_shielded", "otel_middleware", "to_jsonrpc_response"] |
58 | 62 |
|
59 | 63 | logger = logging.getLogger(__name__) |
60 | 64 |
|
|
64 | 68 | _INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) |
65 | 69 |
|
66 | 70 | _EXIT_STACK_CLOSE_TIMEOUT: float = 5 |
67 | | -"""Bound for the shielded exit-stack unwind in `run()`; a hung cleanup |
68 | | -callback must not wedge shutdown.""" |
| 71 | +"""Bound for `aclose_shielded`'s exit-stack unwind; a hung cleanup callback |
| 72 | +must not wedge shutdown.""" |
69 | 73 |
|
70 | 74 |
|
71 | 75 | def _extract_meta(params: Mapping[str, Any] | None) -> RequestParamsMeta | None: |
@@ -169,6 +173,47 @@ def _dump_result(result: Any) -> dict[str, Any]: |
169 | 173 | raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") |
170 | 174 |
|
171 | 175 |
|
| 176 | +async def aclose_shielded(connection: Connection) -> None: |
| 177 | + """Unwind ``connection.exit_stack`` under a shielded, bounded scope. |
| 178 | +
|
| 179 | + Called from a driver's ``finally``: the shield lets per-connection cleanup |
| 180 | + callbacks run even when the driver itself is being cancelled, the |
| 181 | + `_EXIT_STACK_CLOSE_TIMEOUT` bound stops a hung callback wedging shutdown, |
| 182 | + and a raising callback is logged-and-swallowed so it never masks the |
| 183 | + driver's own exception. |
| 184 | + """ |
| 185 | + with anyio.move_on_after(_EXIT_STACK_CLOSE_TIMEOUT, shield=True) as scope: |
| 186 | + try: |
| 187 | + await connection.exit_stack.aclose() |
| 188 | + except Exception: |
| 189 | + logger.exception("connection exit_stack cleanup raised") |
| 190 | + if scope.cancelled_caught: |
| 191 | + logger.warning( |
| 192 | + "connection exit_stack cleanup exceeded %s seconds; abandoning remaining callbacks", |
| 193 | + _EXIT_STACK_CLOSE_TIMEOUT, |
| 194 | + ) |
| 195 | + |
| 196 | + |
| 197 | +async def to_jsonrpc_response(request_id: RequestId, coro: Awaitable[dict[str, Any]]) -> JSONRPCResponse | JSONRPCError: |
| 198 | + """Await ``coro`` and wrap its outcome as the JSON-RPC reply for ``request_id``. |
| 199 | +
|
| 200 | + The exception-to-wire boundary for the request-per-call drivers |
| 201 | + (`serve_one`, the modern HTTP entry). `MCPError` and `ValidationError` |
| 202 | + map via the shared `handler_exception_to_error_data` ladder; any other |
| 203 | + exception is logged and surfaced as `INTERNAL_ERROR` so handler internals |
| 204 | + never reach the wire. |
| 205 | + """ |
| 206 | + try: |
| 207 | + result = await coro |
| 208 | + except Exception as exc: |
| 209 | + error = handler_exception_to_error_data(exc) |
| 210 | + if error is None: |
| 211 | + logger.exception("request handler raised") |
| 212 | + error = ErrorData(code=INTERNAL_ERROR, message="Internal server error") |
| 213 | + return JSONRPCError(jsonrpc="2.0", id=request_id, error=error) |
| 214 | + return JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result) |
| 215 | + |
| 216 | + |
172 | 217 | @dataclass |
173 | 218 | class ServerRunner(Generic[LifespanT]): |
174 | 219 | """Per-connection orchestrator. One instance per client connection.""" |
@@ -205,26 +250,14 @@ async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STAT |
205 | 250 | to `dispatcher.run()`. `task_status.started()` is forwarded so callers |
206 | 251 | can `await tg.start(runner.run)` and resume once the dispatcher is |
207 | 252 | ready to accept requests. Once the dispatcher exits, |
208 | | - `connection.exit_stack` is unwound (shielded from outer cancellation, |
209 | | - bounded by `_EXIT_STACK_CLOSE_TIMEOUT`) so any per-connection cleanup |
210 | | - registered by handlers or middleware gets a chance to run without a |
211 | | - misbehaving callback hanging shutdown indefinitely. |
| 253 | + `connection.exit_stack` is unwound via `aclose_shielded` so any |
| 254 | + per-connection cleanup registered by handlers or middleware gets a |
| 255 | + chance to run. |
212 | 256 | """ |
213 | 257 | try: |
214 | 258 | await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) |
215 | 259 | finally: |
216 | | - with anyio.move_on_after(_EXIT_STACK_CLOSE_TIMEOUT, shield=True) as scope: |
217 | | - try: |
218 | | - await self.connection.exit_stack.aclose() |
219 | | - except Exception: |
220 | | - # Raising here would mask dispatcher.run()'s exception and |
221 | | - # crash stdio servers on normal disconnect. |
222 | | - logger.exception("connection exit_stack cleanup raised") |
223 | | - if scope.cancelled_caught: |
224 | | - logger.warning( |
225 | | - "connection exit_stack cleanup exceeded %s seconds; abandoning remaining callbacks", |
226 | | - _EXIT_STACK_CLOSE_TIMEOUT, |
227 | | - ) |
| 260 | + await aclose_shielded(self.connection) |
228 | 261 |
|
229 | 262 | def _compose_on_request(self) -> OnRequest: |
230 | 263 | """Wrap `_on_request` in `dispatch_middleware`, outermost-first. |
|
0 commit comments