Skip to content

Commit 2522d4c

Browse files
Kludexmaxisbey
authored andcommitted
Slim ServerMiddleware to (ctx, call_next) and add OpenTelemetryMiddleware
Move `method` and `params` onto `ServerRequestContext` so context-tier middleware reads `ctx.method`/`ctx.params` instead of separate positional args. `CallNext` now takes the context, so middleware can rewrite the inbound message with `call_next(replace(ctx, params=...))`. Add a context-tier `OpenTelemetryMiddleware` alongside the existing dispatch-tier `otel_middleware`, which is left intact.
1 parent 2397319 commit 2522d4c

11 files changed

Lines changed: 296 additions & 87 deletions

File tree

docs/migration.md

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ ctx: ClientRequestContext
746746
server_ctx: ServerRequestContext[LifespanContextT, RequestT]
747747
```
748748

749-
`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) plus a new `protocol_version: str` field, so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`.
749+
`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) plus new `protocol_version: str`, `method: str`, and raw `params: Mapping[str, Any] | None` fields (the last two let middleware read and rewrite the inbound message), so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`.
750750

751751
The high-level `Context` class (injected into `@mcp.tool()` etc.) similarly dropped its `ServerSessionT` parameter: `Context[ServerSessionT, LifespanContextT, RequestT]``Context[LifespanContextT, RequestT]`. Both remaining parameters have defaults, so bare `Context` is usually sufficient:
752752

@@ -961,27 +961,24 @@ server.add_notification_handler("notifications/custom", MyNotifyParams, my_notif
961961
These were private, but some users subclassed `Server` and overrode them to intercept requests. Use middleware instead:
962962

963963
```python
964-
from collections.abc import Mapping
965964
from typing import Any
966965

967966
from mcp.server import Server, ServerRequestContext
968967
from mcp.server.context import CallNext, HandlerResult
969968

970969

971-
async def logging_middleware(
972-
ctx: ServerRequestContext[Any, Any], method: str, params: Mapping[str, Any] | None, call_next: CallNext
973-
) -> HandlerResult:
974-
print(f"handling {method}")
975-
result = await call_next()
976-
print(f"done {method}")
970+
async def logging_middleware(ctx: ServerRequestContext[Any, Any], call_next: CallNext) -> HandlerResult:
971+
print(f"handling {ctx.method}")
972+
result = await call_next(ctx)
973+
print(f"done {ctx.method}")
977974
return result
978975

979976

