Skip to content

Commit ad81ca2

Browse files
Kludexmaxisbey
andauthored
Slim ServerMiddleware to (ctx, call_next) and add OpenTelemetryMiddleware (#2941)
Co-authored-by: Max Isbey <224885523+maxisbey@users.noreply.github.com>
1 parent 5e013d9 commit ad81ca2

11 files changed

Lines changed: 374 additions & 86 deletions

File tree

docs/migration.md

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ ctx: ClientRequestContext
746746
server_ctx: ServerRequestContext[LifespanContextT, RequestT]
747747
```
748748

749-
`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) plus a new `protocol_version: str` field, so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`.
749+
`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) plus new `protocol_version: str`, `method: str`, and raw `params: Mapping[str, Any] | None` fields (the last two let middleware read and rewrite the inbound message), so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`.
750750

751751
The high-level `Context` class (injected into `@mcp.tool()` etc.) similarly dropped its `ServerSessionT` parameter: `Context[ServerSessionT, LifespanContextT, RequestT]``Context[LifespanContextT, RequestT]`. Both remaining parameters have defaults, so bare `Context` is usually sufficient:
752752

@@ -961,27 +961,24 @@ server.add_notification_handler("notifications/custom", MyNotifyParams, my_notif
961961
These were private, but some users subclassed `Server` and overrode them to intercept requests. Use middleware instead:
962962

963963
```python
964-
from collections.abc import Mapping
965964
from typing import Any
966965

967966
from mcp.server import Server, ServerRequestContext
968967
from mcp.server.context import CallNext, HandlerResult
969968

970969

971-
async def logging_middleware(
972-
ctx: ServerRequestContext[Any, Any], method: str, params: Mapping[str, Any] | None, call_next: CallNext
973-
) -> HandlerResult:
974-
print(f"handling {method}")
975-
result = await call_next()
976-
print(f"done {method}")
970+
async def logging_middleware(ctx: ServerRequestContext[Any, Any], call_next: CallNext) -> HandlerResult:
971+
print(f"handling {ctx.method}")
972+
result = await call_next(ctx)
973+
print(f"done {ctx.method}")
977974
return result
978975

979976

