From e58ff0288a67f9cf9ae4822f82fcd30e55c00863 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 25 Jun 2026 15:44:29 +0200 Subject: [PATCH 1/8] Add resolver dependency injection for MCPServer tools A tool parameter annotated `Annotated[T, Resolve(fn)]` is filled by running the resolver `fn` before the tool body, instead of by the calling LLM. Resolvers form a dependency graph: a resolver may declare its own `Resolve(...)` dependencies, read the `Context` (including the new `Context.headers`), and receive the tool's own arguments by name. A resolver may return `Elicit[T]` to ask the client; the SDK runs the elicitation and injects the answer. Each resolver runs at most once per `tools/call`. The injected type follows the consumer's annotation: the unwrapped model aborts the call on decline/cancel, while the elicitation result union lets the consumer branch on the outcome. Resolved parameters are omitted from the tool's input schema; unclassifiable resolver parameters and cyclic resolver dependencies raise at registration time. --- docs/migration.md | 58 ++++ src/mcp/server/mcpserver/__init__.py | 22 +- src/mcp/server/mcpserver/context.py | 20 +- src/mcp/server/mcpserver/resolve.py | 256 +++++++++++++++++ src/mcp/server/mcpserver/tools/base.py | 36 ++- .../mcpserver/utilities/func_metadata.py | 14 +- tests/server/mcpserver/test_resolve.py | 265 ++++++++++++++++++ 7 files changed, 663 insertions(+), 8 deletions(-) create mode 100644 src/mcp/server/mcpserver/resolve.py create mode 100644 tests/server/mcpserver/test_resolve.py diff --git a/docs/migration.md b/docs/migration.md index bf06690c45..873c2f0026 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1396,6 +1396,64 @@ app = server.streamable_http_app( The lowlevel `Server` also now exposes a `session_manager` property to access the `StreamableHTTPSessionManager` after calling `streamable_http_app()`. +### Resolver dependency injection for tools (`Resolve` / `Elicit`) + +A tool parameter annotated `Annotated[T, Resolve(fn)]` is filled by running the resolver `fn` before the tool body, instead of by the calling LLM. Resolvers form a dependency graph: a resolver may declare its own `Resolve(...)` dependencies, read the `Context` (including `ctx.headers`), and receive the tool's own arguments by name. A resolver may return `Elicit[T]` to ask the client; the SDK runs the elicitation and injects the answer. Each resolver runs at most once per `tools/call`. + +```python +from typing import Annotated + +from pydantic import BaseModel + +from mcp.server.mcpserver import AcceptedElicitation, Context, Elicit, MCPServer, Resolve + +mcp = MCPServer(name="github") + + +class Login(BaseModel): + username: str + + +class Confirm(BaseModel): + ok: bool + + +async def login(ctx: Context) -> Login | Elicit[Login]: + if username := (ctx.headers or {}).get("x-github-user"): + return Login(username=username) # resolved from context, no question + return Elicit("GitHub username?", Login) # must ask + + +async def confirm(repo: str, login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"Star {repo} as {login.username}?", Confirm) + + +@mcp.tool() +async def star_repo( + repo: str, + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], +) -> str: + """Star a GitHub repo.""" + return f"starred {repo} as {login.username}" if confirm.ok else "cancelled" +``` + +The injected type follows the consumer's annotation. Annotating the unwrapped model (`Annotated[Login, Resolve(login)]`) injects the model on accept and aborts the call with an error result on decline or cancel. To branch on the outcome instead, annotate the elicitation result union: + +```python +@mcp.tool() +async def whoami( + login: Annotated[AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation, Resolve(login)], +) -> str: + match login: + case AcceptedElicitation(data=data): + return f"hi {data.username}" + case _: + return "no username provided" +``` + +Resolved parameters are omitted from the tool's input schema, so the client never supplies them. Resolver parameters that cannot be classified, and cyclic resolver dependencies, raise at registration time. + ## Need Help? If you encounter issues during migration: diff --git a/src/mcp/server/mcpserver/__init__.py b/src/mcp/server/mcpserver/__init__.py index 0857e38bd4..c6bc3d5b00 100644 --- a/src/mcp/server/mcpserver/__init__.py +++ b/src/mcp/server/mcpserver/__init__.py @@ -3,7 +3,27 @@ from mcp.types import Icon from .context import Context +from .resolve import ( + AcceptedElicitation, + CancelledElicitation, + DeclinedElicitation, + Elicit, + ElicitationResult, + Resolve, +) from .server import MCPServer from .utilities.types import Audio, Image -__all__ = ["MCPServer", "Context", "Image", "Audio", "Icon"] +__all__ = [ + "MCPServer", + "Context", + "Image", + "Audio", + "Icon", + "Resolve", + "Elicit", + "ElicitationResult", + "AcceptedElicitation", + "DeclinedElicitation", + "CancelledElicitation", +] diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 0bf0b7ebfd..ce0624df2a 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Generic +from collections.abc import Iterable, Mapping +from typing import TYPE_CHECKING, Any, Generic, Protocol, cast from pydantic import AnyUrl, BaseModel from typing_extensions import deprecated @@ -22,6 +22,11 @@ from mcp.server.mcpserver.server import MCPServer +class _HasHeaders(Protocol): + @property + def headers(self) -> Mapping[str, str]: ... + + class Context(BaseModel, Generic[LifespanContextT, RequestT]): """Context object providing access to MCP capabilities. @@ -225,6 +230,17 @@ def client_id(self) -> str | None: """ return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover + @property + def headers(self) -> Mapping[str, str] | None: + """Request headers carried by this message, when the transport has them. + + Populated by HTTP-based transports; `None` on stdio. + """ + request = self.request_context.request + if request is None: + return None + return cast("_HasHeaders", request).headers + @property def request_id(self) -> str: """Get the unique ID for this request.""" diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py new file mode 100644 index 0000000000..8cab59c4d2 --- /dev/null +++ b/src/mcp/server/mcpserver/resolve.py @@ -0,0 +1,256 @@ +"""Resolver dependency injection for MCPServer tools. + +A tool parameter annotated `Annotated[T, Resolve(fn)]` is filled by running the +resolver `fn` before the tool body, instead of from the LLM-supplied arguments. +Resolvers form a DAG: a resolver may declare its own `Resolve(...)` dependencies, +take tool arguments by name, and take the `Context`. A resolver may return +`Elicit[T]` to ask the client; the framework runs the elicitation and injects the +answer. + +Whether the consumer receives the unwrapped model or the full +`ElicitationResult` union is decided by the consumer's annotation: + +- `Annotated[T, Resolve(fn)]` -> unwrapped `T`; decline/cancel aborts the call. +- `Annotated[ElicitationResult[T], Resolve(fn)]` (or a specific member) -> the + full outcome; the consumer branches on accept/decline/cancel. + +Each resolver runs at most once per `tools/call` (memoized by function identity). +""" + +from __future__ import annotations + +import inspect +import typing +from collections.abc import Callable, Mapping +from typing import Annotated, Any, Generic, cast, get_args, get_origin + +import anyio.to_thread +from pydantic import BaseModel +from typing_extensions import TypeVar + +from mcp.server.elicitation import ( + AcceptedElicitation, + CancelledElicitation, + DeclinedElicitation, + ElicitationResult, +) +from mcp.server.mcpserver.context import Context +from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError +from mcp.shared._callable_inspection import is_async_callable + +T = TypeVar("T", bound=BaseModel) + +# The union members the framework injects when a consumer opts into the outcome. +_ELICITATION_RESULT_MEMBERS = (AcceptedElicitation, DeclinedElicitation, CancelledElicitation) + + +class Resolve: + """Marker for `Annotated[T, Resolve(fn)]`: fill the parameter by running `fn`.""" + + def __init__(self, fn: Callable[..., Any]) -> None: + self.fn = fn + + +class Elicit(Generic[T]): + """A resolver's request to ask the client. + + Returned from a resolver to signal that the value must be elicited. The + framework runs `ctx.elicit(message, schema)` and injects the outcome. + """ + + def __init__(self, message: str, schema: type[T]) -> None: + self.message = message + self.schema = schema + + +class _ParamPlan: + """How to fill one resolver parameter, decided once at registration.""" + + kind: str # "context" | "resolve" | "by_name" + resolve: Resolve | None + wants_union: bool + + def __init__(self, kind: str, resolve: Resolve | None = None, wants_union: bool = False) -> None: + self.kind = kind + self.resolve = resolve + self.wants_union = wants_union + + +class _ResolverPlan: + """A resolver's parameters and whether it is async, analyzed once.""" + + def __init__(self, fn: Callable[..., Any], params: dict[str, _ParamPlan], is_async: bool) -> None: + self.fn = fn + self.params = params + self.is_async = is_async + + +def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, bool]]: + """Find parameters of `fn` annotated `Annotated[_, Resolve(...)]`. + + Returns a mapping of parameter name to `(Resolve, wants_union)`, where + `wants_union` is True when the annotated type is an `ElicitationResult` member + (the consumer wants the full outcome rather than the unwrapped model). + """ + hints = typing.get_type_hints(fn, include_extras=True) + resolved: dict[str, tuple[Resolve, bool]] = {} + for name, annotation in hints.items(): + if get_origin(annotation) is not Annotated: + continue + type_arg, *metadata = get_args(annotation) + marker = next((m for m in metadata if isinstance(m, Resolve)), None) + if marker is not None: + resolved[name] = (marker, _wants_union(type_arg)) + return resolved + + +def _wants_union(type_arg: Any) -> bool: + """True when `type_arg` is an `ElicitationResult` member (or a union of them).""" + members = get_args(type_arg) if get_origin(type_arg) is not None else (type_arg,) + return any(isinstance(m, type) and issubclass(m, _ELICITATION_RESULT_MEMBERS) for m in members) + + +def build_resolver_plans( + resolved_params: Mapping[str, tuple[Resolve, bool]], + tool_arg_names: set[str], +) -> dict[int, _ResolverPlan]: + """Statically analyze the resolver DAG rooted at a tool's resolved parameters. + + Raises: + InvalidSignature: If a resolver has a cyclic dependency, or a resolver + parameter cannot be classified (not a `Context`, a nested `Resolve`, + or a tool argument by name). + """ + plans: dict[int, _ResolverPlan] = {} + + def analyze(fn: Callable[..., Any], stack: tuple[int, ...]) -> None: + key = id(fn) + if key in stack: + raise InvalidSignature(f"Resolver {fn.__name__!r} has a cyclic dependency") + if key in plans: + return + + hints = typing.get_type_hints(fn, include_extras=True) + sig = inspect.signature(fn) + params: dict[str, _ParamPlan] = {} + nested: list[Callable[..., Any]] = [] + for param_name in sig.parameters: + annotation = hints.get(param_name) + if annotation is not None and _is_context_annotation(annotation): + params[param_name] = _ParamPlan("context") + continue + marker, wants_union = _resolve_marker(annotation) + if marker is not None: + params[param_name] = _ParamPlan("resolve", marker, wants_union) + nested.append(marker.fn) + continue + if param_name in tool_arg_names: + params[param_name] = _ParamPlan("by_name") + continue + raise InvalidSignature( + f"Resolver {fn.__name__!r} parameter {param_name!r} cannot be resolved: " + "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" + ) + + plans[key] = _ResolverPlan(fn, params, is_async_callable(fn)) + for dep in nested: + analyze(dep, stack + (key,)) + + for marker, _ in resolved_params.values(): + analyze(marker.fn, ()) + return plans + + +def _resolve_marker(annotation: Any) -> tuple[Resolve | None, bool]: + if get_origin(annotation) is not Annotated: + return None, False + type_arg, *metadata = get_args(annotation) + marker = next((m for m in metadata if isinstance(m, Resolve)), None) + return marker, (_wants_union(type_arg) if marker is not None else False) + + +def _is_context_annotation(annotation: Any) -> bool: + if get_origin(annotation) is Annotated: + annotation = get_args(annotation)[0] + return isinstance(annotation, type) and issubclass(annotation, Context) + + +async def resolve_arguments( + resolved_params: Mapping[str, tuple[Resolve, bool]], + plans: Mapping[int, _ResolverPlan], + tool_args: Mapping[str, Any], + context: Context[Any, Any], +) -> dict[str, Any]: + """Resolve every `Resolve`-marked tool parameter into a concrete value. + + Each resolver runs at most once (memoized by function identity). Returns a + mapping of tool parameter name to the value to inject. + + Raises: + ToolError: If an elicited value is declined or cancelled and the consumer + asked for the unwrapped model (rather than the result union). + """ + cache: dict[int, ElicitationResult[BaseModel]] = {} + injected: dict[str, Any] = {} + for name, (marker, wants_union) in resolved_params.items(): + outcome = await _resolve(marker.fn, plans, tool_args, context, cache) + injected[name] = outcome if wants_union else _unwrap(outcome, name) + return injected + + +async def _resolve( + fn: Callable[..., Any], + plans: Mapping[int, _ResolverPlan], + tool_args: Mapping[str, Any], + context: Context[Any, Any], + cache: dict[int, ElicitationResult[BaseModel]], +) -> ElicitationResult[BaseModel]: + key = id(fn) + if key in cache: + return cache[key] + + plan = plans[key] + kwargs: dict[str, Any] = {} + for param_name, param_plan in plan.params.items(): + if param_plan.kind == "context": + kwargs[param_name] = context + elif param_plan.kind == "by_name": + kwargs[param_name] = tool_args[param_name] + else: + assert param_plan.resolve is not None + dep_outcome = await _resolve(param_plan.resolve.fn, plans, tool_args, context, cache) + kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name) + + if plan.is_async: + result = await fn(**kwargs) + else: + result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) + + outcome: ElicitationResult[BaseModel] + if isinstance(result, Elicit): + elicit = cast("Elicit[BaseModel]", result) + outcome = await context.elicit(elicit.message, elicit.schema) + else: + outcome = AcceptedElicitation(data=result) + + cache[key] = outcome + return outcome + + +def _unwrap(outcome: ElicitationResult[BaseModel], name: str) -> BaseModel: + if isinstance(outcome, AcceptedElicitation): + return outcome.data + raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}") + + +__all__ = [ + "Resolve", + "Elicit", + "ElicitationResult", + "AcceptedElicitation", + "DeclinedElicitation", + "CancelledElicitation", + "find_resolved_parameters", + "build_resolver_plans", + "resolve_arguments", +] diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 29894d7d1d..a3cae05afb 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -7,6 +7,11 @@ from pydantic import BaseModel, Field from mcp.server.mcpserver.exceptions import ToolError +from mcp.server.mcpserver.resolve import ( + build_resolver_plans, + find_resolved_parameters, + resolve_arguments, +) from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata from mcp.shared._callable_inspection import is_async_callable @@ -32,6 +37,14 @@ class Tool(BaseModel): ) is_async: bool = Field(description="Whether the tool is async") context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") + resolved_params: dict[str, Any] = Field( + default_factory=lambda: {}, + exclude=True, + description="Parameters filled by resolvers, mapped to (Resolve, wants_union)", + ) + resolver_plans: dict[int, Any] = Field( + default_factory=lambda: {}, exclude=True, description="Static per-resolver parameter plans" + ) annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool") meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this tool") @@ -67,13 +80,23 @@ def from_function( if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) + resolved_params = find_resolved_parameters(fn) + + skip_names = [context_kwarg] if context_kwarg is not None else [] + skip_names.extend(resolved_params) + func_arg_metadata = func_metadata( fn, - skip_names=[context_kwarg] if context_kwarg is not None else [], + skip_names=skip_names, structured_output=structured_output, ) parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True) + tool_arg_names = set(func_arg_metadata.arg_model.model_fields) | { + field.alias for field in func_arg_metadata.arg_model.model_fields.values() if field.alias + } + resolver_plans = build_resolver_plans(resolved_params, tool_arg_names) + return cls( fn=fn, name=func_name, @@ -83,6 +106,8 @@ def from_function( fn_metadata=func_arg_metadata, is_async=is_async, context_kwarg=context_kwarg, + resolved_params=dict(resolved_params), + resolver_plans=resolver_plans, annotations=annotations, icons=icons, meta=meta, @@ -100,11 +125,18 @@ async def run( ToolError: If the tool function raises during execution. """ try: + pass_directly: dict[str, Any] = {} + if self.context_kwarg is not None: + pass_directly[self.context_kwarg] = context + if self.resolved_params: + tool_args = self.fn_metadata.validate_arguments(arguments) + pass_directly |= await resolve_arguments(self.resolved_params, self.resolver_plans, tool_args, context) + result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, self.is_async, arguments, - {self.context_kwarg: context} if self.context_kwarg is not None else None, + pass_directly or None, ) if convert_result: diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 6c553fbab9..53284c43b2 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -65,6 +65,16 @@ class FuncMetadata(BaseModel): output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None wrap_output: bool = False + def validate_arguments(self, arguments_to_validate: dict[str, Any]) -> dict[str, Any]: + """Validate raw arguments into a one-level kwargs dict (no function call). + + Used to feed resolver dependency injection the validated tool arguments + before the tool function itself runs. + """ + arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) + arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) + return arguments_parsed_model.model_dump_one_level() + async def call_fn_with_arg_validation( self, fn: Callable[..., Any | Awaitable[Any]], @@ -77,9 +87,7 @@ async def call_fn_with_arg_validation( Arguments are first attempted to be parsed from JSON, then validated against the argument model, before being passed to the function. """ - arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) - arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) - arguments_parsed_dict = arguments_parsed_model.model_dump_one_level() + arguments_parsed_dict = self.validate_arguments(arguments_to_validate) arguments_parsed_dict |= arguments_to_pass_directly or {} diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py new file mode 100644 index 0000000000..5b62bd310f --- /dev/null +++ b/tests/server/mcpserver/test_resolve.py @@ -0,0 +1,265 @@ +"""Tests for resolver dependency injection (MRTR) on MCPServer tools.""" + +from typing import Annotated + +import pytest +from pydantic import BaseModel, Field + +from mcp import Client +from mcp.client import ClientRequestContext +from mcp.server.mcpserver import ( + AcceptedElicitation, + CancelledElicitation, + Context, + DeclinedElicitation, + Elicit, + MCPServer, + Resolve, +) +from mcp.server.mcpserver.exceptions import InvalidSignature +from mcp.server.mcpserver.tools.base import Tool +from mcp.types import ElicitRequestParams, ElicitResult, TextContent + + +class Login(BaseModel): + username: str + + +class Confirm(BaseModel): + ok: bool + + +def _accept(content: dict[str, str | int | float | bool | list[str] | None]): + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content=content) + + return callback + + +async def _decline(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + +async def _text(client: Client, tool: str, args: dict[str, object]) -> str: + result = await client.call_tool(tool, args) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + return result.content[0].text + + +@pytest.mark.anyio +async def test_resolver_returns_value_directly_without_eliciting(): + mcp = MCPServer(name="Direct") + + async def login(ctx: Context) -> Login | Elicit[Login]: + username = (ctx.headers or {}).get("x-github-user") + if username: # pragma: no cover - no headers on in-memory transport + return Login(username=username) + return Login(username="from-resolver") + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + return login.username + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, elicitation_callback=never) as client: + assert await _text(client, "whoami", {}) == "from-resolver" + + +@pytest.mark.anyio +async def test_resolver_elicits_and_injects_unwrapped_model_on_accept(): + mcp = MCPServer(name="Accept") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + return login.username + + async with Client(mcp, elicitation_callback=_accept({"username": "octocat"})) as client: + assert await _text(client, "whoami", {}) == "octocat" + + +@pytest.mark.anyio +async def test_consumer_receives_result_union_and_branches(): + mcp = MCPServer(name="Union") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami( + login: Annotated[AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation, Resolve(login)], + ) -> str: + match login: + case AcceptedElicitation(data=data): + return f"hi {data.username}" + case _: # pragma: no cover - accepted in this test + return "no username" + + async with Client(mcp, elicitation_callback=_accept({"username": "octocat"})) as client: + assert await _text(client, "whoami", {}) == "hi octocat" + + +@pytest.mark.anyio +async def test_decline_reaches_union_consumer_without_aborting(): + mcp = MCPServer(name="UnionDecline") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami( + login: Annotated[AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation, Resolve(login)], + ) -> str: + if isinstance(login, DeclinedElicitation): + return "declined gracefully" + raise NotImplementedError + + async with Client(mcp, elicitation_callback=_decline) as client: + assert await _text(client, "whoami", {}) == "declined gracefully" + + +@pytest.mark.anyio +async def test_decline_aborts_when_consumer_wants_unwrapped(): + mcp = MCPServer(name="UnwrappedDecline") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + raise NotImplementedError # pragma: no cover - never reached + + async with Client(mcp, elicitation_callback=_decline) as client: + result = await client.call_tool("whoami", {}) + assert result.is_error + assert isinstance(result.content[0], TextContent) + assert "decline" in result.content[0].text + + +@pytest.mark.anyio +async def test_nested_resolver_sees_dependency_and_tool_args(): + mcp = MCPServer(name="Nested") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + async def confirm(repo: str, login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"Star {repo} as {login.username}?", Confirm) + + @mcp.tool() + async def star_repo( + repo: str, + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + if confirm.ok: + return f"starred {repo} as {login.username}" + raise NotImplementedError + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + if "username" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + assert "Star modelcontextprotocol/python-sdk as octocat?" in params.message + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + text = await _text(client, "star_repo", {"repo": "modelcontextprotocol/python-sdk"}) + assert text == "starred modelcontextprotocol/python-sdk as octocat" + + +@pytest.mark.anyio +async def test_resolver_runs_once_for_two_consumers(): + mcp = MCPServer(name="ExactlyOnce") + elicit_count = 0 + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"As {login.username}?", Confirm) + + @mcp.tool() + async def star_repo( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + nonlocal elicit_count + if "username" in params.message: + elicit_count += 1 + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + assert await _text(client, "star_repo", {}) == "octocat:True" + assert elicit_count == 1 + + +@pytest.mark.anyio +async def test_sync_resolver(): + mcp = MCPServer(name="Sync") + + def login(ctx: Context) -> Login: + return Login(username="sync-user") + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + return login.username + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, elicitation_callback=never) as client: + assert await _text(client, "whoami", {}) == "sync-user" + + +def test_resolved_params_absent_from_input_schema(): + async def login(ctx: Context) -> Login: + return Login(username="x") + + async def tool( + repo: Annotated[str, Field(description="repo name")], + login: Annotated[Login, Resolve(login)], + ) -> str: + return repo + + built = Tool.from_function(tool) + properties = built.parameters["properties"] + assert "repo" in properties + assert "login" not in properties + + +def test_cycle_detection_raises_at_registration(): + async def a(dep: Login) -> Login: + return dep # pragma: no cover + + async def b(dep: Login) -> Login: + return dep # pragma: no cover + + # Close the loop after both exist: a depends on b, b depends on a. + a.__annotations__["dep"] = Annotated[Login, Resolve(b)] + b.__annotations__["dep"] = Annotated[Login, Resolve(a)] + + async def tool(value: Annotated[Login, Resolve(a)]) -> str: + return value.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="cyclic"): + Tool.from_function(tool) + + +def test_unresolvable_resolver_param_raises_at_registration(): + async def login(mystery: int) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(login)]) -> str: + return login.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="cannot be resolved"): + Tool.from_function(tool) From e1100939bb04ce52adf5ad3254d43498d33b22ef Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 25 Jun 2026 16:27:49 +0200 Subject: [PATCH 2/8] Cover Context.headers and resolver schema-only paths The headers property's request-present branch and the schema-inspection helpers in the resolver tests were not exercised, breaking the 100% coverage gate. Add direct Context.headers tests and mark the never-run helper bodies. --- tests/server/mcpserver/test_resolve.py | 4 ++-- tests/server/mcpserver/test_server.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 5b62bd310f..a58e200065 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -222,13 +222,13 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E def test_resolved_params_absent_from_input_schema(): async def login(ctx: Context) -> Login: - return Login(username="x") + return Login(username="x") # pragma: no cover - only the schema is inspected async def tool( repo: Annotated[str, Field(description="repo name")], login: Annotated[Login, Resolve(login)], ) -> str: - return repo + return repo # pragma: no cover - only the schema is inspected built = Tool.from_function(tool) properties = built.parameters["properties"] diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 554fe50215..72ab2f48d7 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1,5 +1,6 @@ import base64 from pathlib import Path +from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -1547,6 +1548,27 @@ async def test_report_progress_passes_related_request_id(): ) +def _request_context(request: object | None) -> ServerRequestContext[None, object]: + return ServerRequestContext( + session=AsyncMock(), + method="tools/call", + lifespan_context=None, + protocol_version="2025-11-25", + request=request, + ) + + +def test_context_headers_returns_request_headers(): + request = SimpleNamespace(headers={"x-github-user": "octocat"}) + ctx = Context(request_context=_request_context(request), mcp_server=MagicMock()) + assert ctx.headers == {"x-github-user": "octocat"} + + +def test_context_headers_is_none_without_request(): + ctx = Context(request_context=_request_context(None), mcp_server=MagicMock()) + assert ctx.headers is None + + async def test_read_resource_template_error(): """Template-creation failure must surface as INTERNAL_ERROR, not INVALID_PARAMS (not-found).""" mcp = MCPServer() From cafe8f325def42bb88dc2fb804c09a9174611a45 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 25 Jun 2026 16:56:08 +0200 Subject: [PATCH 3/8] Resolve type hints for callable-object tools in resolver detection find_resolved_parameters called typing.get_type_hints on the callable directly, which raises for a callable instance (an object with __call__), breaking tool registration for callable objects. Resolve hints off __call__ and tolerate unresolvable hints, mirroring find_context_parameter. --- src/mcp/server/mcpserver/resolve.py | 20 +++++++++++++++++--- tests/server/mcpserver/test_resolve.py | 9 +++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 8cab59c4d2..065f8dd71e 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -85,6 +85,21 @@ def __init__(self, fn: Callable[..., Any], params: dict[str, _ParamPlan], is_asy self.is_async = is_async +def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]: + """Resolve type hints for a function or a callable object. + + `typing.get_type_hints` raises on a callable *instance*; fall back to its + `__call__`. Returns an empty mapping when hints cannot be resolved, matching + `find_context_parameter`'s tolerance so callables without annotations (or with + unresolvable ones) simply have no resolved parameters. + """ + target = fn if inspect.isroutine(fn) else getattr(type(fn), "__call__", fn) + try: + return typing.get_type_hints(target, include_extras=True) + except Exception: + return {} + + def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, bool]]: """Find parameters of `fn` annotated `Annotated[_, Resolve(...)]`. @@ -92,9 +107,8 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, `wants_union` is True when the annotated type is an `ElicitationResult` member (the consumer wants the full outcome rather than the unwrapped model). """ - hints = typing.get_type_hints(fn, include_extras=True) resolved: dict[str, tuple[Resolve, bool]] = {} - for name, annotation in hints.items(): + for name, annotation in _type_hints(fn).items(): if get_origin(annotation) is not Annotated: continue type_arg, *metadata = get_args(annotation) @@ -130,7 +144,7 @@ def analyze(fn: Callable[..., Any], stack: tuple[int, ...]) -> None: if key in plans: return - hints = typing.get_type_hints(fn, include_extras=True) + hints = _type_hints(fn) sig = inspect.signature(fn) params: dict[str, _ParamPlan] = {} nested: list[Callable[..., Any]] = [] diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index a58e200065..cff0d5dc5d 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -17,6 +17,7 @@ Resolve, ) from mcp.server.mcpserver.exceptions import InvalidSignature +from mcp.server.mcpserver.resolve import find_resolved_parameters from mcp.server.mcpserver.tools.base import Tool from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -254,6 +255,14 @@ async def tool(value: Annotated[Login, Resolve(a)]) -> str: Tool.from_function(tool) +def test_find_resolved_parameters_tolerates_unresolvable_hints(): + def fn(x: int) -> int: + return x # pragma: no cover + + fn.__annotations__["x"] = "DoesNotExist" + assert find_resolved_parameters(fn) == {} + + def test_unresolvable_resolver_param_raises_at_registration(): async def login(mystery: int) -> Login: return Login(username="x") # pragma: no cover From 9e9282a126bc3567af0f9facbef66f6652062382 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 25 Jun 2026 17:22:59 +0200 Subject: [PATCH 4/8] Pin elicitation resolver tests to legacy mode for 2026-07-28 default After merging main, LATEST_PROTOCOL_VERSION is 2026-07-28, which defines no server-to-client requests, so elicitation/create is unavailable at the default negotiated version. Pin these tests to mode='legacy' (negotiates 2025-11-25) where elicitation is supported, matching test_elicitation.py. --- tests/server/mcpserver/test_resolve.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index cff0d5dc5d..9a6575b6fe 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -65,7 +65,7 @@ async def whoami(login: Annotated[Login, Resolve(login)]) -> str: async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover raise AssertionError("should not elicit") - async with Client(mcp, elicitation_callback=never) as client: + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: assert await _text(client, "whoami", {}) == "from-resolver" @@ -80,7 +80,7 @@ async def login(ctx: Context) -> Login | Elicit[Login]: async def whoami(login: Annotated[Login, Resolve(login)]) -> str: return login.username - async with Client(mcp, elicitation_callback=_accept({"username": "octocat"})) as client: + async with Client(mcp, mode="legacy", elicitation_callback=_accept({"username": "octocat"})) as client: assert await _text(client, "whoami", {}) == "octocat" @@ -101,7 +101,7 @@ async def whoami( case _: # pragma: no cover - accepted in this test return "no username" - async with Client(mcp, elicitation_callback=_accept({"username": "octocat"})) as client: + async with Client(mcp, mode="legacy", elicitation_callback=_accept({"username": "octocat"})) as client: assert await _text(client, "whoami", {}) == "hi octocat" @@ -120,7 +120,7 @@ async def whoami( return "declined gracefully" raise NotImplementedError - async with Client(mcp, elicitation_callback=_decline) as client: + async with Client(mcp, mode="legacy", elicitation_callback=_decline) as client: assert await _text(client, "whoami", {}) == "declined gracefully" @@ -135,7 +135,7 @@ async def login(ctx: Context) -> Login | Elicit[Login]: async def whoami(login: Annotated[Login, Resolve(login)]) -> str: raise NotImplementedError # pragma: no cover - never reached - async with Client(mcp, elicitation_callback=_decline) as client: + async with Client(mcp, mode="legacy", elicitation_callback=_decline) as client: result = await client.call_tool("whoami", {}) assert result.is_error assert isinstance(result.content[0], TextContent) @@ -168,7 +168,7 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - assert "Star modelcontextprotocol/python-sdk as octocat?" in params.message return ElicitResult(action="accept", content={"ok": True}) - async with Client(mcp, elicitation_callback=callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=callback) as client: text = await _text(client, "star_repo", {"repo": "modelcontextprotocol/python-sdk"}) assert text == "starred modelcontextprotocol/python-sdk as octocat" @@ -198,7 +198,7 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - return ElicitResult(action="accept", content={"username": "octocat"}) return ElicitResult(action="accept", content={"ok": True}) - async with Client(mcp, elicitation_callback=callback) as client: + async with Client(mcp, mode="legacy", elicitation_callback=callback) as client: assert await _text(client, "star_repo", {}) == "octocat:True" assert elicit_count == 1 @@ -217,7 +217,7 @@ async def whoami(login: Annotated[Login, Resolve(login)]) -> str: async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover raise AssertionError("should not elicit") - async with Client(mcp, elicitation_callback=never) as client: + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: assert await _text(client, "whoami", {}) == "sync-user" From c3ea531bb76cf7fdbb4c688d74477316cf0fc409 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 25 Jun 2026 17:46:31 +0200 Subject: [PATCH 5/8] Address cubic review: by-name aliasing, return-annotation, callable-resolver naming - tools/base.py: build tool_arg_names as 'alias or field_name' to match the runtime kwarg keys, so a by-name resolver param on an aliased field resolves instead of raising KeyError at call time. - resolve.py: iterate inspect.signature params (not get_type_hints items, which include 'return') so a Resolve marker on a return annotation is ignored; add _resolver_name so callable-object resolvers raise InvalidSignature instead of AttributeError in error messages. - migration.md: import DeclinedElicitation/CancelledElicitation used in the branching example so the snippet is runnable. Add regression tests for each. --- docs/migration.md | 10 +++++- src/mcp/server/mcpserver/resolve.py | 13 ++++++-- src/mcp/server/mcpserver/tools/base.py | 6 ++-- tests/server/mcpserver/test_resolve.py | 42 ++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 6c82596863..73343f67c3 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1435,7 +1435,15 @@ from typing import Annotated from pydantic import BaseModel -from mcp.server.mcpserver import AcceptedElicitation, Context, Elicit, MCPServer, Resolve +from mcp.server.mcpserver import ( + AcceptedElicitation, + CancelledElicitation, + Context, + DeclinedElicitation, + Elicit, + MCPServer, + Resolve, +) mcp = MCPServer(name="github") diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 065f8dd71e..dfdf1e7658 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -100,6 +100,11 @@ def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]: return {} +def _resolver_name(fn: Callable[..., Any]) -> str: + """Best-effort display name for error messages (callable objects lack `__name__`).""" + return getattr(fn, "__name__", None) or type(fn).__name__ + + def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, bool]]: """Find parameters of `fn` annotated `Annotated[_, Resolve(...)]`. @@ -107,8 +112,10 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, `wants_union` is True when the annotated type is an `ElicitationResult` member (the consumer wants the full outcome rather than the unwrapped model). """ + hints = _type_hints(fn) resolved: dict[str, tuple[Resolve, bool]] = {} - for name, annotation in _type_hints(fn).items(): + for name in inspect.signature(fn).parameters: + annotation = hints.get(name) if get_origin(annotation) is not Annotated: continue type_arg, *metadata = get_args(annotation) @@ -140,7 +147,7 @@ def build_resolver_plans( def analyze(fn: Callable[..., Any], stack: tuple[int, ...]) -> None: key = id(fn) if key in stack: - raise InvalidSignature(f"Resolver {fn.__name__!r} has a cyclic dependency") + raise InvalidSignature(f"Resolver {_resolver_name(fn)!r} has a cyclic dependency") if key in plans: return @@ -162,7 +169,7 @@ def analyze(fn: Callable[..., Any], stack: tuple[int, ...]) -> None: params[param_name] = _ParamPlan("by_name") continue raise InvalidSignature( - f"Resolver {fn.__name__!r} parameter {param_name!r} cannot be resolved: " + f"Resolver {_resolver_name(fn)!r} parameter {param_name!r} cannot be resolved: " "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" ) diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index a3cae05afb..b326108bc8 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -92,9 +92,9 @@ def from_function( ) parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True) - tool_arg_names = set(func_arg_metadata.arg_model.model_fields) | { - field.alias for field in func_arg_metadata.arg_model.model_fields.values() if field.alias - } + # Match `model_dump_one_level`'s kwarg keys (alias when present, else field name) + # so a by-name resolver param resolves to a key that exists at call time. + tool_arg_names = {field.alias or name for name, field in func_arg_metadata.arg_model.model_fields.items()} resolver_plans = build_resolver_plans(resolved_params, tool_arg_names) return cls( diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 9a6575b6fe..13345ce706 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -272,3 +272,45 @@ async def tool(login: Annotated[Login, Resolve(login)]) -> str: with pytest.raises(InvalidSignature, match="cannot be resolved"): Tool.from_function(tool) + + +def test_resolve_marker_on_return_annotation_is_ignored(): + async def login(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(repo: str) -> Annotated[str, Resolve(login)]: + return repo # pragma: no cover + + assert find_resolved_parameters(tool) == {} + + +def test_callable_object_resolver_error_uses_type_name(): + class BadResolver: + async def __call__(self, mystery: int) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(BadResolver())]) -> str: + return login.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="'BadResolver'"): + Tool.from_function(tool) + + +@pytest.mark.anyio +async def test_by_name_resolver_param_uses_aliased_tool_arg(): + mcp = MCPServer(name="Aliased") + + # `schema` collides with a BaseModel attribute, so func_metadata aliases the field; + # the runtime kwarg key is the alias, which is what a by-name resolver must match. + async def upper(schema: str) -> Login: + return Login(username=schema.upper()) + + @mcp.tool() + async def run(schema: str, shouted: Annotated[Login, Resolve(upper)]) -> str: + return shouted.username + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {"schema": "gpt"}) == "GPT" From aac86dc0c8847489fd6b622df2b1229446f3b9ab Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 25 Jun 2026 18:13:21 +0200 Subject: [PATCH 6/8] Fix resolver edge cases: non-BaseModel returns, optional Context, bound-method memoization bughunter findings on #2969: - Resolvers may return any type, not just BaseModel. Wrapping the return in AcceptedElicitation(data=...) validated it against the schema bound, so e.g. Annotated[str, Resolve(get_token)] failed every call with a cryptic ValidationError. Use model_construct to wrap the value without validation (the Elicit[T] path still validates via ctx.elicit). - _is_context_annotation now unwraps unions, so a resolver param typed Context | None is accepted, matching find_context_parameter on tools. - Memoize resolvers by the callable itself (hash/eq) instead of id(fn), so a bound-method resolver referenced as auth.login in two places runs at most once and participates in cycle detection. Fresh bound-method objects share identity by (__func__, __self__). Add regression tests for each. --- src/mcp/server/mcpserver/resolve.py | 47 +++++++++----- src/mcp/server/mcpserver/tools/base.py | 4 +- tests/server/mcpserver/test_resolve.py | 89 ++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 17 deletions(-) diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index dfdf1e7658..1c60ca7f39 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -21,7 +21,7 @@ import inspect import typing -from collections.abc import Callable, Mapping +from collections.abc import Callable, Hashable, Mapping from typing import Annotated, Any, Generic, cast, get_args, get_origin import anyio.to_thread @@ -131,10 +131,24 @@ def _wants_union(type_arg: Any) -> bool: return any(isinstance(m, type) and issubclass(m, _ELICITATION_RESULT_MEMBERS) for m in members) +def _resolver_key(fn: Callable[..., Any]) -> Hashable: + """Stable, equality-based key for memoizing a resolver. + + Bound methods are recreated on each attribute access (`id(auth.login)` differs + every time) but hash/compare by `(__func__, __self__)`, so the callable itself + is the right key. Falls back to `id` only for the rare unhashable callable. + """ + try: + hash(fn) + except TypeError: # pragma: no cover - unhashable callables are pathological + return id(fn) + return fn + + def build_resolver_plans( resolved_params: Mapping[str, tuple[Resolve, bool]], tool_arg_names: set[str], -) -> dict[int, _ResolverPlan]: +) -> dict[Hashable, _ResolverPlan]: """Statically analyze the resolver DAG rooted at a tool's resolved parameters. Raises: @@ -142,10 +156,10 @@ def build_resolver_plans( parameter cannot be classified (not a `Context`, a nested `Resolve`, or a tool argument by name). """ - plans: dict[int, _ResolverPlan] = {} + plans: dict[Hashable, _ResolverPlan] = {} - def analyze(fn: Callable[..., Any], stack: tuple[int, ...]) -> None: - key = id(fn) + def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: + key = _resolver_key(fn) if key in stack: raise InvalidSignature(f"Resolver {_resolver_name(fn)!r} has a cyclic dependency") if key in plans: @@ -193,12 +207,13 @@ def _resolve_marker(annotation: Any) -> tuple[Resolve | None, bool]: def _is_context_annotation(annotation: Any) -> bool: if get_origin(annotation) is Annotated: annotation = get_args(annotation)[0] - return isinstance(annotation, type) and issubclass(annotation, Context) + candidates = get_args(annotation) if get_origin(annotation) is not None else (annotation,) + return any(isinstance(c, type) and issubclass(c, Context) for c in candidates) async def resolve_arguments( resolved_params: Mapping[str, tuple[Resolve, bool]], - plans: Mapping[int, _ResolverPlan], + plans: Mapping[Hashable, _ResolverPlan], tool_args: Mapping[str, Any], context: Context[Any, Any], ) -> dict[str, Any]: @@ -211,7 +226,7 @@ async def resolve_arguments( ToolError: If an elicited value is declined or cancelled and the consumer asked for the unwrapped model (rather than the result union). """ - cache: dict[int, ElicitationResult[BaseModel]] = {} + cache: dict[Hashable, ElicitationResult[Any]] = {} injected: dict[str, Any] = {} for name, (marker, wants_union) in resolved_params.items(): outcome = await _resolve(marker.fn, plans, tool_args, context, cache) @@ -221,12 +236,12 @@ async def resolve_arguments( async def _resolve( fn: Callable[..., Any], - plans: Mapping[int, _ResolverPlan], + plans: Mapping[Hashable, _ResolverPlan], tool_args: Mapping[str, Any], context: Context[Any, Any], - cache: dict[int, ElicitationResult[BaseModel]], -) -> ElicitationResult[BaseModel]: - key = id(fn) + cache: dict[Hashable, ElicitationResult[Any]], +) -> ElicitationResult[Any]: + key = _resolver_key(fn) if key in cache: return cache[key] @@ -247,18 +262,20 @@ async def _resolve( else: result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) - outcome: ElicitationResult[BaseModel] + outcome: ElicitationResult[Any] if isinstance(result, Elicit): elicit = cast("Elicit[BaseModel]", result) outcome = await context.elicit(elicit.message, elicit.schema) else: - outcome = AcceptedElicitation(data=result) + # A resolver may return any type (not just `BaseModel`); `model_construct` + # wraps it as an accepted result without validating against the schema bound. + outcome = cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=result)) cache[key] = outcome return outcome -def _unwrap(outcome: ElicitationResult[BaseModel], name: str) -> BaseModel: +def _unwrap(outcome: ElicitationResult[Any], name: str) -> Any: if isinstance(outcome, AcceptedElicitation): return outcome.data raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}") diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index b326108bc8..14f8086856 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Hashable from functools import cached_property from typing import TYPE_CHECKING, Any @@ -42,7 +42,7 @@ class Tool(BaseModel): exclude=True, description="Parameters filled by resolvers, mapped to (Resolve, wants_union)", ) - resolver_plans: dict[int, Any] = Field( + resolver_plans: dict[Hashable, Any] = Field( default_factory=lambda: {}, exclude=True, description="Static per-resolver parameter plans" ) annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 13345ce706..e45a58f7d2 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -314,3 +314,92 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E async with Client(mcp, mode="legacy", elicitation_callback=never) as client: assert await _text(client, "run", {"schema": "gpt"}) == "GPT" + + +@pytest.mark.anyio +async def test_resolver_may_return_non_basemodel_value(): + mcp = MCPServer(name="NonModel") + + async def get_token(ctx: Context) -> str: + return "secret-token" + + @mcp.tool() + async def use_token(token: Annotated[str, Resolve(get_token)]) -> str: + return token + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "use_token", {}) == "secret-token" + + +@pytest.mark.anyio +async def test_resolver_accepts_optional_context_annotation(): + mcp = MCPServer(name="OptionalContext") + + async def whoami(ctx: Context | None) -> str: + assert ctx is not None + return "has-context" + + @mcp.tool() + async def run(who: Annotated[str, Resolve(whoami)]) -> str: + return who + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {}) == "has-context" + + +@pytest.mark.anyio +async def test_bound_method_resolver_runs_once_across_references(): + mcp = MCPServer(name="BoundMethod") + calls = 0 + + class Service: + async def token(self, ctx: Context) -> str: + nonlocal calls + calls += 1 + return "tok" + + service = Service() + + # Each `service.token` access is a fresh bound-method object; keying by the + # callable (not id) keeps the resolver memoized to a single call. + async def downstream(token: Annotated[str, Resolve(service.token)]) -> str: + return token.upper() + + @mcp.tool() + async def run( + token: Annotated[str, Resolve(service.token)], + shouted: Annotated[str, Resolve(downstream)], + ) -> str: + return f"{token}:{shouted}" + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {}) == "tok:TOK" + assert calls == 1 + + +def test_bound_method_cycle_is_detected(): + class Service: + async def a(self, dep: Login) -> Login: + return dep # pragma: no cover + + async def b(self, dep: Login) -> Login: + return dep # pragma: no cover + + service = Service() + service.a.__func__.__annotations__["dep"] = Annotated[Login, Resolve(service.b)] + service.b.__func__.__annotations__["dep"] = Annotated[Login, Resolve(service.a)] + + async def tool(value: Annotated[Login, Resolve(service.a)]) -> str: + return value.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="cyclic"): + Tool.from_function(tool) From 37c038c6fcdef965450469b6617be66d78e67d19 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 25 Jun 2026 18:29:41 +0200 Subject: [PATCH 7/8] Validate resolver tool args once; key resolvers by method identity Review follow-ups on #2969: - Tool.run validated arguments twice when resolvers were present (once to feed resolvers, once in call_fn_with_arg_validation). A field with default_factory or a stateful validator could hand a by-name resolver a different value than the tool body. Validate once and pass it through via a new pre_validated argument so both observe the same value. - Key the resolver cache/plans by (id(__func__), id(__self__)) for bound methods and id(fn) otherwise, instead of the callable's equality, so two distinct callables that compare equal can no longer share a plan/cache entry while bound-method memoization still works. Add regression tests. --- src/mcp/server/mcpserver/resolve.py | 17 +++++------ src/mcp/server/mcpserver/tools/base.py | 12 ++++++-- .../mcpserver/utilities/func_metadata.py | 10 +++++-- tests/server/mcpserver/test_resolve.py | 29 +++++++++++++++++++ 4 files changed, 55 insertions(+), 13 deletions(-) diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 1c60ca7f39..5202ff804c 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -132,17 +132,16 @@ def _wants_union(type_arg: Any) -> bool: def _resolver_key(fn: Callable[..., Any]) -> Hashable: - """Stable, equality-based key for memoizing a resolver. + """Identity key for memoizing a resolver. - Bound methods are recreated on each attribute access (`id(auth.login)` differs - every time) but hash/compare by `(__func__, __self__)`, so the callable itself - is the right key. Falls back to `id` only for the rare unhashable callable. + A bound method is recreated on each attribute access (`id(auth.login)` differs + every time), so key it by `(id(__func__), id(__self__))` to keep `auth.login` + referenced in two places memoized to one call. Everything else keys by `id`, + so two distinct callables never collide even if they compare equal. """ - try: - hash(fn) - except TypeError: # pragma: no cover - unhashable callables are pathological - return id(fn) - return fn + if inspect.ismethod(fn): + return (id(fn.__func__), id(fn.__self__)) + return id(fn) def build_resolver_plans( diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 14f8086856..2ee2862794 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -128,15 +128,23 @@ async def run( pass_directly: dict[str, Any] = {} if self.context_kwarg is not None: pass_directly[self.context_kwarg] = context + + # Resolvers see the same validated arguments the tool body receives: + # validate once and reuse it, so a `default_factory`/stateful validator + # can't hand a by-name resolver a different value than the body. + pre_validated: dict[str, Any] | None = None if self.resolved_params: - tool_args = self.fn_metadata.validate_arguments(arguments) - pass_directly |= await resolve_arguments(self.resolved_params, self.resolver_plans, tool_args, context) + pre_validated = self.fn_metadata.validate_arguments(arguments) + pass_directly |= await resolve_arguments( + self.resolved_params, self.resolver_plans, pre_validated, context + ) result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, self.is_async, arguments, pass_directly or None, + pre_validated=pre_validated, ) if convert_result: diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 53284c43b2..abc552efe0 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -81,13 +81,19 @@ async def call_fn_with_arg_validation( fn_is_async: bool, arguments_to_validate: dict[str, Any], arguments_to_pass_directly: dict[str, Any] | None, + pre_validated: dict[str, Any] | None = None, ) -> Any: """Call the given function with arguments validated and injected. Arguments are first attempted to be parsed from JSON, then validated against - the argument model, before being passed to the function. + the argument model, before being passed to the function. Pass `pre_validated` + (the output of `validate_arguments`) to reuse an earlier validation pass - + validating twice can re-run `default_factory`/stateful validators and hand the + function different values than a caller already observed. """ - arguments_parsed_dict = self.validate_arguments(arguments_to_validate) + arguments_parsed_dict = ( + pre_validated if pre_validated is not None else self.validate_arguments(arguments_to_validate) + ) arguments_parsed_dict |= arguments_to_pass_directly or {} diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index e45a58f7d2..8d065fd6f5 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -403,3 +403,32 @@ async def tool(value: Annotated[Login, Resolve(service.a)]) -> str: with pytest.raises(InvalidSignature, match="cyclic"): Tool.from_function(tool) + + +@pytest.mark.anyio +async def test_resolver_and_body_see_the_same_validated_default(): + mcp = MCPServer(name="DefaultFactory") + counter = {"n": 0} + + def next_id() -> int: + counter["n"] += 1 + return counter["n"] + + # A by-name resolver and the tool body must observe one validation pass, so the + # `default_factory` runs once and both see the same generated value. + async def echo_id(request_id: int) -> int: + return request_id + + @mcp.tool() + async def run( + request_id: Annotated[int, Field(default_factory=next_id)], + resolved_id: Annotated[int, Resolve(echo_id)], + ) -> str: + return f"{request_id}:{resolved_id}" + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {}) == "1:1" + assert counter["n"] == 1 From 58238b137d785156d083f97ddf67bcccb1b4e373 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 25 Jun 2026 18:59:50 +0200 Subject: [PATCH 8/8] Memoize built-in bound-method resolvers; stop mutating pre_validated Review follow-ups on #2969: - _resolver_key now keys any bound method (pure-python or built-in) by its underlying function/name plus __self__ identity, so a built-in bound method (no __func__, fresh object each access) referenced twice still memoizes to one call. - call_fn_with_arg_validation copies the validated args before merging the injected kwargs, so a caller-provided pre_validated dict is never mutated. Add regression tests. --- src/mcp/server/mcpserver/resolve.py | 19 +++++++++----- .../mcpserver/utilities/func_metadata.py | 3 ++- tests/server/mcpserver/test_func_metadata.py | 22 ++++++++++++++++ tests/server/mcpserver/test_resolve.py | 26 ++++++++++++++++++- 4 files changed, 62 insertions(+), 8 deletions(-) diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 5202ff804c..8244362e6e 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -134,13 +134,20 @@ def _wants_union(type_arg: Any) -> bool: def _resolver_key(fn: Callable[..., Any]) -> Hashable: """Identity key for memoizing a resolver. - A bound method is recreated on each attribute access (`id(auth.login)` differs - every time), so key it by `(id(__func__), id(__self__))` to keep `auth.login` - referenced in two places memoized to one call. Everything else keys by `id`, - so two distinct callables never collide even if they compare equal. + A bound method - pure-python (`inspect.ismethod`) or built-in (e.g. `obj.meth` + on a C-extension type) - is recreated on each attribute access, so `id(fn)` + differs every time. Key it by its underlying function (or name) plus its + `__self__` identity so `auth.login` referenced in two places memoizes to one + call. Everything else keys by `id`, so two distinct callables never collide + even if they compare equal. """ - if inspect.ismethod(fn): - return (id(fn.__func__), id(fn.__self__)) + bound_self = getattr(fn, "__self__", None) + if bound_self is not None: + # `__func__` (pure-python) has a stable identity; built-ins expose only a + # stable `__name__`. Use the function's id or the name's value accordingly. + func = getattr(fn, "__func__", None) + underlying: Hashable = id(func) if func is not None else getattr(fn, "__name__", id(fn)) + return (underlying, id(bound_self)) return id(fn) diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index abc552efe0..05b4563df6 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -91,7 +91,8 @@ async def call_fn_with_arg_validation( validating twice can re-run `default_factory`/stateful validators and hand the function different values than a caller already observed. """ - arguments_parsed_dict = ( + # Copy so a caller-provided `pre_validated` dict is never mutated in place. + arguments_parsed_dict = dict( pre_validated if pre_validated is not None else self.validate_arguments(arguments_to_validate) ) diff --git a/tests/server/mcpserver/test_func_metadata.py b/tests/server/mcpserver/test_func_metadata.py index 2763b3f503..be22722319 100644 --- a/tests/server/mcpserver/test_func_metadata.py +++ b/tests/server/mcpserver/test_func_metadata.py @@ -155,6 +155,28 @@ async def test_complex_function_runtime_arg_validation_with_json(): assert result == "ok!" +@pytest.mark.anyio +async def test_call_fn_does_not_mutate_pre_validated(): + """A caller-provided `pre_validated` dict must not be mutated by the call.""" + + def fn(x: int, ctx: str) -> str: + return f"{x}:{ctx}" + + meta = func_metadata(fn, skip_names=["ctx"]) + pre_validated = meta.validate_arguments({"x": 1}) + snapshot = dict(pre_validated) + + result = await meta.call_fn_with_arg_validation( + fn, + fn_is_async=False, + arguments_to_validate={"x": 1}, + arguments_to_pass_directly={"ctx": "injected"}, + pre_validated=pre_validated, + ) + assert result == "1:injected" + assert pre_validated == snapshot # `ctx` was not leaked into the caller's dict + + def test_str_vs_list_str(): """Test handling of string vs list[str] type annotations. diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 8d065fd6f5..3c94ebdf39 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -17,7 +17,7 @@ Resolve, ) from mcp.server.mcpserver.exceptions import InvalidSignature -from mcp.server.mcpserver.resolve import find_resolved_parameters +from mcp.server.mcpserver.resolve import _resolver_key, find_resolved_parameters from mcp.server.mcpserver.tools.base import Tool from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -432,3 +432,27 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E async with Client(mcp, mode="legacy", elicitation_callback=never) as client: assert await _text(client, "run", {}) == "1:1" assert counter["n"] == 1 + + +def test_resolver_key_is_stable_for_methods_and_distinct_callables(): + class Service: + def handler(self) -> None: ... # pragma: no cover + + a, b = Service(), Service() + + # Pure-python bound methods: stable across accesses, distinct per instance. + assert _resolver_key(a.handler) == _resolver_key(a.handler) + assert _resolver_key(a.handler) != _resolver_key(b.handler) + + # Built-in bound methods (no `__func__`): fresh object each access, but the key + # is stable and keyed to `__self__`. + items: list[int] = [] + others: list[int] = [] + assert _resolver_key(items.append) == _resolver_key(items.append) + assert _resolver_key(items.append) != _resolver_key(others.append) + assert _resolver_key(items.append) != _resolver_key(items.pop) + + # Plain functions key by identity. + def fn() -> None: ... # pragma: no cover + + assert _resolver_key(fn) == _resolver_key(fn)