980977
server = Server("my-server", on_call_tool=...)
981978
server.middleware.append(logging_middleware)
982979
```
983980

984-
Middleware runs before params validation, so `params` is the raw inbound mapping (or `None`), and it also wraps unknown methods.
981+
The method and the raw inbound params are `ctx.method` and `ctx.params` (`params` is `None` when the message carries none). Middleware runs before params validation and also wraps unknown methods. To rewrite the method or params before the handler runs, pass an adjusted context through: `await call_next(replace(ctx, params=...))`.
985982

986983
### Lowlevel `Server.run(raise_exceptions=True)`: transport errors no longer re-raised
987984

src/mcp/server/_otel.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from opentelemetry.trace import SpanKind, StatusCode
6+
from pydantic import ValidationError
7+
8+
from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext
9+
from mcp.shared._otel import extract_trace_context, otel_span
10+
from mcp.shared.exceptions import MCPError
11+
12+
13+
class OpenTelemetryMiddleware(ServerMiddleware[Any]):
14+
"""Context-tier middleware that wraps each inbound message in an OpenTelemetry span.
15+
16+
Span name `"MCP handle <method> [<target>]"`, `mcp.method.name` attribute, W3C
17+
trace context extracted from `params._meta` (SEP-414), and an ERROR status if
18+
the handler raises. Requests and notifications both get a span;
19+
`jsonrpc.request.id` is set only when `ctx.request_id` is present (notifications
20+
have none).
21+
"""
22+
23+
async def __call__(self, ctx: ServerRequestContext[Any, Any], call_next: CallNext) -> HandlerResult:
24+
name = ctx.params.get("name") if ctx.params else None
25+
target = name if isinstance(name, str) else None
26+
27+
attributes: dict[str, Any] = {"mcp.method.name": ctx.method}
28+
if ctx.request_id is not None:
29+
attributes["jsonrpc.request.id"] = str(ctx.request_id)
30+
31+
with otel_span(
32+
name=f"MCP handle {ctx.method}{f' {target}' if target else ''}",
33+
kind=SpanKind.SERVER,
34+
attributes=attributes,
35+
context=extract_trace_context(ctx.meta or {}),
36+
record_exception=False,
37+
set_status_on_exception=False,
38+
) as span:
39+
try:
40+
return await call_next(ctx)
41+
except MCPError as e:
42+
span.set_status(StatusCode.ERROR, e.error.message)
43+
raise
44+
except ValidationError:
45+
# Mirror the sanitized wire response; pydantic messages carry client input.
46+
span.set_status(StatusCode.ERROR, "Invalid request parameters")
47+
raise
48+
except Exception as e:
49+
span.record_exception(e)
50+
span.set_status(StatusCode.ERROR, str(e))
51+
raise

src/mcp/server/context.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from mcp.shared.transport_context import TransportContext
1616
from mcp.types import LoggingLevel, RequestId, RequestParamsMeta
1717

18-
# Invariant: parameterizes a mutable dataclass field; dict default matches the default lifespan.
18+
# Invariant: parametrizes a mutable dataclass field; dict default matches the default lifespan.
1919
LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any])
2020
RequestT = TypeVar("RequestT", default=Any)
2121

@@ -33,6 +33,8 @@ class ServerRequestContext(Generic[LifespanContextT, RequestT]):
3333
session: ServerSession
3434
lifespan_context: LifespanContextT
3535
protocol_version: str
36+
method: str
37+
params: Mapping[str, Any] | None = None
3638
request_id: RequestId | None = None
3739
meta: RequestParamsMeta | None = None
3840
request: RequestT | None = None
@@ -113,39 +115,41 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *
113115
"""What a request handler (or middleware) may return. `ServerRunner` serializes
114116
all three to a result dict."""
115117

116-
CallNext = Callable[[], Awaitable[HandlerResult]]
118+
CallNext = Callable[["ServerRequestContext[Any, Any]"], Awaitable[HandlerResult]]
119+
"""Invokes the rest of the chain. Pass the `ctx` through; rewrite `method` or
120+
`params` with `dataclasses.replace(ctx, ...)` to alter what the handler sees."""
117121

118122
_MwLifespanT = TypeVar("_MwLifespanT")
119123

120124

121125
class ServerMiddleware(Protocol[_MwLifespanT]):
122-
"""Context-tier middleware: `(ctx, method, params, call_next) -> result`.
126+
"""Context-tier middleware: `(ctx, call_next) -> result`.
123127
124128
Runs at the top of `ServerRunner._on_request` / `_on_notify` after `ctx`
125129
is built but before any validation, lookup, or handshake. Wraps every
126130
inbound request and notification: `initialize`, the pre-init gate,
127131
`METHOD_NOT_FOUND`, params validation, the handler call, and
128-
`notifications/initialized` all run inside `call_next()`.
132+
`notifications/initialized` all run inside `call_next(ctx)`.
129133
`notifications/cancelled` is observed too; the dispatcher applies the
130134
cancellation itself, then forwards the notification. A request-side
131135
failure reaches the middleware as a raised `MCPError` (or
132136
`ValidationError` for malformed params) so observation/logging middleware
133137
can record it. Listed outermost-first on `Server.middleware`.
134138
139+
The method and the raw inbound params are `ctx.method` and `ctx.params` (no
140+
model validation has happened yet). To rewrite either before the handler
141+
runs, pass an adjusted context: `await call_next(replace(ctx, params=...))`.
135142
`ctx.request_id is None` distinguishes a notification from a request. For
136-
notifications `call_next()` returns `None` (a dropped or unhandled
143+
notifications `call_next(ctx)` returns `None` (a dropped or unhandled
137144
notification also returns `None`) and the middleware's own return value is
138145
discarded.
139146
140-
`params` is the raw inbound mapping (no model validation has happened
141-
yet). For typed inspection, validate against the model the middleware
142-
expects.
143-
144-
Warning: `initialize` is handled inline - the dispatcher does not read
145-
further inbound messages until the middleware chain returns. Awaiting a
146-
server-to-client request (`ctx.session.send_request`, `send_ping`, ...)
147-
while handling `initialize` therefore deadlocks the connection: the
148-
response can never be dequeued. Send-and-forget notifications are safe.
147+
!!! warning
148+
`initialize` is handled inline - the dispatcher does not read
149+
further inbound messages until the middleware chain returns. Awaiting a
150+
server-to-client request (`ctx.session.send_request`, `send_ping`, ...)
151+
while handling `initialize` therefore deadlocks the connection: the
152+
response can never be dequeued. Send-and-forget notifications are safe.
149153
150154
`Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific
151155
middleware sees `ctx.lifespan_context: L`. While the context is the
@@ -162,7 +166,5 @@ class ServerMiddleware(Protocol[_MwLifespanT]):
162166
async def __call__(
163167
self,
164168
ctx: ServerRequestContext[_MwLifespanT, Any],
165-
method: str,
166-
params: Mapping[str, Any] | None,
167169
call_next: CallNext,
168170
) -> HandlerResult: ...

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def __init__(
225225
self._session_manager: StreamableHTTPSessionManager | None = None
226226
# Context-tier middleware: wraps every inbound request (including
227227
# `initialize`, lookup, validation, handler) with
228-
# `(ctx, method, params, call_next)`. Applied in `ServerRunner._on_request`.
228+
# `(ctx, call_next)`. Applied in `ServerRunner._on_request`.
229229
# TODO(L54): provisional - signature and semantics change with the
230230
# Context/middleware rework (covariant `Context[L]`, outbound seam) before
231231
# v2 final.

src/mcp/server/runner.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _extract_meta(params: Mapping[str, Any] | None) -> RequestParamsMeta | None:
9393
return None
9494

9595

96-
def otel_middleware(next_on_request: OnRequest) -> OnRequest:
96+
def otel_middleware(call_next: OnRequest) -> OnRequest:
9797
"""Dispatch-tier middleware that wraps each request in an OpenTelemetry span.
9898
9999
Mirrors the span shape of the existing `Server._handle_request`: span name
@@ -129,7 +129,7 @@ async def wrapped(
129129
set_status_on_exception=False,
130130
) as span:
131131
try:
132-
return await next_on_request(dctx, method, params)
132+
return await call_next(dctx, method, params)
133133
except MCPError as e:
134134
span.set_status(StatusCode.ERROR, e.error.message)
135135
raise
@@ -200,6 +200,14 @@ async def to_jsonrpc_response(request_id: RequestId, coro: Awaitable[dict[str, A
200200
return JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)
201201

202202

203+
def _apply_middleware(
204+
mw: ServerMiddleware[Any], call_next: CallNext, ctx: ServerRequestContext[Any, Any]
205+
) -> Awaitable[HandlerResult]:
206+
"""Adapt one middleware to the `CallNext` shape: bind `call_next`, take
207+
`ctx` at call time so a rewritten context flows down the chain."""
208+
return mw(ctx, call_next)
209+
210+
203211
@dataclass
204212
class ServerRunner(Generic[LifespanT]):
205213
"""Per-connection handler kernel. One instance per client connection."""
@@ -234,15 +242,18 @@ async def _on_request(
234242
) -> dict[str, Any]:
235243
meta = _extract_meta(params)
236244
version = self.connection.protocol_version
237-
ctx = self._make_context(dctx, meta, version)
245+
ctx = self._make_context(dctx, method, params, meta, version)
238246
is_spec_method = method in _methods.SPEC_CLIENT_METHODS
239247

240-
async def _inner() -> HandlerResult:
248+
async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> HandlerResult:
249+
# Read method/params off `ctx` so a middleware that rewrote them via
250+
# `call_next(replace(ctx, ...))` reaches lookup and the handler.
251+
method, params = ctx.method, ctx.params
241252
# Pinned compat: spec methods are surface-validated before lookup,
242253
# so malformed params are INVALID_PARAMS even with no handler
243254
# registered. Custom methods miss the monolith map and fall through
244255
# to `entry.params_type` exactly as before.
245-
if is_spec_method:
256+
if method in _methods.SPEC_CLIENT_METHODS:
246257
try:
247258
_methods.validate_client_request(method, version, params)
248259
except KeyError:
@@ -272,8 +283,8 @@ async def _inner() -> HandlerResult:
272283
raise MCPError.from_error_data(result)
273284
return result
274285

275-
call = self._compose_server_middleware(ctx, method, params, _inner)
276-
result = _dump_result(await call())
286+
call = self._compose_server_middleware(_inner)
287+
result = _dump_result(await call(ctx))
277288
# TODO(L56): reject resultType values outside {"complete", "input_required"} unless the
278289
# corresponding extension is in this request's _meta clientCapabilities.extensions; the
279290
# explicit MUST-reject is client-side (basic/index.mdx ResultType), this enforces it proactively.
@@ -303,9 +314,10 @@ async def _on_notify(
303314
) -> None:
304315
meta = _extract_meta(params)
305316
version = self.connection.protocol_version
306-
ctx = self._make_context(dctx, meta, version)
317+
ctx = self._make_context(dctx, method, params, meta, version)
307318

308-
async def _inner() -> None:
319+
async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> None:
320+
method, params = ctx.method, ctx.params
309321
if method in _methods.SPEC_CLIENT_NOTIFICATION_METHODS:
310322
try:
311323
_methods.validate_client_notification(method, version, params)
@@ -335,33 +347,33 @@ async def _inner() -> None:
335347
return
336348
await entry.handler(ctx, typed_params)
337349

338-
call = self._compose_server_middleware(ctx, method, params, _inner)
350+
call = self._compose_server_middleware(_inner)
339351
try:
340-
await call()
352+
await call(ctx)
341353
except Exception:
342354
# A crashing handler must not cancel the dispatcher's task group;
343355
# middleware saw the raise out of call_next() first.
344356
logger.exception("notification handler for %r raised", method)
345357

346-
def _compose_server_middleware(
347-
self,
348-
ctx: ServerRequestContext[LifespanT, Any],
349-
method: str,
350-
params: Mapping[str, Any] | None,
351-
inner: CallNext,
352-
) -> CallNext:
358+
def _compose_server_middleware(self, inner: CallNext) -> CallNext:
353359
"""Wrap `inner` in `Server.middleware`, outermost-first.
354360
355361
Shared by `_on_request` and `_on_notify` so the same middleware chain
356-
observes every inbound message.
362+
observes every inbound message. The composed callable takes the `ctx`
363+
at call time, so a middleware can rewrite it for the rest of the chain.
357364
"""
358-
call = inner
365+
call: CallNext = inner
359366
for mw in reversed(self.server.middleware):
360-
call = partial(mw, ctx, method, params, call)
367+
call = partial(_apply_middleware, mw, call)
361368
return call
362369

363370
def _make_context(
364-
self, dctx: DispatchContext[TransportContext], meta: RequestParamsMeta | None, protocol_version: str
371+
self,
372+
dctx: DispatchContext[TransportContext],
373+
method: str,
374+
params: Mapping[str, Any] | None,
375+
meta: RequestParamsMeta | None,
376+
protocol_version: str,
365377
) -> ServerRequestContext[LifespanT, Any]:
366378
# TODO(L54): remove for Context rework. Reads the SHTTP per-request
367379
# data off the raw `dctx.message_metadata` carrier; replace with the
@@ -380,6 +392,8 @@ def _make_context(
380392
return ServerRequestContext(
381393
session=session,
382394
lifespan_context=self.lifespan_state,
395+
method=method,
396+
params=params,
383397
request_id=dctx.request_id,
384398
meta=meta,
385399
protocol_version=protocol_version,

src/mcp/shared/_otel.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Iterator
5+
from collections.abc import Generator, Mapping
66
from contextlib import contextmanager
77
from typing import Any
88

99
from opentelemetry.context import Context
1010
from opentelemetry.propagate import extract, inject
1111
from opentelemetry.trace import SpanKind, get_tracer
12+
from opentelemetry.trace.span import Span
1213

1314
_tracer = get_tracer("mcp-python-sdk")
1415

@@ -22,7 +23,7 @@ def otel_span(
2223
context: Context | None = None,
2324
record_exception: bool = True,
2425
set_status_on_exception: bool = True,
25-
) -> Iterator[Any]:
26+
) -> Generator[Span]:
2627
"""Create an OTel span."""
2728
with _tracer.start_as_current_span(
2829
name,
@@ -40,13 +41,10 @@ def inject_trace_context(meta: dict[str, Any]) -> None:
4041
inject(meta)
4142

4243

43-
def extract_trace_context(meta: dict[str, Any]) -> Context | None:
44-
"""Extract W3C trace context from a `_meta` dict.
45-
46-
Returns `None` when the carrier is malformed; telemetry parsing must
47-
never fail the request it annotates.
48-
"""
44+
def extract_trace_context(meta: Mapping[str, Any]) -> Context:
45+
"""Extract W3C trace context from a `_meta` dict."""
4946
try:
5047
return extract(meta)
51-
except (TypeError, ValueError):
52-
return None
48+
except (ValueError, TypeError):
49+
# If the traceparent is malformed, degrade to no parent rather than failing the request.
50+
return Context()

tests/issues/test_176_progress_token.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ async def test_progress_token_zero_first_call():
1919
request_context = ServerRequestContext(
2020
request_id="test-request",
2121
session=mock_session,
22+
method="tools/call",
2223
meta={"progress_token": 0},
2324
lifespan_context=None,
2425
protocol_version="2025-11-25",

tests/server/mcpserver/test_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,7 @@ async def test_report_progress_passes_related_request_id():
15281528
request_context = ServerRequestContext(
15291529
request_id="req-abc-123",
15301530
session=mock_session,
1531+
method="tools/call",
15311532
meta={"progress_token": "tok-1"},
15321533
lifespan_context=None,
15331534
protocol_version="2025-11-25",

0 commit comments

Comments
 (0)