980977
server = Server("my-server", on_call_tool=...)
981978
server.middleware.append(logging_middleware)
982979
```
983980

984-
Middleware runs before params validation, so `params` is the raw inbound mapping (or `None`), and it also wraps unknown methods.
981+
The method and the raw inbound params are `ctx.method` and `ctx.params` (`params` is `None` when the message carries none). Middleware runs before params validation and also wraps unknown methods. To rewrite the method or params before the handler runs, pass an adjusted context through: `await call_next(replace(ctx, params=...))`.
985982

986983
### Lowlevel `Server.run(raise_exceptions=True)`: transport errors no longer re-raised
987984

src/mcp/server/_otel.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from opentelemetry.trace import SpanKind, StatusCode
6+
from pydantic import ValidationError
7+
8+
from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext
9+
from mcp.shared._otel import extract_trace_context, otel_span
10+
from mcp.shared.exceptions import MCPError
11+
12+
13+
class OpenTelemetryMiddleware(ServerMiddleware[Any]):
14+
"""Context-tier middleware that wraps each inbound message in an OpenTelemetry span.
15+
16+
Span name `"MCP handle <method> [<target>]"`, `mcp.method.name` attribute, W3C
17+
trace context extracted from `params._meta` (SEP-414), and an ERROR status if
18+
the handler raises. Requests and notifications both get a span;
19+
`jsonrpc.request.id` is set only when `ctx.request_id` is present (notifications
20+
have none).
21+
"""
22+
23+
async def __call__(self, ctx: ServerRequestContext[Any, Any], call_next: CallNext) -> HandlerResult:
24+
name = ctx.params.get("name") if ctx.params else None
25+
target = name if isinstance(name, str) else None
26+
27+
attributes: dict[str, Any] = {"mcp.method.name": ctx.method}
28+
if ctx.request_id is not None:
29+
attributes["jsonrpc.request.id"] = str(ctx.request_id)
30+
31+
with otel_span(
32+
name=f"MCP handle {ctx.method}{f' {target}' if target else ''}",
33+
kind=SpanKind.SERVER,
34+
attributes=attributes,
35+
context=extract_trace_context(ctx.meta),
36+
record_exception=False,
37+
set_status_on_exception=False,
38+
) as span:
39+
try:
40+
return await call_next(ctx)
41+
except MCPError as e:
42+
span.set_status(StatusCode.ERROR, e.error.message)
43+
raise
44+
except ValidationError:
45+
# Mirror the sanitized wire response; pydantic messages carry client input.
46+
span.set_status(StatusCode.ERROR, "Invalid request parameters")
47+
raise
48+
except Exception as e:
49+
span.record_exception(e)
50+
span.set_status(StatusCode.ERROR, str(e))
51+
raise

src/mcp/server/context.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from mcp.shared.transport_context import TransportContext
1616
from mcp.types import LoggingLevel, RequestId, RequestParamsMeta
1717

18-
# Invariant: parameterizes a mutable dataclass field; dict default matches the default lifespan.
18+
# Invariant: parametrizes a mutable dataclass field; dict default matches the default lifespan.
1919
LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any])
2020
RequestT = TypeVar("RequestT", default=Any)
2121

@@ -33,6 +33,8 @@ class ServerRequestContext(Generic[LifespanContextT, RequestT]):
3333
session: ServerSession
3434
lifespan_context: LifespanContextT
3535
protocol_version: str
36+
method: str
37+
params: Mapping[str, Any] | None = None
3638
request_id: RequestId | None = None
3739
meta: RequestParamsMeta | None = None
3840
request: RequestT | None = None
@@ -113,39 +115,44 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *
113115
"""What a request handler (or middleware) may return. `ServerRunner` serializes
114116
all three to a result dict."""
115117

116-
CallNext = Callable[[], Awaitable[HandlerResult]]
118+
CallNext = Callable[["ServerRequestContext[Any, Any]"], Awaitable[HandlerResult]]
119+
"""Invokes the rest of the chain. Pass the `ctx` through; rewrite `method` or
120+
`params` with `dataclasses.replace(ctx, ...)` to alter what the handler sees."""
117121

118122
_MwLifespanT = TypeVar("_MwLifespanT")
119123

120124

121125
class ServerMiddleware(Protocol[_MwLifespanT]):
122-
"""Context-tier middleware: `(ctx, method, params, call_next) -> result`.
126+
"""Context-tier middleware: `(ctx, call_next) -> result`.
123127
124128
Runs at the top of `ServerRunner._on_request` / `_on_notify` after `ctx`
125129
is built but before any validation, lookup, or handshake. Wraps every
126130
inbound request and notification: `initialize`, the pre-init gate,
127131
`METHOD_NOT_FOUND`, params validation, the handler call, and
128-
`notifications/initialized` all run inside `call_next()`.
132+
`notifications/initialized` all run inside `call_next(ctx)`.
129133
`notifications/cancelled` is observed too; the dispatcher applies the
130134
cancellation itself, then forwards the notification. A request-side
131135
failure reaches the middleware as a raised `MCPError` (or
132136
`ValidationError` for malformed params) so observation/logging middleware
133137
can record it. Listed outermost-first on `Server.middleware`.
134138
139+
The method and the raw inbound params are `ctx.method` and `ctx.params` (no
140+
model validation has happened yet). To rewrite either before the handler
141+
runs, pass an adjusted context: `await call_next(replace(ctx, params=...))`.
135142
`ctx.request_id is None` distinguishes a notification from a request. For
136-
notifications `call_next()` returns `None` (a dropped or unhandled
143+
notifications `call_next(ctx)` returns `None` (a dropped or unhandled
137144
notification also returns `None`) and the middleware's own return value is
138145
discarded.
139146
140-
`params` is the raw inbound mapping (no model validation has happened
141-
yet). For typed inspection, validate against the model the middleware
142-
expects.
143-
144-
Warning: `initialize` is handled inline - the dispatcher does not read
145-
further inbound messages until the middleware chain returns. Awaiting a
146-
server-to-client request (`ctx.session.send_request`, `send_ping`, ...)
147-
while handling `initialize` therefore deadlocks the connection: the
148-
response can never be dequeued. Send-and-forget notifications are safe.
147+
!!! warning
148+
`initialize` is handled inline - the dispatcher does not read
149+
further inbound messages until the middleware chain returns. Awaiting a
150+
server-to-client request (`ctx.session.send_request`, `send_ping`, ...)
151+
while handling `initialize` therefore deadlocks the connection: the
152+
response can never be dequeued. Send-and-forget notifications are safe.
153+
`initialize` is observed but not rewritable: the post-chain handshake
154+
commit reads the wire params, so to veto the handshake raise *before*
155+
`call_next()`.
149156
150157
`Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific
151158
middleware sees `ctx.lifespan_context: L`. While the context is the
@@ -162,7 +169,5 @@ class ServerMiddleware(Protocol[_MwLifespanT]):
162169
async def __call__(
163170
self,
164171
ctx: ServerRequestContext[_MwLifespanT, Any],
165-
method: str,
166-
params: Mapping[str, Any] | None,
167172
call_next: CallNext,
168173
) -> HandlerResult: ...

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def __init__(
225225
self._session_manager: StreamableHTTPSessionManager | None = None
226226
# Context-tier middleware: wraps every inbound request (including
227227
# `initialize`, lookup, validation, handler) with
228-
# `(ctx, method, params, call_next)`. Applied in `ServerRunner._on_request`.
228+
# `(ctx, call_next)`. Applied in `ServerRunner._on_request`.
229229
# TODO(L54): provisional - signature and semantics change with the
230230
# Context/middleware rework (covariant `Context[L]`, outbound seam) before
231231
# v2 final.

src/mcp/server/runner.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _extract_meta(params: Mapping[str, Any] | None) -> RequestParamsMeta | None:
9393
return None
9494

9595

96-
def otel_middleware(next_on_request: OnRequest) -> OnRequest:
96+
def otel_middleware(call_next: OnRequest) -> OnRequest:
9797
"""Dispatch-tier middleware that wraps each request in an OpenTelemetry span.
9898
9999
Mirrors the span shape of the existing `Server._handle_request`: span name
@@ -129,7 +129,7 @@ async def wrapped(
129129
set_status_on_exception=False,
130130
) as span:
131131
try:
132-
return await next_on_request(dctx, method, params)
132+
return await call_next(dctx, method, params)
133133
except MCPError as e:
134134
span.set_status(StatusCode.ERROR, e.error.message)
135135
raise
@@ -200,6 +200,14 @@ async def to_jsonrpc_response(request_id: RequestId, coro: Awaitable[dict[str, A
200200
return JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)
201201

202202

203+
def _apply_middleware(
204+
middleware: ServerMiddleware[Any], call_next: CallNext, ctx: ServerRequestContext[Any, Any]
205+
) -> Awaitable[HandlerResult]:
206+
"""Adapt one middleware to the `CallNext` shape: bind `call_next`, take
207+
`ctx` at call time so a rewritten context flows down the chain."""
208+
return middleware(ctx, call_next)
209+
210+
203211
@dataclass
204212
class ServerRunner(Generic[LifespanT]):
205213
"""Per-connection handler kernel. One instance per client connection."""
@@ -220,7 +228,9 @@ def on_request(self) -> OnRequest:
220228
wraps everything - initialize, METHOD_NOT_FOUND, validation failures
221229
included.
222230
"""
223-
return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request)
231+
return reduce(
232+
lambda handler, middleware: middleware(handler), reversed(self.dispatch_middleware), self._on_request
233+
)
224234

