Skip to content

Commit feac453

Browse files
committed
Merge remote-tracking branch 'origin/main' into worktree-eager-growing-unicorn
# Conflicts: # tests/client/test_session.py
2 parents d1ef950 + 03681ed commit feac453

8 files changed

Lines changed: 251 additions & 18 deletions

File tree

docs/migration.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ For an in-process `Client(server)` (where `server` is a `Server` or `MCPServer`
393393

394394
`Client.send_ping()` is deprecated (ping is removed in 2026-07-28); pin `mode='legacy'` if you need it.
395395

396+
### `call_tool` can return `InputRequiredResult` (opt-in)
397+
398+
For protocol 2026-07-28, a `tools/call` request may return an `InputRequiredResult` asking the client to supply additional input and retry. By default `call_tool` (on `ClientSession`, `Client`, and `ClientSessionGroup`) still returns `CallToolResult` and raises `RuntimeError` if the server requests input. Pass `allow_input_required=True` to receive the `InputRequiredResult` instead, then retry with `input_responses=` / `request_state=`.
399+
396400
### `McpError` renamed to `MCPError`
397401

398402
The `McpError` exception class has been renamed to `MCPError` for consistent naming with the MCP acronym style used throughout the SDK.

src/mcp/client/client.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Awaitable, Callable, Mapping
66
from contextlib import AsyncExitStack
77
from dataclasses import KW_ONLY, dataclass, field
8-
from typing import Any, Literal, TypeVar
8+
from typing import Any, Literal, TypeVar, overload
99

1010
import anyio
1111
import mcp_types as types
@@ -15,6 +15,8 @@
1515
EmptyResult,
1616
GetPromptResult,
1717
Implementation,
18+
InputRequiredResult,
19+
InputResponses,
1820
ListPromptsResult,
1921
ListResourcesResult,
2022
ListResourceTemplatesResult,
@@ -374,33 +376,79 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None
374376
"""Unsubscribe from resource updates."""
375377
return await self.session.unsubscribe_resource(uri, meta=meta)
376378

379+
@overload
377380
async def call_tool(
378381
self,
379382
name: str,
380383
arguments: dict[str, Any] | None = None,
381384
read_timeout_seconds: float | None = None,
382385
progress_callback: ProgressFnT | None = None,
383386
*,
387+
input_responses: InputResponses | None = None,
388+
request_state: str | None = None,
384389
meta: RequestParamsMeta | None = None,
385-
) -> CallToolResult:
390+
allow_input_required: Literal[False] = False,
391+
) -> CallToolResult: ...
392+
393+
@overload
394+
async def call_tool(
395+
self,
396+
name: str,
397+
arguments: dict[str, Any] | None = None,
398+
read_timeout_seconds: float | None = None,
399+
progress_callback: ProgressFnT | None = None,
400+
*,
401+
input_responses: InputResponses | None = None,
402+
request_state: str | None = None,
403+
meta: RequestParamsMeta | None = None,
404+
allow_input_required: bool,
405+
) -> CallToolResult | InputRequiredResult: ...
406+
407+
async def call_tool(
408+
self,
409+
name: str,
410+
arguments: dict[str, Any] | None = None,
411+
read_timeout_seconds: float | None = None,
412+
progress_callback: ProgressFnT | None = None,
413+
*,
414+
input_responses: InputResponses | None = None,
415+
request_state: str | None = None,
416+
meta: RequestParamsMeta | None = None,
417+
allow_input_required: bool = False,
418+
) -> CallToolResult | InputRequiredResult:
386419
"""Call a tool on the server.
387420
388421
Args:
389422
name: The name of the tool to call
390423
arguments: Arguments to pass to the tool
391424
read_timeout_seconds: Timeout for the tool call
392425
progress_callback: Callback for progress updates
426+
input_responses: Responses to a prior `InputRequiredResult.input_requests`
427+
request_state: Opaque state echoed from a prior `InputRequiredResult`
393428
meta: Additional metadata for the request
429+
allow_input_required: When ``False`` (default), an `InputRequiredResult`
430+
from the server raises `RuntimeError`; when ``True``, it is returned
431+
so the caller can resolve the requests and retry.
394432
395433
Returns:
396-
The tool result.
434+
The tool result. When ``allow_input_required=True``, may instead be an
435+
`InputRequiredResult` carrying the server's input requests and opaque
436+
``request_state`` for the retry.
437+
438+
Raises:
439+
RuntimeError: If the server returns an `InputRequiredResult` and
440+
``allow_input_required`` is ``False``.
397441
"""
442+
# TODO(L84): stop forwarding allow_input_required; run the MRTR auto-loop driver here (S6).
398443
return await self.session.call_tool(
399444
name=name,
400445
arguments=arguments,
401446
read_timeout_seconds=read_timeout_seconds,
402447
progress_callback=progress_callback,
448+
input_responses=input_responses,
449+
request_state=request_state,
403450
meta=meta,
451+
allow_input_required=allow_input_required,
404452
)
405453

