diff --git a/pyproject.toml b/pyproject.toml index 737839a23..7d6ac742e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ ] dependencies = [ "anyio>=4.5", + "exceptiongroup>=1.2.0; python_version < '3.11'", "httpx>=0.27.1", "httpx-sse>=0.4", "pydantic>=2.12.0", diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index e6e938673..9cd17ec50 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -5,7 +5,15 @@ from collections.abc import AsyncIterator from contextlib import AbstractAsyncContextManager, asynccontextmanager from types import TracebackType -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio @@ -49,20 +57,27 @@ async def _connect(self) -> AsyncIterator[TransportStreams]: server_read, server_write = server_streams async with anyio.create_task_group() as tg: - # Start server in background - tg.start_soon( - lambda: actual_server.run( - server_read, - server_write, - actual_server.create_initialization_options(), - raise_exceptions=self._raise_exceptions, + try: + # Start server in background + tg.start_soon( + lambda: actual_server.run( + server_read, + server_write, + actual_server.create_initialization_options(), + raise_exceptions=self._raise_exceptions, + ) ) - ) - try: - yield client_read, client_write - finally: - tg.cancel_scope.cancel() + try: + yield client_read, client_write + finally: + tg.cancel_scope.cancel() + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc async def __aenter__(self) -> TransportStreams: """Connect to the server and return streams for communication.""" diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..9cb5518c6 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -11,7 +11,15 @@ from collections.abc import Callable from dataclasses import dataclass from types import TracebackType -from typing import Any, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio import httpx @@ -167,8 +175,15 @@ async def __aexit__( # Concurrently close session stacks. async with anyio.create_task_group() as tg: - for exit_stack in self._session_exit_stacks.values(): - tg.start_soon(exit_stack.aclose) + try: + for exit_stack in self._session_exit_stacks.values(): + tg.start_soon(exit_stack.aclose) + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc @property def sessions(self) -> list[mcp.ClientSession]: diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 61026aa0c..0433b0073 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,9 +1,17 @@ import logging from collections.abc import Callable from contextlib import asynccontextmanager -from typing import Any +from typing import TYPE_CHECKING, Any from urllib.parse import parse_qs, urljoin, urlparse +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup + import anyio import httpx from anyio.abc import TaskStatus @@ -157,6 +165,12 @@ async def post_writer(endpoint_url: str): yield read_stream, write_stream finally: tg.cancel_scope.cancel() + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc finally: await read_stream_writer.aclose() await write_stream.aclose() diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 902dc8576..dcca4f57c 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -3,7 +3,15 @@ import sys from contextlib import asynccontextmanager from pathlib import Path -from typing import Literal, TextIO +from typing import TYPE_CHECKING, Literal, TextIO + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio import anyio.lowlevel @@ -178,37 +186,44 @@ async def stdin_writer(): await anyio.lowlevel.checkpoint() async with anyio.create_task_group() as tg, process: - tg.start_soon(stdout_reader) - tg.start_soon(stdin_writer) try: - yield read_stream, write_stream - finally: - # MCP spec: stdio shutdown sequence - # 1. Close input stream to server - # 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time - # 3. Send SIGKILL if still not exited - if process.stdin: # pragma: no branch + tg.start_soon(stdout_reader) + tg.start_soon(stdin_writer) + try: + yield read_stream, write_stream + finally: + # MCP spec: stdio shutdown sequence + # 1. Close input stream to server + # 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time + # 3. Send SIGKILL if still not exited + if process.stdin: # pragma: no branch + try: + await process.stdin.aclose() + except Exception: # pragma: no cover + # stdin might already be closed, which is fine + pass + try: - await process.stdin.aclose() - except Exception: # pragma: no cover - # stdin might already be closed, which is fine + # Give the process time to exit gracefully after stdin closes + with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT): + await process.wait() + except TimeoutError: + # Process didn't exit from stdin closure, use platform-specific termination + # which handles SIGTERM -> SIGKILL escalation + await _terminate_process_tree(process) + except ProcessLookupError: # pragma: no cover + # Process already exited, which is fine pass - - try: - # Give the process time to exit gracefully after stdin closes - with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT): - await process.wait() - except TimeoutError: - # Process didn't exit from stdin closure, use platform-specific termination - # which handles SIGTERM -> SIGKILL escalation - await _terminate_process_tree(process) - except ProcessLookupError: # pragma: no cover - # Process already exited, which is fine - pass - await read_stream.aclose() - await write_stream.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc def _get_executable_command(command: str) -> str: diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9f3dd5e0b..c5accf9e1 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -7,6 +7,15 @@ from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio import httpx @@ -574,6 +583,13 @@ def start_get_stream() -> None: if transport.session_id and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() + except BaseExceptionGroup as e: + # Unwrap ExceptionGroup to get only the real error + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc finally: await read_stream_writer.aclose() await write_stream.aclose() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 79e75fad1..69d16446a 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -1,6 +1,15 @@ import json from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -69,12 +78,19 @@ async def ws_writer(): await ws.send(json.dumps(msg_dict)) async with anyio.create_task_group() as tg: - # Start reader and writer tasks - tg.start_soon(ws_reader) - tg.start_soon(ws_writer) + try: + # Start reader and writer tasks + tg.start_soon(ws_reader) + tg.start_soon(ws_writer) + + # Yield the receive/send streams + yield (read_stream, write_stream) - # Yield the receive/send streams - yield (read_stream, write_stream) + # Once the caller's 'async with' block exits, we shut down + tg.cancel_scope.cancel() + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception - # Once the caller's 'async with' block exits, we shut down - tg.cancel_scope.cancel() + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index b2268bc1c..e7fd2bef8 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -10,7 +10,15 @@ """ import logging -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio @@ -163,25 +171,32 @@ async def _wait_for_task_update(self, task_id: str) -> None: Races between store update and queue message - first one wins. """ async with anyio.create_task_group() as tg: - - async def wait_for_store() -> None: - try: - await self._store.wait_for_update(task_id) - except Exception: - pass - finally: - tg.cancel_scope.cancel() - - async def wait_for_queue() -> None: - try: - await self._queue.wait_for_message(task_id) - except Exception: - pass - finally: - tg.cancel_scope.cancel() - - tg.start_soon(wait_for_store) - tg.start_soon(wait_for_queue) + try: + + async def wait_for_store() -> None: + try: + await self._store.wait_for_update(task_id) + except Exception: + pass + finally: + tg.cancel_scope.cancel() + + async def wait_for_queue() -> None: + try: + await self._queue.wait_for_message(task_id) + except Exception: + pass + finally: + tg.cancel_scope.cancel() + + tg.start_soon(wait_for_store) + tg.start_soon(wait_for_queue) + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: """Route a response back to the waiting resolver. diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py index b54219504..dde23a7f0 100644 --- a/src/mcp/server/experimental/task_support.py +++ b/src/mcp/server/experimental/task_support.py @@ -7,6 +7,15 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio from anyio.abc import TaskGroup @@ -80,11 +89,18 @@ async def run(self) -> AsyncIterator[None]: ... """ async with anyio.create_task_group() as tg: - self._task_group = tg try: - yield - finally: - self._task_group = None + self._task_group = tg + try: + yield + finally: + self._task_group = None + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc def configure_session(self, session: ServerSession) -> None: """Configure a session for task support. diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index aee644040..09cbe40a8 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -42,7 +42,15 @@ async def main(): from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -390,16 +398,23 @@ async def run( await stack.enter_async_context(task_support.run()) async with anyio.create_task_group() as tg: - async for message in session.incoming_messages: - logger.debug("Received message: %s", message) - - tg.start_soon( - self._handle_message, - message, - session, - lifespan_context, - raise_exceptions, - ) + try: + async for message in session.incoming_messages: + logger.debug("Received message: %s", message) + + tg.start_soon( + self._handle_message, + message, + session, + lifespan_context, + raise_exceptions, + ) + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc async def _handle_message( self, diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9007230ce..6863ba368 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -38,10 +38,18 @@ async def handle_sse(request): import logging from contextlib import asynccontextmanager -from typing import Any +from typing import TYPE_CHECKING, Any from urllib.parse import quote from uuid import UUID, uuid4 +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup + import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError @@ -175,24 +183,31 @@ async def sse_writer(): ) async with anyio.create_task_group() as tg: + try: + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") + + logger.debug("Starting SSE response task") + tg.start_soon(response_wrapper, scope, receive, send) + + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception - async def response_wrapper(scope: Scope, receive: Receive, send: Send): - """The EventSourceResponse returning signals a client close / disconnect. - In this case we close our side of the streams to signal the client that - the connection has been closed. - """ - await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( - scope, receive, send - ) - await read_stream_writer.aclose() - await write_stream_reader.aclose() - logging.debug(f"Client session disconnected {session_id}") - - logger.debug("Starting SSE response task") - tg.start_soon(response_wrapper, scope, receive, send) - - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover logger.debug("Handling POST message") diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index e526bab56..6f0ac0ed9 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -20,6 +20,15 @@ async def run_server(): import sys from contextlib import asynccontextmanager from io import TextIOWrapper +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio import anyio.lowlevel @@ -78,6 +87,13 @@ async def stdout_writer(): await anyio.lowlevel.checkpoint() async with anyio.create_task_group() as tg: - tg.start_soon(stdin_reader) - tg.start_soon(stdout_writer) - yield read_stream, write_stream + try: + tg.start_soon(stdin_reader) + tg.start_soon(stdout_writer) + yield read_stream, write_stream + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 04aed345e..4a3495ccf 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -13,7 +13,15 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio import pydantic_core @@ -429,7 +437,9 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se return False return True - async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: + async def _handle_post_request( # noqa: C901 - Function is complex but handles multiple request types + self, scope: Scope, request: Request, receive: Receive, send: Send + ) -> None: """Handle POST requests containing JSON-RPC messages.""" writer = self._read_stream_writer if writer is None: # pragma: no cover @@ -615,10 +625,19 @@ async def sse_writer(): # pragma: lax no cover try: # First send the response to establish the SSE connection async with anyio.create_task_group() as tg: - tg.start_soon(response, scope, receive, send) - # Then send the message to be processed by the server - session_message = self._create_session_message(message, request, request_id, protocol_version) - await writer.send(session_message) + try: + tg.start_soon(response, scope, receive, send) + # Then send the message to be processed by the server + session_message = self._create_session_message( + message, request, request_id, protocol_version + ) + await writer.send(session_message) + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc except Exception: # pragma: no cover logger.exception("SSE response error") await sse_stream_writer.aclose() @@ -971,67 +990,88 @@ async def connect( # Start a task group for message routing async with anyio.create_task_group() as tg: - # Create a message router that distributes messages to request streams - async def message_router(): - try: - async for session_message in write_stream_reader: # pragma: no branch - # Determine which request stream(s) should receive this message - message = session_message.message - target_request_id = None - # Check if this is a response with a known request id. - # Null-id errors (e.g., parse errors) fall through to - # the GET stream since they can't be correlated. - if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: - target_request_id = str(message.id) - # Extract related_request_id from meta if it exists - elif ( # pragma: no cover - session_message.metadata is not None - and isinstance( - session_message.metadata, - ServerMessageMetadata, - ) - and session_message.metadata.related_request_id is not None - ): - target_request_id = str(session_message.metadata.related_request_id) - - request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY - - # Store the event if we have an event store, - # regardless of whether a client is connected - # messages will be replayed on the re-connect - event_id = None - if self._event_store: # pragma: lax no cover - event_id = await self._event_store.store_event(request_stream_id, message) - logger.debug(f"Stored {event_id} from {request_stream_id}") - - if request_stream_id in self._request_streams: - try: - # Send both the message and the event ID - await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) - except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover - # Stream might be closed, remove from registry - self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover - logger.debug( - f"""Request stream {request_stream_id} not found - for message. Still processing message as the client - might reconnect and replay.""" - ) - except anyio.ClosedResourceError: - if self._terminated: - logger.debug("Read stream closed by client") - else: - logger.exception("Unexpected closure of read stream in message router") - except Exception: # pragma: lax no cover - logger.exception("Error in message router") + try: + # Create a message router that distributes messages to request streams + async def message_router(): + try: + async for session_message in write_stream_reader: # pragma: no branch + # Determine which request stream(s) should receive this message + message = session_message.message + target_request_id = None + # Check if this is a response with a known request id. + # Null-id errors (e.g., parse errors) fall through to + # the GET stream since they can't be correlated. + if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: + target_request_id = str(message.id) + # Extract related_request_id from meta if it exists + elif ( # pragma: no cover + session_message.metadata is not None + and isinstance( + session_message.metadata, + ServerMessageMetadata, + ) + and session_message.metadata.related_request_id is not None + ): + target_request_id = str(session_message.metadata.related_request_id) + + request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY + + # Store the event if we have an event store, + # regardless of whether a client is connected + # messages will be replayed on the re-connect + event_id = None + if self._event_store: # pragma: lax no cover + event_id = await self._event_store.store_event(request_stream_id, message) + logger.debug(f"Stored {event_id} from {request_stream_id}") + + if request_stream_id in self._request_streams: + try: + # Send both the message and the event ID + await self._request_streams[request_stream_id][0].send( + EventMessage(message, event_id) + ) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover + # Stream might be closed, remove from registry + self._request_streams.pop(request_stream_id, None) + else: # pragma: no cover + logger.debug( + f"""Request stream {request_stream_id} not found + for message. Still processing message as the client + might reconnect and replay.""" + ) + except anyio.ClosedResourceError: + if self._terminated: + logger.debug("Read stream closed by client") + else: + logger.exception("Unexpected closure of read stream in message router") + except Exception: # pragma: lax no cover + logger.exception("Error in message router") - # Start the message router - tg.start_soon(message_router) + # Start the message router + tg.start_soon(message_router) - try: - # Yield the streams for the caller to use - yield read_stream, write_stream - finally: + try: + # Yield the streams for the caller to use + yield read_stream, write_stream + finally: + for stream_id in list(self._request_streams.keys()): # pragma: lax no cover + await self._clean_up_memory_streams(stream_id) + self._request_streams.clear() + + # Clean up the read and write streams + try: + await read_stream_writer.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() + await write_stream.aclose() + except Exception: # pragma: no cover + logger.exception("Error closing streams") + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc for stream_id in list(self._request_streams.keys()): # pragma: lax no cover await self._clean_up_memory_streams(stream_id) self._request_streams.clear() diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 50bcd5e79..04a61f8f3 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -9,6 +9,14 @@ from typing import TYPE_CHECKING, Any from uuid import uuid4 +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup + import anyio from anyio.abc import TaskStatus from starlette.requests import Request @@ -123,18 +131,25 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: self._has_started = True async with anyio.create_task_group() as tg: - # Store the task group for later use - self._task_group = tg - logger.info("StreamableHTTP session manager started") try: - yield # Let the application run - finally: - logger.info("StreamableHTTP session manager shutting down") - # Cancel task group to stop all spawned tasks - tg.cancel_scope.cancel() - self._task_group = None - # Clear any remaining server instances - self._server_instances.clear() + # Store the task group for later use + self._task_group = tg + logger.info("StreamableHTTP session manager started") + try: + yield # Let the application run + finally: + logger.info("StreamableHTTP session manager shutting down") + # Cancel task group to stop all spawned tasks + tg.cancel_scope.cancel() + self._task_group = None + # Clear any remaining server instances + self._server_instances.clear() + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 3e675da5f..9b0d3fc7c 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -1,4 +1,13 @@ from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -53,6 +62,13 @@ async def ws_writer(): await websocket.close() async with anyio.create_task_group() as tg: - tg.start_soon(ws_reader) - tg.start_soon(ws_writer) - yield (read_stream, write_stream) + try: + tg.start_soon(ws_reader) + tg.start_soon(ws_writer) + yield (read_stream, write_stream) + except BaseExceptionGroup as e: + from mcp.shared.exceptions import unwrap_task_group_exception + + real_exc = unwrap_task_group_exception(e) + if real_exc is not e: + raise real_exc diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319..c09e05e4b 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -1,6 +1,14 @@ from __future__ import annotations -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError @@ -104,3 +112,48 @@ def from_error(cls, error: ErrorData) -> UrlElicitationRequiredError: raw_elicitations = cast(list[dict[str, Any]], data.get("elicitations", [])) elicitations = [ElicitRequestURLParams.model_validate(e) for e in raw_elicitations] return cls(elicitations, error.message) + + +def unwrap_task_group_exception(exc: BaseException) -> BaseException: + """Unwrap an exception from a task group, extracting only the real error. + + When anyio task groups fail, they raise BaseExceptionGroup containing: + - The original error that caused the failure + - CancelledError from sibling tasks that were cancelled + + This function extracts only the real error, ignoring cancelled siblings. + + Args: + exc: The exception to unwrap (could be any exception) + + Returns: + The unwrapped exception if it was an ExceptionGroup with a real error, + otherwise the original exception + + Example: + ```python + try: + async with anyio.create_task_group() as tg: + tg.start_soon(task1) + tg.start_soon(task2) + except BaseExceptionGroup as e: + # Extract only the real error, ignore CancelledError + real_exc = unwrap_task_group_exception(e) + raise real_exc + ``` + """ + import anyio + + # If not an exception group, return as-is + if not isinstance(exc, BaseExceptionGroup): + return exc + + # Find the first non-cancelled exception + cancelled_exc_class = anyio.get_cancelled_exc_class() + for sub_exc in exc.exceptions: # type: ignore[reportUnknownVariableType] + if not isinstance(sub_exc, cancelled_exc_class): + # Type narrowing: we know this is not a CancelledError + return sub_exc # type: ignore[reportUnknownVariableType] + + # All were cancelled, return the group + return exc # type: ignore[reportUnknownVariableType] diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b617d702f..2a5c3225a 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -223,12 +223,23 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: + from mcp.shared.exceptions import unwrap_task_group_exception + await self._exit_stack.aclose() # Using BaseSession as a context manager should not block on exit (this # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + # Exit the task group and unwrap any ExceptionGroup + try: + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + except BaseException as e: + # Unwrap ExceptionGroup to get only the real error + unwrapped = unwrap_task_group_exception(e) + if unwrapped is not e: + raise unwrapped + raise async def send_request( self, diff --git a/tests/shared/test_exceptions.py b/tests/shared/test_exceptions.py index 9a7466264..72f0c8511 100644 --- a/tests/shared/test_exceptions.py +++ b/tests/shared/test_exceptions.py @@ -1,7 +1,18 @@ """Tests for MCP exception classes.""" +from typing import TYPE_CHECKING + +import anyio import pytest +if TYPE_CHECKING: + from builtins import BaseExceptionGroup +else: + try: + from builtins import BaseExceptionGroup + except ImportError: + from exceptiongroup import BaseExceptionGroup + from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData @@ -162,3 +173,65 @@ def test_url_elicitation_required_error_exception_message() -> None: # The exception's string representation should match the message assert str(error) == "URL elicitation required" + + +# Tests for unwrap_task_group_exception + + +@pytest.mark.anyio +async def test_unwrap_single_error() -> None: + """Test that a single exception is returned as-is.""" + from mcp.shared.exceptions import unwrap_task_group_exception + + error = ValueError("test error") + result = unwrap_task_group_exception(error) + assert result is error + + +@pytest.mark.anyio +async def test_unwrap_exception_group_with_real_error() -> None: + """Test that real error is extracted from ExceptionGroup.""" + from mcp.shared.exceptions import unwrap_task_group_exception + + real_error = ConnectionError("connection failed") + + # Simulate what anyio does: create exception group with real error + cancelled + try: + async with anyio.create_task_group() as tg: + tg.start_soon(lambda: (_ for _ in ()).throw(real_error)) + tg.start_soon(anyio.sleep, 999) # Will be cancelled + except BaseExceptionGroup as e: + result = unwrap_task_group_exception(e) + assert isinstance(result, ConnectionError) + assert str(result) == "connection failed" + + +@pytest.mark.anyio +async def test_unwrap_exception_group_all_cancelled() -> None: + """Test that when all exceptions are cancelled, the group is re-raised.""" + from mcp.shared.exceptions import unwrap_task_group_exception + + try: + async with anyio.create_task_group() as tg: + tg.start_soon(anyio.sleep, 999) + tg.cancel_scope.cancel() + except BaseExceptionGroup as e: + # Should return the group if all are cancelled + result = unwrap_task_group_exception(e) + assert isinstance(result, BaseExceptionGroup) + + +@pytest.mark.anyio +async def test_unwrap_preserves_non_cancelled_errors() -> None: + """Test that all non-cancelled exceptions are preserved.""" + from mcp.shared.exceptions import unwrap_task_group_exception + + error1 = ValueError("error 1") + error2 = RuntimeError("error 2") + + # Create an exception group with multiple real errors + group = BaseExceptionGroup("multiple", [error1, error2]) + + result = unwrap_task_group_exception(group) + # Should return the first non-cancelled error + assert result is error1 diff --git a/tests/shared/test_session_exception_group.py b/tests/shared/test_session_exception_group.py new file mode 100644 index 000000000..18011f609 --- /dev/null +++ b/tests/shared/test_session_exception_group.py @@ -0,0 +1,50 @@ +"""Test that BaseSession unwraps ExceptionGroups properly.""" + +from __future__ import annotations + +import anyio +import pytest +from pydantic import TypeAdapter + +from mcp.shared.message import SessionMessage +from mcp.shared.session import BaseSession + + +class _TestSession(BaseSession): # type: ignore[reportMissingTypeArgument] + """Test implementation of BaseSession.""" + + @property + def _receive_request_adapter(self) -> TypeAdapter[dict[str, object]]: + return TypeAdapter(dict) + + @property + def _receive_notification_adapter(self) -> TypeAdapter[dict[str, object]]: + return TypeAdapter(dict) + + +@pytest.mark.anyio +async def test_session_propagates_real_error_not_exception_group() -> None: + """Test that real errors propagate unwrapped from session task groups.""" + # Create streams + read_sender, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception]() + write_stream, write_receiver = anyio.create_memory_object_stream[SessionMessage]() + + try: + session = _TestSession( + read_stream=read_stream, + write_stream=write_stream, + read_timeout_seconds=None, + ) + + # The session's receive loop will start in __aenter__ + # If it fails with ExceptionGroup, we want only the real error + with pytest.raises(ConnectionError, match="connection failed"): + async with session: + # Raise a connection error to trigger exception group behavior + raise ConnectionError("connection failed") + + finally: + await read_sender.aclose() + await read_stream.aclose() + await write_stream.aclose() + await write_receiver.aclose() diff --git a/uv.lock b/uv.lock index d01d510f1..1b72be53a 100644 --- a/uv.lock +++ b/uv.lock @@ -529,14 +529,14 @@ wheels = [ [[package]] name = "exceptiongroup" -version = "1.3.0" +version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } +sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, + { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, ] [[package]] @@ -784,6 +784,7 @@ name = "mcp" source = { editable = "." } dependencies = [ { name = "anyio" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "httpx" }, { name = "httpx-sse" }, { name = "jsonschema" }, @@ -838,6 +839,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "anyio", specifier = ">=4.5" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'", specifier = ">=1.2.0" }, { name = "httpx", specifier = ">=0.27.1" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" },