|
5 | 5 | from collections.abc import Awaitable, Callable, Mapping |
6 | 6 | from contextlib import AsyncExitStack |
7 | 7 | from dataclasses import KW_ONLY, dataclass, field |
8 | | -from typing import Any, Literal, TypeVar |
| 8 | +from typing import Any, Literal, TypeVar, overload |
9 | 9 |
|
10 | 10 | import anyio |
11 | 11 | from typing_extensions import deprecated |
|
30 | 30 | EmptyResult, |
31 | 31 | GetPromptResult, |
32 | 32 | Implementation, |
| 33 | + InputRequiredResult, |
| 34 | + InputResponses, |
33 | 35 | ListPromptsResult, |
34 | 36 | ListResourcesResult, |
35 | 37 | ListResourceTemplatesResult, |
@@ -374,34 +376,85 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None |
374 | 376 | """Unsubscribe from resource updates.""" |
375 | 377 | return await self.session.unsubscribe_resource(uri, meta=meta) |
376 | 378 |
|
| 379 | + @overload |
377 | 380 | async def call_tool( |
378 | 381 | self, |
379 | 382 | name: str, |
380 | 383 | arguments: dict[str, Any] | None = None, |
381 | 384 | read_timeout_seconds: float | None = None, |
382 | 385 | progress_callback: ProgressFnT | None = None, |
383 | 386 | *, |
| 387 | + input_responses: InputResponses | None = None, |
| 388 | + request_state: str | None = None, |
384 | 389 | meta: RequestParamsMeta | None = None, |
385 | | - ) -> CallToolResult: |
| 390 | + allow_input_required: Literal[False] = False, |
| 391 | + ) -> CallToolResult: ... |
| 392 | + |
| 393 | + @overload |
| 394 | + async def call_tool( |
| 395 | + self, |
| 396 | + name: str, |
| 397 | + arguments: dict[str, Any] | None = None, |
| 398 | + read_timeout_seconds: float | None = None, |
| 399 | + progress_callback: ProgressFnT | None = None, |
| 400 | + *, |
| 401 | + input_responses: InputResponses | None = None, |
| 402 | + request_state: str | None = None, |
| 403 | + meta: RequestParamsMeta | None = None, |
| 404 | + allow_input_required: Literal[True], |
| 405 | + ) -> CallToolResult | InputRequiredResult: ... |
| 406 | + |
| 407 | + async def call_tool( |
| 408 | + self, |
| 409 | + name: str, |
| 410 | + arguments: dict[str, Any] | None = None, |
| 411 | + read_timeout_seconds: float | None = None, |
| 412 | + progress_callback: ProgressFnT | None = None, |
| 413 | + *, |
| 414 | + input_responses: InputResponses | None = None, |
| 415 | + request_state: str | None = None, |
| 416 | + meta: RequestParamsMeta | None = None, |
| 417 | + allow_input_required: bool = False, |
| 418 | + ) -> CallToolResult | InputRequiredResult: |
386 | 419 | """Call a tool on the server. |
387 | 420 |
|
388 | 421 | Args: |
389 | 422 | name: The name of the tool to call |
390 | 423 | arguments: Arguments to pass to the tool |
391 | 424 | read_timeout_seconds: Timeout for the tool call |
392 | 425 | progress_callback: Callback for progress updates |
| 426 | + input_responses: Responses to a prior `InputRequiredResult.input_requests` |
| 427 | + request_state: Opaque state echoed from a prior `InputRequiredResult` |
393 | 428 | meta: Additional metadata for the request |
| 429 | + allow_input_required: When ``False`` (default), an `InputRequiredResult` |
| 430 | + from the server raises `RuntimeError`; when ``True``, it is returned |
| 431 | + so the caller can resolve the requests and retry. |
394 | 432 |
|
395 | 433 | Returns: |
396 | | - The tool result. |
| 434 | + The tool result. When ``allow_input_required=True``, may instead be an |
| 435 | + `InputRequiredResult` carrying the server's input requests and opaque |
| 436 | + ``request_state`` for the retry. |
| 437 | +
|
| 438 | + Raises: |
| 439 | + RuntimeError: If the server returns an `InputRequiredResult` and |
| 440 | + ``allow_input_required`` is ``False``. |
397 | 441 | """ |
398 | | - return await self.session.call_tool( |
| 442 | + result = await self.session.call_tool( |
399 | 443 | name=name, |
400 | 444 | arguments=arguments, |
401 | 445 | read_timeout_seconds=read_timeout_seconds, |
402 | 446 | progress_callback=progress_callback, |
| 447 | + input_responses=input_responses, |
| 448 | + request_state=request_state, |
403 | 449 | meta=meta, |
404 | 450 | ) |
| 451 | + if isinstance(result, InputRequiredResult) and not allow_input_required: |
| 452 | + # TODO(L80): replace this raise with the MRTR auto-loop driver (S6). |
| 453 | + raise RuntimeError( |
| 454 | + "Server returned InputRequiredResult; pass allow_input_required=True to receive it " |
| 455 | + "and retry call_tool(..., input_responses=..., request_state=result.request_state)." |
| 456 | + ) |
| 457 | + return result |
405 | 458 |
|
406 | 459 | async def list_prompts( |
407 | 460 | self, |
|
0 commit comments