406454
async def list_prompts(

src/mcp/client/session.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Callable, Mapping
55
from dataclasses import dataclass
66
from types import TracebackType
7-
from typing import Any, Protocol, cast
7+
from typing import Any, Literal, Protocol, cast, overload
88

99
import anyio
1010
import anyio.abc
@@ -173,6 +173,10 @@ async def _default_logging_callback(
173173

174174
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
175175

176+
_CallToolResultAdapter: TypeAdapter[types.CallToolResult | types.InputRequiredResult] = TypeAdapter(
177+
types.CallToolResult | types.InputRequiredResult
178+
)
179+
176180

177181
class ClientSession:
178182
"""Client half of an MCP connection, running on a `Dispatcher`.
@@ -269,7 +273,7 @@ async def __aexit__(
269273
async def send_request(
270274
self,
271275
request: types.ClientRequest,
272-
result_type: type[ReceiveResultT],
276+
result_type: type[ReceiveResultT] | TypeAdapter[ReceiveResultT],
273277
request_read_timeout_seconds: float | None = None,
274278
metadata: ClientMessageMetadata | None = None,
275279
progress_callback: ProgressFnT | None = None,
@@ -308,6 +312,8 @@ async def send_request(
308312
_methods.validate_server_result(method, version, raw)
309313
except KeyError:
310314
pass
315+
if isinstance(result_type, TypeAdapter):
316+
return result_type.validate_python(raw, by_name=False)
311317
return result_type.model_validate(raw, by_name=False)
312318

313319
async def send_notification(self, notification: types.ClientNotification) -> None:
@@ -596,29 +602,83 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None
596602
types.EmptyResult,
597603
)
598604

605+
@overload
606+
async def call_tool(
607+
self,
608+
name: str,
609+
arguments: dict[str, Any] | None = None,
610+
read_timeout_seconds: float | None = None,
611+
progress_callback: ProgressFnT | None = None,
612+
*,
613+
input_responses: types.InputResponses | None = None,
614+
request_state: str | None = None,
615+
meta: RequestParamsMeta | None = None,
616+
allow_input_required: Literal[False] = False,
617+
) -> types.CallToolResult: ...
618+
619+
@overload
620+
async def call_tool(
621+
self,
622+
name: str,
623+
arguments: dict[str, Any] | None = None,
624+
read_timeout_seconds: float | None = None,
625+
progress_callback: ProgressFnT | None = None,
626+
*,
627+
input_responses: types.InputResponses | None = None,
628+
request_state: str | None = None,
629+
meta: RequestParamsMeta | None = None,
630+
allow_input_required: bool,
631+
) -> types.CallToolResult | types.InputRequiredResult: ...
632+
599633
async def call_tool(
600634
self,
601635
name: str,
602636
arguments: dict[str, Any] | None = None,
603637
read_timeout_seconds: float | None = None,
604638
progress_callback: ProgressFnT | None = None,
605639
*,
640+
input_responses: types.InputResponses | None = None,
641+
request_state: str | None = None,
606642
meta: RequestParamsMeta | None = None,
607-
) -> types.CallToolResult:
608-
"""Send a tools/call request with optional progress callback support."""
643+
allow_input_required: bool = False,
644+
) -> types.CallToolResult | types.InputRequiredResult:
645+
"""Send a tools/call request with optional progress callback support.
646+
647+
Args:
648+
input_responses: Responses to a prior `InputRequiredResult.input_requests`.
649+
request_state: Opaque state echoed from a prior `InputRequiredResult`.
650+
allow_input_required: When ``False`` (default), an `InputRequiredResult`
651+
from the server raises `RuntimeError`; when ``True``, it is returned
652+
so the caller can resolve the requests and retry.
653+
654+
Raises:
655+
RuntimeError: If the server returns an `InputRequiredResult` and
656+
``allow_input_required`` is ``False``.
657+
"""
609658

610659
result = await self.send_request(
611660
types.CallToolRequest(
612-
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=meta),
661+
params=types.CallToolRequestParams(
662+
name=name,
663+
arguments=arguments,
664+
input_responses=input_responses,
665+
request_state=request_state,
666+
_meta=meta,
667+
),
613668
),
614-
types.CallToolResult,
669+
_CallToolResultAdapter,
615670
request_read_timeout_seconds=read_timeout_seconds,
616671
progress_callback=progress_callback,
617672
)
618673

619-
if not result.is_error:
674+
if isinstance(result, types.CallToolResult) and not result.is_error:
620675
await self._validate_tool_result(name, result)
621676

677+
if isinstance(result, types.InputRequiredResult) and not allow_input_required:
678+
raise RuntimeError(
679+
"Server returned InputRequiredResult; pass allow_input_required=True to receive it "
680+
"and retry call_tool(..., input_responses=..., request_state=result.request_state)."
681+
)
622682
return result
623683

624684
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:

src/mcp/client/session_group.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Callable
1212
from dataclasses import dataclass
1313
from types import TracebackType
14-
from typing import Any, TypeAlias
14+
from typing import Any, Literal, TypeAlias, overload
1515

1616
import anyio
1717
import httpx
@@ -190,24 +190,63 @@ def tools(self) -> dict[str, types.Tool]:
190190
"""Returns the tools as a dictionary of names to tools."""
191191
return self._tools
192192

193+
@overload
193194
async def call_tool(
194195
self,
195196
name: str,
196197
arguments: dict[str, Any] | None = None,
197198
read_timeout_seconds: float | None = None,
198199
progress_callback: ProgressFnT | None = None,
199200
*,
201+
input_responses: types.InputResponses | None = None,
202+
request_state: str | None = None,
200203
meta: types.RequestParamsMeta | None = None,
201-
) -> types.CallToolResult:
202-
"""Executes a tool given its name and arguments."""
204+
allow_input_required: Literal[False] = False,
205+
) -> types.CallToolResult: ...
206+
207+
@overload
208+
async def call_tool(
209+
self,
210+
name: str,
211+
arguments: dict[str, Any] | None = None,
212+
read_timeout_seconds: float | None = None,
213+
progress_callback: ProgressFnT | None = None,
214+
*,
215+
input_responses: types.InputResponses | None = None,
216+
request_state: str | None = None,
217+
meta: types.RequestParamsMeta | None = None,
218+
allow_input_required: bool,
219+
) -> types.CallToolResult | types.InputRequiredResult: ...
220+
221+
async def call_tool(
222+
self,
223+
name: str,
224+
arguments: dict[str, Any] | None = None,
225+
read_timeout_seconds: float | None = None,
226+
progress_callback: ProgressFnT | None = None,
227+
*,
228+
input_responses: types.InputResponses | None = None,
229+
request_state: str | None = None,
230+
meta: types.RequestParamsMeta | None = None,
231+
allow_input_required: bool = False,
232+
) -> types.CallToolResult | types.InputRequiredResult:
233+
"""Executes a tool given its name and arguments.
234+
235+
Raises:
236+
RuntimeError: If the server returns an `InputRequiredResult` and
237+
``allow_input_required`` is ``False``.
238+
"""
203239
session = self._tool_to_session[name]
204240
session_tool_name = self.tools[name].name
205241
return await session.call_tool(
206242
session_tool_name,
207243
arguments=arguments,
208244
read_timeout_seconds=read_timeout_seconds,
209245
progress_callback=progress_callback,
246+
input_responses=input_responses,
247+
request_state=request_state,
210248
meta=meta,
249+
allow_input_required=allow_input_required,
211250
)
212251

213252
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:

src/mcp/server/_otel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,12 @@ async def __call__(self, ctx: ServerRequestContext[Any, Any], call_next: CallNex
5959
span.set_status(StatusCode.ERROR, str(e))
6060
raise
6161
if ctx.method == "tools/call":
62+
# Tool errors are detected pre-serialization, so only shapes that reach the wire as an error
63+
# count: the model, or the camelCase alias (`is_error` is dropped by the alias-only wire
64+
# validation). A raw-dict `isError` is matched as a literal bool only - non-bool coercible
65+
# values (1, "true") would serialize to an error but are rare enough to leave undetected.
6266
match result:
63-
case CallToolResult(is_error=True) | {"isError": True} | {"is_error": True}:
67+
case CallToolResult(is_error=True) | {"isError": True}:
6468
span.set_attribute("error.type", "tool_error")
6569
span.set_status(StatusCode.ERROR)
6670
case _:

tests/client/test_session.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737

3838
from mcp import MCPError
3939
from mcp.client import ClientRequestContext
40+
from mcp.client.client import Client
4041
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
42+
from mcp.server import Server, ServerRequestContext
4143
from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair
4244
from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest
4345
from mcp.shared.message import SessionMessage
@@ -1657,3 +1659,59 @@ async def test_discover_reraises_unsupported_version_with_malformed_error_data()
16571659
await session.discover()
16581660
assert exc.value.error.code == UNSUPPORTED_PROTOCOL_VERSION
16591661
assert [m for m, _ in dispatcher.calls] == ["server/discover"]
1662+
1663+
1664+
@pytest.mark.anyio
1665+
async def test_call_tool_returns_input_required_result_when_server_requests_input() -> None:
1666+
# `on_call_tool` is still typed `-> CallToolResult` on this branch (#2967 widens it later);
1667+
# `add_request_handler` is `HandlerResult`-typed and accepts `InputRequiredResult` cleanly.
1668+
async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult:
1669+
return types.InputRequiredResult(request_state="s")
1670+
1671+
server = Server("test")
1672+
server.add_request_handler("tools/call", types.CallToolRequestParams, handler)
1673+
with anyio.fail_after(5):
1674+
async with Client(server, mode="2026-07-28") as client:
1675+
result = await client.call_tool("ask", allow_input_required=True)
1676+
assert isinstance(result, types.InputRequiredResult)
1677+
assert result.request_state == "s"
1678+
1679+
1680+
@pytest.mark.anyio
1681+
async def test_call_tool_threads_input_responses_and_request_state_into_params() -> None:
1682+
captured: list[types.CallToolRequestParams] = []
1683+
1684+
async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
1685+
captured.append(params)
1686+
return CallToolResult(content=[])
1687+
1688+
async def on_list_tools(
1689+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
1690+
) -> types.ListToolsResult:
1691+
return types.ListToolsResult(tools=[])
1692+
1693+
server = Server("test", on_call_tool=on_call_tool, on_list_tools=on_list_tools)
1694+
with anyio.fail_after(5):
1695+
async with Client(server, mode="2026-07-28") as client:
1696+
await client.call_tool(
1697+
"ask",
1698+
input_responses={"k": types.ElicitResult(action="decline")},
1699+
request_state="s",
1700+
)
1701+
assert captured[0].input_responses == {"k": types.ElicitResult(action="decline")}
1702+
assert captured[0].request_state == "s"
1703+
1704+
1705+
@pytest.mark.anyio
1706+
async def test_client_call_tool_raises_on_input_required_without_opt_in() -> None:
1707+
async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult:
1708+
return types.InputRequiredResult(request_state="s")
1709+
1710+
server = Server("test")
1711+
server.add_request_handler("tools/call", types.CallToolRequestParams, handler)
1712+
with anyio.fail_after(5):
1713+
async with Client(server, mode="2026-07-28") as client:
1714+
with pytest.raises(RuntimeError):
1715+
await client.call_tool("t")
1716+
result = await client.call_tool("t", allow_input_required=True)
1717+
assert isinstance(result, types.InputRequiredResult)

0 commit comments

Comments
 (0)