diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 53fc53a1f..445e8e4cf 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -25,7 +25,7 @@ async def __call__( self, context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams, - ) -> types.CreateMessageResult | types.ErrorData: ... # pragma: no branch + ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch class ElicitationFnT(Protocol): @@ -65,7 +65,7 @@ async def _default_message_handler( async def _default_sampling_callback( context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams, -) -> types.CreateMessageResult | types.ErrorData: +) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, message="Sampling not supported", @@ -115,6 +115,7 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, + sampling_capabilities: types.SamplingCapability | None = None, elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, @@ -132,6 +133,7 @@ def __init__( ) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback + self._sampling_capabilities = sampling_capabilities self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback @@ -144,7 +146,11 @@ def __init__( self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() async def initialize(self) -> types.InitializeResult: - sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None + sampling = ( + (self._sampling_capabilities or types.SamplingCapability()) + if self._sampling_callback is not _default_sampling_callback + else None + ) elicitation = ( types.ElicitationCapability( form=types.FormElicitationCapability(), diff --git a/src/mcp/types.py b/src/mcp/types.py index 9ca8ffc18..654c00660 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1928,6 +1928,7 @@ class ElicitationRequiredErrorData(BaseModel): ClientResultType: TypeAlias = ( EmptyResult | CreateMessageResult + | CreateMessageResultWithTools | ListRootsResult | ElicitResult | GetTaskResult diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 8d0ef68a9..eb2683fbd 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -497,6 +497,8 @@ async def mock_server(): # Custom sampling callback provided assert received_capabilities.sampling is not None assert isinstance(received_capabilities.sampling, types.SamplingCapability) + # Default sampling capabilities (no tools) + assert received_capabilities.sampling.tools is None # Custom list_roots callback provided assert received_capabilities.roots is not None assert isinstance(received_capabilities.roots, types.RootsCapability) @@ -504,6 +506,84 @@ async def mock_server(): assert received_capabilities.roots.listChanged is True +@pytest.mark.anyio +async def test_client_capabilities_with_sampling_tools(): + """Test that sampling capabilities with tools are properly advertised""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities = None + + async def custom_sampling_callback( # pragma: no cover + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult | types.ErrorData: + return types.CreateMessageResult( + role="assistant", + content=types.TextContent(type="text", text="test"), + model="test-model", + ) + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + sampling_callback=custom_sampling_callback, + sampling_capabilities=types.SamplingCapability(tools=types.SamplingToolsCapability()), + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that sampling capabilities with tools are properly advertised + assert received_capabilities is not None + assert received_capabilities.sampling is not None + assert isinstance(received_capabilities.sampling, types.SamplingCapability) + # Tools capability should be present + assert received_capabilities.sampling.tools is not None + assert isinstance(received_capabilities.sampling.tools, types.SamplingToolsCapability) + + @pytest.mark.anyio async def test_get_server_capabilities(): """Test that get_server_capabilities returns None before init and capabilities after"""