diff --git a/docs/instrumentation.md b/docs/instrumentation.md new file mode 100644 index 0000000000..870a2963c7 --- /dev/null +++ b/docs/instrumentation.md @@ -0,0 +1,375 @@ +# Instrumentation + +The MCP Python SDK provides a pluggable instrumentation interface for monitoring request/response lifecycle. This enables integration with OpenTelemetry, custom metrics, logging frameworks, and other observability tools. + +**Related issue**: [#421 - Adding OpenTelemetry to MCP SDK](https://github.com/modelcontextprotocol/python-sdk/issues/421) + +## Overview + +The `Instrumenter` protocol defines three hooks: + +- `on_request_start`: Called when a request starts processing, **returns a token** +- `on_request_end`: Called when a request completes, **receives the token** +- `on_error`: Called when an error occurs, **receives the token** + +The token-based design allows instrumenters to maintain state (like OpenTelemetry spans) between `on_request_start` and `on_request_end` without needing external storage or side-channels. + +All methods are optional (no-op implementations are valid). Exceptions raised by instrumentation hooks are logged but do not affect request processing. + +## Basic Usage + +### Server-Side Instrumentation + +```python +from typing import Any +from mcp.server.lowlevel import Server +from mcp.shared.instrumentation import Instrumenter +from mcp.types import RequestId + +class MyInstrumenter: + """Custom instrumenter implementation.""" + + def on_request_start( + self, + request_id: RequestId, + request_type: str, + method: str | None = None, + **metadata, + ) -> Any: + """Return a token (any value) to track this request.""" + print(f"Request {request_id} started: {request_type}") + # Return a token - can be anything (dict, object, etc.) + return {"request_id": request_id, "start_time": time.time()} + + def on_request_end( + self, + token: Any, # Receives the token from on_request_start + request_id: RequestId, + request_type: str, + success: bool, + duration_seconds: float | None = None, + **metadata, + ) -> None: + """Process the completed request using the token.""" + status = "succeeded" if success else "failed" + print(f"Request {request_id} {status} in {duration_seconds:.3f}s") + print(f"Token data: {token}") + + def on_error( + self, + token: Any, # Receives the token from on_request_start + request_id: RequestId | None, + error: Exception, + error_type: str, + **metadata, + ) -> None: + """Handle errors using the token.""" + print(f"Error in request {request_id}: {error_type} - {error}") + +# Create server with custom instrumenter +server = Server("my-server") + +# Pass instrumenter when running the server +async def run_server(): + async with stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + instrumenter=MyInstrumenter(), + ) +``` + +### Client-Side Instrumentation + +```python +from mcp.client.session import ClientSession +from mcp.shared.instrumentation import Instrumenter + +# Create client session with instrumenter +async with ClientSession( + read_stream=read_stream, + write_stream=write_stream, + instrumenter=MyInstrumenter(), +) as session: + await session.initialize() + # Use session... +``` + +### Why Tokens? + +The token-based design solves a key problem: **how do you maintain state between `on_request_start` and `on_request_end`?** + +Without tokens, instrumenters would need to use external storage (like a dictionary keyed by `request_id`) to track state: + +```python +# ❌ Old approach - requires external storage +class OldInstrumenter: + def __init__(self): + self.spans = {} # Need to manage this dict + + def on_request_start(self, request_id, ...): + span = create_span(...) + self.spans[request_id] = span # Store externally + + def on_request_end(self, request_id, ...): + span = self.spans.pop(request_id) # Retrieve from storage + span.end() +``` + +With tokens, state passes directly through the SDK: + +```python +# ✅ New approach - token is returned and passed back +class NewInstrumenter: + def on_request_start(self, request_id, ...): + span = create_span(...) + return span # Return directly + + def on_request_end(self, token, request_id, ...): + span = token # Receive directly + span.end() +``` + +This is especially important for OpenTelemetry, where spans need to be kept alive. + +## Metadata + +Instrumentation hooks receive metadata via `**metadata` keyword arguments: + +- `on_request_start` metadata: + - `session_type`: "server" or "client" + - Any additional context provided by the framework + +- `on_request_end` metadata: + - `cancelled`: True if the request was cancelled + - `error`: Error message if request failed + - Any additional context + +- `on_error` metadata: + - Additional error context + +## Request ID + +The `request_id` parameter is consistent across all hooks for a given request, allowing you to correlate the request lifecycle. The `request_id` is also added to log records via the `extra` field, so you can filter logs by request. + +## OpenTelemetry Integration + +The token-based instrumentation interface is designed specifically to work well with OpenTelemetry. Here's a complete example: + +```python +from typing import Any +from opentelemetry import trace +from opentelemetry.trace import Status, StatusCode +from mcp.types import RequestId + +class OpenTelemetryInstrumenter: + """OpenTelemetry implementation of the MCP Instrumenter protocol.""" + + def __init__(self, tracer_provider=None): + if tracer_provider is None: + tracer_provider = trace.get_tracer_provider() + self.tracer = tracer_provider.get_tracer("mcp.sdk", version="1.0.0") + + def on_request_start( + self, + request_id: RequestId, + request_type: str, + method: str | None = None, + **metadata: Any, + ) -> Any: + """Start a new span and return it as the token.""" + span_name = f"mcp.{request_type}" + if method: + span_name = f"{span_name}.{method}" + + # Start the span + span = self.tracer.start_span(span_name) + + # Set attributes + span.set_attribute("mcp.request_id", str(request_id)) + span.set_attribute("mcp.request_type", request_type) + if method: + span.set_attribute("mcp.method", method) + + # Add metadata + session_type = metadata.get("session_type") + if session_type: + span.set_attribute("mcp.session_type", session_type) + + # Return span as token + return span + + def on_request_end( + self, + token: Any, # This is the span from on_request_start + request_id: RequestId, + request_type: str, + success: bool, + duration_seconds: float | None = None, + **metadata: Any, + ) -> None: + """End the span.""" + if token is None: + return + + span = token + + # Set success attributes + span.set_attribute("mcp.success", success) + if duration_seconds is not None: + span.set_attribute("mcp.duration_seconds", duration_seconds) + + # Set status + if success: + span.set_status(Status(StatusCode.OK)) + else: + span.set_status(Status(StatusCode.ERROR)) + error_msg = metadata.get("error") + if error_msg: + span.set_attribute("mcp.error", str(error_msg)) + + # End the span + span.end() + + def on_error( + self, + token: Any, # This is the span from on_request_start + request_id: RequestId | None, + error: Exception, + error_type: str, + **metadata: Any, + ) -> None: + """Record error in the span.""" + if token is None: + return + + span = token + + # Record exception + span.record_exception(error) + span.set_attribute("mcp.error_type", error_type) + span.set_attribute("mcp.error_message", str(error)) + + # Set error status + span.set_status(Status(StatusCode.ERROR, str(error))) +``` + +### Full Working Example + +A complete working example with OpenTelemetry setup is available in `examples/opentelemetry_instrumentation.py`. + +To use it: + +```bash +# Install OpenTelemetry +pip install opentelemetry-api opentelemetry-sdk + +# Run the example +python examples/opentelemetry_instrumentation.py +``` + +### Key Benefits + +The token-based design provides several advantages for OpenTelemetry: + +1. **No external storage**: No need to maintain a `spans` dictionary +2. **Automatic cleanup**: Spans are garbage collected when done +3. **Thread-safe**: Each request gets its own token +4. **Context propagation**: Easy to integrate with OpenTelemetry context +5. **Distributed tracing**: Can be extended to propagate trace context in `_meta` + +## Default Behavior + +If no instrumenter is provided, a no-op implementation is used automatically. This has minimal overhead and doesn't affect request processing. + +```python +from mcp.shared.instrumentation import get_default_instrumenter + +# Get the default no-op instrumenter +instrumenter = get_default_instrumenter() +``` + +## Best Practices + +1. **Keep hooks fast**: Instrumentation hooks are called synchronously in the request path. Keep processing minimal to avoid impacting request latency. + +2. **Handle errors gracefully**: Exceptions in instrumentation hooks are caught and logged, but it's best to handle errors within your instrumenter. + +3. **Use appropriate metadata**: Include relevant context in metadata fields to aid debugging and analysis. + +4. **Consider sampling**: For high-volume servers, consider implementing sampling in your instrumenter to reduce overhead. + +## Example: Custom Metrics + +```python +from collections import defaultdict +from typing import Any, Dict +from mcp.types import RequestId + +class MetricsInstrumenter: + """Track request counts and durations.""" + + def __init__(self): + self.request_counts: Dict[str, int] = defaultdict(int) + self.request_durations: Dict[str, list[float]] = defaultdict(list) + self.error_counts: Dict[str, int] = defaultdict(int) + + def on_request_start( + self, + request_id: RequestId, + request_type: str, + method: str | None = None, + **metadata: Any, + ) -> Any: + """Track request start, return request_type as token.""" + self.request_counts[request_type] += 1 + return request_type # Simple token - just the request type + + def on_request_end( + self, + token: Any, + request_id: RequestId, + request_type: str, + success: bool, + duration_seconds: float | None = None, + **metadata: Any, + ) -> None: + """Track request completion.""" + if duration_seconds is not None: + self.request_durations[request_type].append(duration_seconds) + + def on_error( + self, + token: Any, + request_id: RequestId | None, + error: Exception, + error_type: str, + **metadata: Any, + ) -> None: + """Track errors.""" + self.error_counts[error_type] += 1 + + def get_stats(self): + """Get statistics summary.""" + stats = {} + for request_type, durations in self.request_durations.items(): + if durations: + avg_duration = sum(durations) / len(durations) + stats[request_type] = { + "count": self.request_counts[request_type], + "avg_duration": avg_duration, + } + return stats +``` + +Note: For this simple metrics case, the token isn't strictly necessary, so we just return the `request_type`. For more complex instrumenters (like OpenTelemetry), the token is essential for maintaining state. + +## Future Work + +- Package OpenTelemetry instrumenter as a separate installable extra (`pip install mcp[opentelemetry]`) +- Additional built-in instrumenters (Prometheus, StatsD, Datadog, etc.) +- Support for distributed tracing via `params._meta.traceparent` propagation (see [modelcontextprotocol/spec#414](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/414)) +- Semantic conventions for MCP traces and metrics (see [open-telemetry/semantic-conventions#2083](https://github.com/open-telemetry/semantic-conventions/pull/2083)) +- Client-side request instrumentation +- Async hook support for long-running instrumentation operations diff --git a/examples/opentelemetry_instrumentation.py b/examples/opentelemetry_instrumentation.py new file mode 100644 index 0000000000..631f6604bc --- /dev/null +++ b/examples/opentelemetry_instrumentation.py @@ -0,0 +1,249 @@ +"""OpenTelemetry instrumentation example for MCP SDK. + +This example demonstrates how to integrate OpenTelemetry tracing with the MCP SDK +using the pluggable instrumentation interface. + +Installation: + pip install opentelemetry-api opentelemetry-sdk + +Usage: + # In your server code: + from opentelemetry_instrumentation import OpenTelemetryInstrumenter + + instrumenter = OpenTelemetryInstrumenter() + + # When creating ServerSession: + session = ServerSession( + read_stream, + write_stream, + init_options, + instrumenter=instrumenter, + ) + +Related issue: https://github.com/modelcontextprotocol/python-sdk/issues/421 +""" + +from __future__ import annotations + +import logging +from typing import Any + +from mcp.types import RequestId + +logger = logging.getLogger(__name__) + + +class OpenTelemetryInstrumenter: + """OpenTelemetry implementation of the MCP Instrumenter protocol. + + This instrumenter creates spans for each MCP request, tracks metrics, + and supports distributed tracing via context propagation. + """ + + def __init__(self, tracer_provider=None): + """Initialize the OpenTelemetry instrumenter. + + Args: + tracer_provider: Optional OpenTelemetry tracer provider. + If None, uses the global tracer provider. + """ + try: + from opentelemetry import trace + from opentelemetry.trace import Status, StatusCode + + self._trace = trace + self._Status = Status + self._StatusCode = StatusCode + + if tracer_provider is None: + tracer_provider = trace.get_tracer_provider() + + self._tracer = tracer_provider.get_tracer("mcp.sdk", version="1.0.0") + self._enabled = True + except ImportError: + logger.warning("OpenTelemetry not installed. Install with: pip install opentelemetry-api opentelemetry-sdk") + self._enabled = False + + def on_request_start( + self, + request_id: RequestId, + request_type: str, + method: str | None = None, + **metadata: Any, + ) -> Any: + """Start a new span for the request. + + Returns: + The OpenTelemetry span object as the token. + """ + if not self._enabled: + return None + + # Create span name from request type + span_name = f"mcp.{request_type}" + if method: + span_name = f"{span_name}.{method}" + + # Start the span + span = self._tracer.start_span(span_name) + + # Set standard attributes + span.set_attribute("mcp.request_id", str(request_id)) + span.set_attribute("mcp.request_type", request_type) + + if method: + span.set_attribute("mcp.method", method) + + # Add metadata as attributes + session_type = metadata.get("session_type") + if session_type: + span.set_attribute("mcp.session_type", session_type) + + # Add any custom metadata + for key, value in metadata.items(): + if key not in ("session_type",) and isinstance(value, str | int | float | bool): + span.set_attribute(f"mcp.{key}", value) + + return span + + def on_request_end( + self, + token: Any, + request_id: RequestId, + request_type: str, + success: bool, + duration_seconds: float | None = None, + **metadata: Any, + ) -> None: + """End the span for the request. + + Args: + token: The span object returned from on_request_start + """ + if not self._enabled or token is None: + return + + span = token + + # Set success status + span.set_attribute("mcp.success", success) + + if duration_seconds is not None: + span.set_attribute("mcp.duration_seconds", duration_seconds) + + # Set span status + if success: + span.set_status(self._Status(self._StatusCode.OK)) + else: + span.set_status(self._Status(self._StatusCode.ERROR)) + # Add error info if available + error_msg = metadata.get("error") + if error_msg: + span.set_attribute("mcp.error", str(error_msg)) + + # Check if cancelled + if metadata.get("cancelled"): + span.set_attribute("mcp.cancelled", True) + + # End the span + span.end() + + def on_error( + self, + token: Any, + request_id: RequestId | None, + error: Exception, + error_type: str, + **metadata: Any, + ) -> None: + """Record error information in the span. + + Args: + token: The span object returned from on_request_start + """ + if not self._enabled or token is None: + return + + span = token + + # Record the exception + span.record_exception(error) + + # Set error attributes + span.set_attribute("mcp.error_type", error_type) + span.set_attribute("mcp.error_message", str(error)) + + # Mark span as error + span.set_status(self._Status(self._StatusCode.ERROR, str(error))) + + +# Example usage +if __name__ == "__main__": + import asyncio + + from opentelemetry import trace + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter + + import mcp.types as types + from mcp.server.lowlevel import Server + from mcp.shared.memory import create_connected_server_and_client_session + + # Setup OpenTelemetry + resource = Resource.create({"service.name": "mcp-example-server"}) + provider = TracerProvider(resource=resource) + processor = BatchSpanProcessor(ConsoleSpanExporter()) + provider.add_span_processor(processor) + trace.set_tracer_provider(provider) + + # Create instrumenter + instrumenter = OpenTelemetryInstrumenter() + + # Create server + server = Server("example-server") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="echo", + description="Echo a message", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]: + if name == "echo": + message = arguments.get("message", "") + return [types.TextContent(type="text", text=f"Echo: {message}")] + raise ValueError(f"Unknown tool: {name}") + + async def main(): + print("Running MCP server with OpenTelemetry instrumentation...") + print("Traces will be printed to console.\n") + + async with create_connected_server_and_client_session( + server, + raise_exceptions=True, + ) as client: + # Note: In production, you would pass the instrumenter when creating the ServerSession + # For this example, we're using the test helper which doesn't expose that parameter + + await client.initialize() + + # List tools + print("Listing tools...") + tools_result = await client.list_tools() + print(f"Found {len(tools_result.tools)} tools\n") + + # Call a tool + print("Calling echo tool...") + result = await client.call_tool("echo", {"message": "Hello, OpenTelemetry!"}) + print(f"Result: {result}\n") + + # Give time for spans to export + await asyncio.sleep(1) + + asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index 18cbb034bb..2612b7f7bd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -18,6 +18,7 @@ nav: - Low-Level Server: low-level-server.md - Authorization: authorization.md - Testing: testing.md + - Instrumentation: instrumentation.md - API Reference: api.md theme: diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index be47d681fb..3b59ea6d83 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,6 +9,7 @@ import mcp.types as types from mcp.shared.context import RequestContext +from mcp.shared.instrumentation import Instrumenter, get_default_instrumenter from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -118,6 +119,7 @@ def __init__( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, + instrumenter: Instrumenter | None = None, ) -> None: super().__init__( read_stream, @@ -127,6 +129,7 @@ def __init__( read_timeout_seconds=read_timeout_seconds, ) self._client_info = client_info or DEFAULT_CLIENT_INFO + self._instrumenter = instrumenter or get_default_instrumenter() self._sampling_callback = sampling_callback or _default_sampling_callback self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback @@ -135,6 +138,11 @@ def __init__( self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._server_capabilities: types.ServerCapabilities | None = None + @property + def instrumenter(self) -> Instrumenter: # pragma: no cover + """Get the instrumenter for this session.""" + return self._instrumenter + async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None elicitation = ( diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index a0617036f9..3263824299 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -70,6 +70,7 @@ async def main(): import contextvars import json import logging +import time import warnings from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager @@ -88,6 +89,7 @@ async def main(): from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError +from mcp.shared.instrumentation import Instrumenter from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.tool_name_validation import validate_and_warn_tool_name @@ -615,6 +617,7 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, + instrumenter: Instrumenter | None = None, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) @@ -624,6 +627,7 @@ async def run( write_stream, initialization_options, stateless=stateless, + instrumenter=instrumenter, ) ) @@ -674,11 +678,28 @@ async def _handle_request( lifespan_context: LifespanResultT, raise_exceptions: bool, ): - logger.info("Processing request of type %s", type(req).__name__) + request_type = type(req).__name__ + log_extra = {"request_id": str(message.request_id)} + logger.info("Processing request of type %s", request_type, extra=log_extra) + + # Start instrumentation and capture token + start_time = time.monotonic() + instrumentation_token = None + try: + instrumentation_token = session.instrumenter.on_request_start( + request_id=message.request_id, + request_type=request_type, + session_type="server", + ) + except Exception: # pragma: no cover + logger.exception("Error in instrumentation on_request_start") + if handler := self.request_handlers.get(type(req)): # type: ignore - logger.debug("Dispatching request of type %s", type(req).__name__) + logger.debug("Dispatching request of type %s", request_type, extra=log_extra) - token = None + context_token = None + response = None + success = False try: # Extract request context from message metadata request_data = None @@ -689,7 +710,7 @@ async def _handle_request( # Set our global state that can be retrieved via # app.get_request_context() - token = request_ctx.set( + context_token = request_ctx.set( RequestContext( message.request_id, message.request_meta, @@ -699,22 +720,65 @@ async def _handle_request( ) ) response = await handler(req) + success = not isinstance(response, types.ErrorData) except McpError as err: # pragma: no cover response = err.error + try: + session.instrumenter.on_error( + token=instrumentation_token, + request_id=message.request_id, + error=err, + error_type=type(err).__name__, + ) + except Exception: # pragma: no cover + logger.exception("Error in instrumentation on_error") except anyio.get_cancelled_exc_class(): # pragma: no cover logger.info( "Request %s cancelled - duplicate response suppressed", message.request_id, + extra=log_extra, ) + try: + session.instrumenter.on_request_end( + token=instrumentation_token, + request_id=message.request_id, + request_type=request_type, + success=False, + duration_seconds=time.monotonic() - start_time, + cancelled=True, + ) + except Exception: # pragma: no cover + logger.exception("Error in instrumentation on_request_end") return except Exception as err: # pragma: no cover + try: + session.instrumenter.on_error( + token=instrumentation_token, + request_id=message.request_id, + error=err, + error_type=type(err).__name__, + ) + except Exception: # pragma: no cover + logger.exception("Error in instrumentation on_error") if raise_exceptions: raise err response = types.ErrorData(code=0, message=str(err), data=None) finally: # Reset the global state after we are done - if token is not None: # pragma: no branch - request_ctx.reset(token) + if context_token is not None: # pragma: no branch + request_ctx.reset(context_token) + + # End instrumentation + try: + session.instrumenter.on_request_end( + token=instrumentation_token, + request_id=message.request_id, + request_type=request_type, + success=success, + duration_seconds=time.monotonic() - start_time, + ) + except Exception: # pragma: no cover + logger.exception("Error in instrumentation on_request_end") await message.respond(response) else: # pragma: no cover @@ -724,8 +788,19 @@ async def _handle_request( message="Method not found", ) ) + try: + session.instrumenter.on_request_end( + token=instrumentation_token, + request_id=message.request_id, + request_type=request_type, + success=False, + duration_seconds=time.monotonic() - start_time, + error="Method not found", + ) + except Exception: # pragma: no cover + logger.exception("Error in instrumentation on_request_end") - logger.debug("Response sent") + logger.debug("Response sent", extra=log_extra) async def _handle_notification(self, notify: Any): if handler := self.notification_handlers.get(type(notify)): # type: ignore diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index b116fbe384..86b989db2f 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -48,6 +48,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions from mcp.shared.exceptions import McpError +from mcp.shared.instrumentation import Instrumenter, get_default_instrumenter from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -87,6 +88,7 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + instrumenter: Instrumenter | None = None, ) -> None: super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) self._initialization_state = ( @@ -94,6 +96,7 @@ def __init__( ) self._init_options = init_options + self._instrumenter = instrumenter or get_default_instrumenter() self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ ServerRequestResponder ](0) @@ -103,6 +106,11 @@ def __init__( def client_params(self) -> types.InitializeRequestParams | None: return self._client_params # pragma: no cover + @property + def instrumenter(self) -> Instrumenter: # pragma: no cover + """Get the instrumenter for this session.""" + return self._instrumenter + def check_client_capability(self, capability: types.ClientCapabilities) -> bool: # pragma: no cover """Check if the client supports a specific capability.""" if self._client_params is None: diff --git a/src/mcp/shared/instrumentation.py b/src/mcp/shared/instrumentation.py new file mode 100644 index 0000000000..faf1cb97b4 --- /dev/null +++ b/src/mcp/shared/instrumentation.py @@ -0,0 +1,140 @@ +"""Instrumentation interface for MCP SDK observability. + +This module provides a pluggable instrumentation interface for monitoring +MCP request/response lifecycle. It's designed to support integration with +OpenTelemetry and other observability tools. + +See: https://github.com/modelcontextprotocol/python-sdk/issues/421 +""" + +from __future__ import annotations + +from typing import Any, Protocol + +from mcp.types import RequestId + + +class Instrumenter(Protocol): + """Protocol for instrumenting MCP request/response lifecycle. + + Implementers can use this to integrate with OpenTelemetry, custom metrics, + logging frameworks, or other observability tools. + + The token-based design allows instrumenters to maintain state (like OpenTelemetry + spans) between on_request_start and on_request_end without side-channels. + + All methods are optional (no-op implementations are valid). Exceptions + raised by instrumentation hooks are logged but do not affect request processing. + """ + + def on_request_start( + self, + request_id: RequestId, + request_type: str, + method: str | None = None, + **metadata: Any, + ) -> Any: + """Called when a request starts processing. + + Args: + request_id: Unique identifier for this request + request_type: Type name of the request (e.g., "CallToolRequest") + method: Optional method name being called (e.g., tool/resource name) + **metadata: Additional context (session_type, client_info, etc.) + + Returns: + A token (any value) that will be passed to on_request_end/on_error. + This allows instrumenters to maintain state (e.g., OpenTelemetry spans) + without needing external storage. + """ + ... + + def on_request_end( + self, + token: Any, + request_id: RequestId, + request_type: str, + success: bool, + duration_seconds: float | None = None, + **metadata: Any, + ) -> None: + """Called when a request completes (successfully or not). + + Args: + token: The value returned from on_request_start + request_id: Unique identifier for this request + request_type: Type name of the request + success: Whether the request completed successfully + duration_seconds: Optional request duration in seconds + **metadata: Additional context (error info, result summary, etc.) + """ + ... + + def on_error( + self, + token: Any, + request_id: RequestId | None, + error: Exception, + error_type: str, + **metadata: Any, + ) -> None: + """Called when an error occurs during request processing. + + Args: + token: The value returned from on_request_start (may be None) + request_id: Request ID if available, None for session-level errors + error: The exception that occurred + error_type: Type name of the error + **metadata: Additional error context + """ + ... + + +class NoOpInstrumenter: + """Default no-op implementation of the Instrumenter protocol. + + This implementation does nothing and has minimal overhead. + Used as the default when no instrumentation is configured. + """ + + def on_request_start( + self, + request_id: RequestId, + request_type: str, + method: str | None = None, + **metadata: Any, + ) -> None: + """No-op implementation that returns None as token.""" + return None + + def on_request_end( + self, + token: Any, + request_id: RequestId, + request_type: str, + success: bool, + duration_seconds: float | None = None, + **metadata: Any, + ) -> None: + """No-op implementation.""" + pass + + def on_error( + self, + token: Any, + request_id: RequestId | None, + error: Exception, + error_type: str, + **metadata: Any, + ) -> None: + """No-op implementation.""" + pass + + +# Global default instance +_default_instrumenter = NoOpInstrumenter() + + +def get_default_instrumenter() -> Instrumenter: + """Get the default no-op instrumenter instance.""" + return _default_instrumenter diff --git a/tests/shared/test_instrumentation.py b/tests/shared/test_instrumentation.py new file mode 100644 index 0000000000..7c73eee63b --- /dev/null +++ b/tests/shared/test_instrumentation.py @@ -0,0 +1,194 @@ +"""Tests for instrumentation interface.""" + +from typing import Any + +import pytest + +from mcp.shared.instrumentation import NoOpInstrumenter, get_default_instrumenter +from mcp.types import RequestId + + +class MockInstrumenter: + """Track calls to instrumentation hooks for testing.""" + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + + def on_request_start( + self, request_id: RequestId, request_type: str, method: str | None = None, **metadata: Any + ) -> dict[str, Any]: + call: dict[str, Any] = { + "hook": "on_request_start", + "request_id": request_id, + "request_type": request_type, + "method": method, + "metadata": metadata, + } + self.calls.append(call) + # Return the call itself as a token for testing + return call + + def on_request_end( + self, + token: Any, + request_id: RequestId, + request_type: str, + success: bool, + duration_seconds: float | None = None, + **metadata: Any, + ) -> None: + self.calls.append( + { + "hook": "on_request_end", + "token": token, + "request_id": request_id, + "request_type": request_type, + "success": success, + "duration_seconds": duration_seconds, + "metadata": metadata, + } + ) + + def on_error( + self, token: Any, request_id: RequestId | None, error: Exception, error_type: str, **metadata: Any + ) -> None: + self.calls.append( + { + "hook": "on_error", + "token": token, + "request_id": request_id, + "error": error, + "error_type": error_type, + "metadata": metadata, + } + ) + + def get_calls_by_hook(self, hook_name: str) -> list[dict[str, Any]]: + """Get all calls to a specific hook.""" + return [call for call in self.calls if call["hook"] == hook_name] + + def get_calls_by_request_id(self, request_id: RequestId) -> list[dict[str, Any]]: + """Get all calls for a specific request_id.""" + return [call for call in self.calls if call.get("request_id") == request_id] + + +@pytest.mark.anyio +async def test_noop_instrumenter(): + """Test that NoOpInstrumenter does nothing and doesn't raise errors.""" + instrumenter = NoOpInstrumenter() + + # Should not raise any errors + token = instrumenter.on_request_start(request_id=1, request_type="TestRequest") + instrumenter.on_request_end(token=token, request_id=1, request_type="TestRequest", success=True) + instrumenter.on_error(token=token, request_id=1, error=Exception("test"), error_type="Exception") + + +def test_get_default_instrumenter(): + """Test that get_default_instrumenter returns a NoOpInstrumenter.""" + instrumenter = get_default_instrumenter() + assert isinstance(instrumenter, NoOpInstrumenter) + + +def test_instrumenter_protocol(): + """Test that MockInstrumenter implements the Instrumenter protocol.""" + instrumenter = MockInstrumenter() + + # Call all methods to ensure they exist + token = instrumenter.on_request_start(request_id=1, request_type="TestRequest", method="test_method") + instrumenter.on_request_end( + token=token, request_id=1, request_type="TestRequest", success=True, duration_seconds=1.5 + ) + instrumenter.on_error(token=token, request_id=1, error=Exception("test"), error_type="Exception") + + # Verify calls were tracked + assert len(instrumenter.calls) == 3 + assert instrumenter.get_calls_by_hook("on_request_start")[0]["request_type"] == "TestRequest" + assert instrumenter.get_calls_by_hook("on_request_end")[0]["success"] is True + assert instrumenter.get_calls_by_hook("on_error")[0]["error_type"] == "Exception" + + +def test_instrumenter_tracks_request_id(): + """Test that request_id is tracked consistently across hooks.""" + instrumenter = MockInstrumenter() + test_request_id = 42 + + token = instrumenter.on_request_start(request_id=test_request_id, request_type="TestRequest") + instrumenter.on_request_end(token=token, request_id=test_request_id, request_type="TestRequest", success=True) + + # Verify request_id is consistent + calls = instrumenter.get_calls_by_request_id(test_request_id) + assert len(calls) == 2 + assert all(call["request_id"] == test_request_id for call in calls) + + +def test_instrumenter_metadata(): + """Test that metadata is passed through correctly.""" + instrumenter = MockInstrumenter() + + instrumenter.on_request_start( + request_id=1, request_type="TestRequest", method="test_tool", session_type="server", custom_field="custom_value" + ) + + call = instrumenter.get_calls_by_hook("on_request_start")[0] + assert call["metadata"]["session_type"] == "server" + assert call["metadata"]["custom_field"] == "custom_value" + assert call["method"] == "test_tool" + + +def test_instrumenter_duration_tracking(): + """Test that duration is passed to on_request_end.""" + instrumenter = MockInstrumenter() + + token = {"test": "token"} + instrumenter.on_request_end( + token=token, request_id=1, request_type="TestRequest", success=True, duration_seconds=2.5 + ) + + call = instrumenter.get_calls_by_hook("on_request_end")[0] + assert call["duration_seconds"] == 2.5 + assert call["token"] == token + + +def test_instrumenter_error_info(): + """Test that error information is captured correctly.""" + instrumenter = MockInstrumenter() + test_error = ValueError("test error message") + + token = {"test": "token"} + instrumenter.on_error( + token=token, request_id=1, error=test_error, error_type="ValueError", extra_info="additional context" + ) + + call = instrumenter.get_calls_by_hook("on_error")[0] + assert call["error"] is test_error + assert call["error_type"] == "ValueError" + assert call["metadata"]["extra_info"] == "additional context" + assert call["token"] == token + + +def test_instrumenter_token_flow(): + """Test that token is passed correctly from start to end/error.""" + instrumenter = MockInstrumenter() + + # Start request and get token + token = instrumenter.on_request_start(request_id=1, request_type="TestRequest", method="test_tool") + assert token is not None + assert isinstance(token, dict) + assert token["request_id"] == 1 + + # End request with the token + instrumenter.on_request_end( + token=token, request_id=1, request_type="TestRequest", success=True, duration_seconds=1.5 + ) + + # Verify token is the same + start_call = instrumenter.get_calls_by_hook("on_request_start")[0] + end_call = instrumenter.get_calls_by_hook("on_request_end")[0] + assert end_call["token"] is start_call # Token should be the start call itself + + # Test error path + token2 = instrumenter.on_request_start(request_id=2, request_type="TestRequest2") + instrumenter.on_error(token=token2, request_id=2, error=Exception("test"), error_type="Exception") + + error_call = instrumenter.get_calls_by_hook("on_error")[0] + assert error_call["token"]["request_id"] == 2