225235
@cached_property
226236
def on_notify(self) -> OnNotify:
@@ -234,15 +244,18 @@ async def _on_request(
234244
) -> dict[str, Any]:
235245
meta = _extract_meta(params)
236246
version = self.connection.protocol_version
237-
ctx = self._make_context(dctx, meta, version)
247+
ctx = self._make_context(dctx, method, params, meta, version)
238248
is_spec_method = method in _methods.SPEC_CLIENT_METHODS
239249

240-
async def _inner() -> HandlerResult:
250+
async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> HandlerResult:
251+
# Read method/params off `ctx` so a middleware that rewrote them via
252+
# `call_next(replace(ctx, ...))` reaches lookup and the handler.
253+
method, params = ctx.method, ctx.params
241254
# Pinned compat: spec methods are surface-validated before lookup,
242255
# so malformed params are INVALID_PARAMS even with no handler
243256
# registered. Custom methods miss the monolith map and fall through
244257
# to `entry.params_type` exactly as before.
245-
if is_spec_method:
258+
if method in _methods.SPEC_CLIENT_METHODS:
246259
try:
247260
_methods.validate_client_request(method, version, params)
248261
except KeyError:
@@ -272,8 +285,8 @@ async def _inner() -> HandlerResult:
272285
raise MCPError.from_error_data(result)
273286
return result
274287

275-
call = self._compose_server_middleware(ctx, method, params, _inner)
276-
result = _dump_result(await call())
288+
call = self._compose_server_middleware(_inner)
289+
result = _dump_result(await call(ctx))
277290
# TODO(L56): reject resultType values outside {"complete", "input_required"} unless the
278291
# corresponding extension is in this request's _meta clientCapabilities.extensions; the
279292
# explicit MUST-reject is client-side (basic/index.mdx ResultType), this enforces it proactively.
@@ -292,6 +305,11 @@ async def _inner() -> HandlerResult:
292305
if method == "initialize":
293306
# Commit only on chain success, so a middleware veto leaves no state.
294307
# Race-free: the read loop is parked until this call returns.
308+
# TODO: this re-reads the wire `params`, so a middleware that rewrote
309+
# `ctx.params` (or `ctx.method`, or short-circuited without `call_next`)
310+
# can leave `connection.protocol_version` out of step with the
311+
# `InitializeResult` `_inner` produced. Resolve when `initialize` becomes
312+
# a built-in handler so commit and result derive from one negotiation.
295313
self.connection.client_params, self.connection.protocol_version = self._negotiate_initialize(params)
296314
return result
297315

@@ -303,9 +321,10 @@ async def _on_notify(
303321
) -> None:
304322
meta = _extract_meta(params)
305323
version = self.connection.protocol_version
306-
ctx = self._make_context(dctx, meta, version)
324+
ctx = self._make_context(dctx, method, params, meta, version)
307325

308-
async def _inner() -> None:
326+
async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> None:
327+
method, params = ctx.method, ctx.params
309328
if method in _methods.SPEC_CLIENT_NOTIFICATION_METHODS:
310329
try:
311330
_methods.validate_client_notification(method, version, params)
@@ -335,33 +354,33 @@ async def _inner() -> None:
335354
return
336355
await entry.handler(ctx, typed_params)
337356

338-
call = self._compose_server_middleware(ctx, method, params, _inner)
357+
call = self._compose_server_middleware(_inner)
339358
try:
340-
await call()
359+
await call(ctx)
341360
except Exception:
342361
# A crashing handler must not cancel the dispatcher's task group;
343362
# middleware saw the raise out of call_next() first.
344363
logger.exception("notification handler for %r raised", method)
345364

346-
def _compose_server_middleware(
347-
self,
348-
ctx: ServerRequestContext[LifespanT, Any],
349-
method: str,
350-
params: Mapping[str, Any] | None,
351-
inner: CallNext,
352-
) -> CallNext:
365+
def _compose_server_middleware(self, inner: CallNext) -> CallNext:
353366
"""Wrap `inner` in `Server.middleware`, outermost-first.
354367
355368
Shared by `_on_request` and `_on_notify` so the same middleware chain
356-
observes every inbound message.
369+
observes every inbound message. The composed callable takes the `ctx`
370+
at call time, so a middleware can rewrite it for the rest of the chain.
357371
"""
358372
call = inner
359-
for mw in reversed(self.server.middleware):
360-
call = partial(mw, ctx, method, params, call)
373+
for middleware in reversed(self.server.middleware):
374+
call = partial(_apply_middleware, middleware, call)
361375
return call
362376

363377
def _make_context(
364-
self, dctx: DispatchContext[TransportContext], meta: RequestParamsMeta | None, protocol_version: str
378+
self,
379+
dctx: DispatchContext[TransportContext],
380+
method: str,
381+
params: Mapping[str, Any] | None,
382+
meta: RequestParamsMeta | None,
383+
protocol_version: str,
365384
) -> ServerRequestContext[LifespanT, Any]:
366385
# TODO(L54): remove for Context rework. Reads the SHTTP per-request
367386
# data off the raw `dctx.message_metadata` carrier; replace with the
@@ -380,6 +399,8 @@ def _make_context(
380399
return ServerRequestContext(
381400
session=session,
382401
lifespan_context=self.lifespan_state,
402+
method=method,
403+
params=params,
383404
request_id=dctx.request_id,
384405
meta=meta,
385406
protocol_version=protocol_version,

src/mcp/shared/_otel.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Iterator
5+
from collections.abc import Generator, Mapping
66
from contextlib import contextmanager
77
from typing import Any
88

99
from opentelemetry.context import Context
1010
from opentelemetry.propagate import extract, inject
11-
from opentelemetry.trace import SpanKind, get_tracer
11+
from opentelemetry.trace import SpanKind, get_current_span, get_tracer
12+
from opentelemetry.trace.span import Span
1213

1314
_tracer = get_tracer("mcp-python-sdk")
1415

@@ -22,7 +23,7 @@ def otel_span(
2223
context: Context | None = None,
2324
record_exception: bool = True,
2425
set_status_on_exception: bool = True,
25-
) -> Iterator[Any]:
26+
) -> Generator[Span]:
2627
"""Create an OTel span."""
2728
with _tracer.start_as_current_span(
2829
name,
@@ -40,13 +41,20 @@ def inject_trace_context(meta: dict[str, Any]) -> None:
4041
inject(meta)
4142

4243

43-
def extract_trace_context(meta: dict[str, Any]) -> Context | None:
44+
def extract_trace_context(meta: Mapping[str, Any] | None) -> Context | None:
4445
"""Extract W3C trace context from a `_meta` dict.
4546
46-
Returns `None` when the carrier is malformed; telemetry parsing must
47-
never fail the request it annotates.
47+
Returns `None` when the carrier is absent, malformed, or carries no
48+
valid `traceparent`, so callers fall through to ambient parenting; an
49+
explicit empty `Context` would orphan the span instead of nesting under
50+
the current one.
4851
"""
52+
if not meta:
53+
return None
4954
try:
50-
return extract(meta)
51-
except (TypeError, ValueError):
55+
ctx = extract(meta)
56+
except (ValueError, TypeError):
57+
return None
58+
if not get_current_span(ctx).get_span_context().is_valid:
5259
return None
60+
return ctx

0 commit comments

Comments
 (0)