diff --git a/docs/migration.md b/docs/migration.md index 46ec205ee9..73343f67c3 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1426,6 +1426,72 @@ app = server.streamable_http_app( The lowlevel `Server` also now exposes a `session_manager` property to access the `StreamableHTTPSessionManager` after calling `streamable_http_app()`. +### Resolver dependency injection for tools (`Resolve` / `Elicit`) + +A tool parameter annotated `Annotated[T, Resolve(fn)]` is filled by running the resolver `fn` before the tool body, instead of by the calling LLM. Resolvers form a dependency graph: a resolver may declare its own `Resolve(...)` dependencies, read the `Context` (including `ctx.headers`), and receive the tool's own arguments by name. A resolver may return `Elicit[T]` to ask the client; the SDK runs the elicitation and injects the answer. Each resolver runs at most once per `tools/call`. + +```python +from typing import Annotated + +from pydantic import BaseModel + +from mcp.server.mcpserver import ( + AcceptedElicitation, + CancelledElicitation, + Context, + DeclinedElicitation, + Elicit, + MCPServer, + Resolve, +) + +mcp = MCPServer(name="github") + + +class Login(BaseModel): + username: str + + +class Confirm(BaseModel): + ok: bool + + +async def login(ctx: Context) -> Login | Elicit[Login]: + if username := (ctx.headers or {}).get("x-github-user"): + return Login(username=username) # resolved from context, no question + return Elicit("GitHub username?", Login) # must ask + + +async def confirm(repo: str, login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"Star {repo} as {login.username}?", Confirm) + + +@mcp.tool() +async def star_repo( + repo: str, + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], +) -> str: + """Star a GitHub repo.""" + return f"starred {repo} as {login.username}" if confirm.ok else "cancelled" +``` + +The injected type follows the consumer's annotation. Annotating the unwrapped model (`Annotated[Login, Resolve(login)]`) injects the model on accept and aborts the call with an error result on decline or cancel. To branch on the outcome instead, annotate the elicitation result union: + +```python +@mcp.tool() +async def whoami( + login: Annotated[AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation, Resolve(login)], +) -> str: + match login: + case AcceptedElicitation(data=data): + return f"hi {data.username}" + case _: + return "no username provided" +``` + +Resolved parameters are omitted from the tool's input schema, so the client never supplies them. Resolver parameters that cannot be classified, and cyclic resolver dependencies, raise at registration time. + ## Need Help? If you encounter issues during migration: diff --git a/src/mcp/server/mcpserver/__init__.py b/src/mcp/server/mcpserver/__init__.py index 0857e38bd4..c6bc3d5b00 100644 --- a/src/mcp/server/mcpserver/__init__.py +++ b/src/mcp/server/mcpserver/__init__.py @@ -3,7 +3,27 @@ from mcp.types import Icon from .context import Context +from .resolve import ( + AcceptedElicitation, + CancelledElicitation, + DeclinedElicitation, + Elicit, + ElicitationResult, + Resolve, +) from .server import MCPServer from .utilities.types import Audio, Image -__all__ = ["MCPServer", "Context", "Image", "Audio", "Icon"] +__all__ = [ + "MCPServer", + "Context", + "Image", + "Audio", + "Icon", + "Resolve", + "Elicit", + "ElicitationResult", + "AcceptedElicitation", + "DeclinedElicitation", + "CancelledElicitation", +] diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 7856e32185..ec04c64fe8 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Generic +from collections.abc import Iterable, Mapping +from typing import TYPE_CHECKING, Any, Generic, Protocol, cast from pydantic import AnyUrl, BaseModel from typing_extensions import deprecated @@ -22,6 +22,11 @@ from mcp.server.mcpserver.server import MCPServer +class _HasHeaders(Protocol): + @property + def headers(self) -> Mapping[str, str]: ... + + class Context(BaseModel, Generic[LifespanContextT, RequestT]): """Context object providing access to MCP capabilities. @@ -214,6 +219,17 @@ def client_id(self) -> str | None: """ return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover + @property + def headers(self) -> Mapping[str, str] | None: + """Request headers carried by this message, when the transport has them. + + Populated by HTTP-based transports; `None` on stdio. + """ + request = self.request_context.request + if request is None: + return None + return cast("_HasHeaders", request).headers + @property def request_id(self) -> str: """Get the unique ID for this request.""" diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py new file mode 100644 index 0000000000..8244362e6e --- /dev/null +++ b/src/mcp/server/mcpserver/resolve.py @@ -0,0 +1,300 @@ +"""Resolver dependency injection for MCPServer tools. + +A tool parameter annotated `Annotated[T, Resolve(fn)]` is filled by running the +resolver `fn` before the tool body, instead of from the LLM-supplied arguments. +Resolvers form a DAG: a resolver may declare its own `Resolve(...)` dependencies, +take tool arguments by name, and take the `Context`. A resolver may return +`Elicit[T]` to ask the client; the framework runs the elicitation and injects the +answer. + +Whether the consumer receives the unwrapped model or the full +`ElicitationResult` union is decided by the consumer's annotation: + +- `Annotated[T, Resolve(fn)]` -> unwrapped `T`; decline/cancel aborts the call. +- `Annotated[ElicitationResult[T], Resolve(fn)]` (or a specific member) -> the + full outcome; the consumer branches on accept/decline/cancel. + +Each resolver runs at most once per `tools/call` (memoized by function identity). +""" + +from __future__ import annotations + +import inspect +import typing +from collections.abc import Callable, Hashable, Mapping +from typing import Annotated, Any, Generic, cast, get_args, get_origin + +import anyio.to_thread +from pydantic import BaseModel +from typing_extensions import TypeVar + +from mcp.server.elicitation import ( + AcceptedElicitation, + CancelledElicitation, + DeclinedElicitation, + ElicitationResult, +) +from mcp.server.mcpserver.context import Context +from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError +from mcp.shared._callable_inspection import is_async_callable + +T = TypeVar("T", bound=BaseModel) + +# The union members the framework injects when a consumer opts into the outcome. +_ELICITATION_RESULT_MEMBERS = (AcceptedElicitation, DeclinedElicitation, CancelledElicitation) + + +class Resolve: + """Marker for `Annotated[T, Resolve(fn)]`: fill the parameter by running `fn`.""" + + def __init__(self, fn: Callable[..., Any]) -> None: + self.fn = fn + + +class Elicit(Generic[T]): + """A resolver's request to ask the client. + + Returned from a resolver to signal that the value must be elicited. The + framework runs `ctx.elicit(message, schema)` and injects the outcome. + """ + + def __init__(self, message: str, schema: type[T]) -> None: + self.message = message + self.schema = schema + + +class _ParamPlan: + """How to fill one resolver parameter, decided once at registration.""" + + kind: str # "context" | "resolve" | "by_name" + resolve: Resolve | None + wants_union: bool + + def __init__(self, kind: str, resolve: Resolve | None = None, wants_union: bool = False) -> None: + self.kind = kind + self.resolve = resolve + self.wants_union = wants_union + + +class _ResolverPlan: + """A resolver's parameters and whether it is async, analyzed once.""" + + def __init__(self, fn: Callable[..., Any], params: dict[str, _ParamPlan], is_async: bool) -> None: + self.fn = fn + self.params = params + self.is_async = is_async + + +def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]: + """Resolve type hints for a function or a callable object. + + `typing.get_type_hints` raises on a callable *instance*; fall back to its + `__call__`. Returns an empty mapping when hints cannot be resolved, matching + `find_context_parameter`'s tolerance so callables without annotations (or with + unresolvable ones) simply have no resolved parameters. + """ + target = fn if inspect.isroutine(fn) else getattr(type(fn), "__call__", fn) + try: + return typing.get_type_hints(target, include_extras=True) + except Exception: + return {} + + +def _resolver_name(fn: Callable[..., Any]) -> str: + """Best-effort display name for error messages (callable objects lack `__name__`).""" + return getattr(fn, "__name__", None) or type(fn).__name__ + + +def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, bool]]: + """Find parameters of `fn` annotated `Annotated[_, Resolve(...)]`. + + Returns a mapping of parameter name to `(Resolve, wants_union)`, where + `wants_union` is True when the annotated type is an `ElicitationResult` member + (the consumer wants the full outcome rather than the unwrapped model). + """ + hints = _type_hints(fn) + resolved: dict[str, tuple[Resolve, bool]] = {} + for name in inspect.signature(fn).parameters: + annotation = hints.get(name) + if get_origin(annotation) is not Annotated: + continue + type_arg, *metadata = get_args(annotation) + marker = next((m for m in metadata if isinstance(m, Resolve)), None) + if marker is not None: + resolved[name] = (marker, _wants_union(type_arg)) + return resolved + + +def _wants_union(type_arg: Any) -> bool: + """True when `type_arg` is an `ElicitationResult` member (or a union of them).""" + members = get_args(type_arg) if get_origin(type_arg) is not None else (type_arg,) + return any(isinstance(m, type) and issubclass(m, _ELICITATION_RESULT_MEMBERS) for m in members) + + +def _resolver_key(fn: Callable[..., Any]) -> Hashable: + """Identity key for memoizing a resolver. + + A bound method - pure-python (`inspect.ismethod`) or built-in (e.g. `obj.meth` + on a C-extension type) - is recreated on each attribute access, so `id(fn)` + differs every time. Key it by its underlying function (or name) plus its + `__self__` identity so `auth.login` referenced in two places memoizes to one + call. Everything else keys by `id`, so two distinct callables never collide + even if they compare equal. + """ + bound_self = getattr(fn, "__self__", None) + if bound_self is not None: + # `__func__` (pure-python) has a stable identity; built-ins expose only a + # stable `__name__`. Use the function's id or the name's value accordingly. + func = getattr(fn, "__func__", None) + underlying: Hashable = id(func) if func is not None else getattr(fn, "__name__", id(fn)) + return (underlying, id(bound_self)) + return id(fn) + + +def build_resolver_plans( + resolved_params: Mapping[str, tuple[Resolve, bool]], + tool_arg_names: set[str], +) -> dict[Hashable, _ResolverPlan]: + """Statically analyze the resolver DAG rooted at a tool's resolved parameters. + + Raises: + InvalidSignature: If a resolver has a cyclic dependency, or a resolver + parameter cannot be classified (not a `Context`, a nested `Resolve`, + or a tool argument by name). + """ + plans: dict[Hashable, _ResolverPlan] = {} + + def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: + key = _resolver_key(fn) + if key in stack: + raise InvalidSignature(f"Resolver {_resolver_name(fn)!r} has a cyclic dependency") + if key in plans: + return + + hints = _type_hints(fn) + sig = inspect.signature(fn) + params: dict[str, _ParamPlan] = {} + nested: list[Callable[..., Any]] = [] + for param_name in sig.parameters: + annotation = hints.get(param_name) + if annotation is not None and _is_context_annotation(annotation): + params[param_name] = _ParamPlan("context") + continue + marker, wants_union = _resolve_marker(annotation) + if marker is not None: + params[param_name] = _ParamPlan("resolve", marker, wants_union) + nested.append(marker.fn) + continue + if param_name in tool_arg_names: + params[param_name] = _ParamPlan("by_name") + continue + raise InvalidSignature( + f"Resolver {_resolver_name(fn)!r} parameter {param_name!r} cannot be resolved: " + "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" + ) + + plans[key] = _ResolverPlan(fn, params, is_async_callable(fn)) + for dep in nested: + analyze(dep, stack + (key,)) + + for marker, _ in resolved_params.values(): + analyze(marker.fn, ()) + return plans + + +def _resolve_marker(annotation: Any) -> tuple[Resolve | None, bool]: + if get_origin(annotation) is not Annotated: + return None, False + type_arg, *metadata = get_args(annotation) + marker = next((m for m in metadata if isinstance(m, Resolve)), None) + return marker, (_wants_union(type_arg) if marker is not None else False) + + +def _is_context_annotation(annotation: Any) -> bool: + if get_origin(annotation) is Annotated: + annotation = get_args(annotation)[0] + candidates = get_args(annotation) if get_origin(annotation) is not None else (annotation,) + return any(isinstance(c, type) and issubclass(c, Context) for c in candidates) + + +async def resolve_arguments( + resolved_params: Mapping[str, tuple[Resolve, bool]], + plans: Mapping[Hashable, _ResolverPlan], + tool_args: Mapping[str, Any], + context: Context[Any, Any], +) -> dict[str, Any]: + """Resolve every `Resolve`-marked tool parameter into a concrete value. + + Each resolver runs at most once (memoized by function identity). Returns a + mapping of tool parameter name to the value to inject. + + Raises: + ToolError: If an elicited value is declined or cancelled and the consumer + asked for the unwrapped model (rather than the result union). + """ + cache: dict[Hashable, ElicitationResult[Any]] = {} + injected: dict[str, Any] = {} + for name, (marker, wants_union) in resolved_params.items(): + outcome = await _resolve(marker.fn, plans, tool_args, context, cache) + injected[name] = outcome if wants_union else _unwrap(outcome, name) + return injected + + +async def _resolve( + fn: Callable[..., Any], + plans: Mapping[Hashable, _ResolverPlan], + tool_args: Mapping[str, Any], + context: Context[Any, Any], + cache: dict[Hashable, ElicitationResult[Any]], +) -> ElicitationResult[Any]: + key = _resolver_key(fn) + if key in cache: + return cache[key] + + plan = plans[key] + kwargs: dict[str, Any] = {} + for param_name, param_plan in plan.params.items(): + if param_plan.kind == "context": + kwargs[param_name] = context + elif param_plan.kind == "by_name": + kwargs[param_name] = tool_args[param_name] + else: + assert param_plan.resolve is not None + dep_outcome = await _resolve(param_plan.resolve.fn, plans, tool_args, context, cache) + kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name) + + if plan.is_async: + result = await fn(**kwargs) + else: + result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) + + outcome: ElicitationResult[Any] + if isinstance(result, Elicit): + elicit = cast("Elicit[BaseModel]", result) + outcome = await context.elicit(elicit.message, elicit.schema) + else: + # A resolver may return any type (not just `BaseModel`); `model_construct` + # wraps it as an accepted result without validating against the schema bound. + outcome = cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=result)) + + cache[key] = outcome + return outcome + + +def _unwrap(outcome: ElicitationResult[Any], name: str) -> Any: + if isinstance(outcome, AcceptedElicitation): + return outcome.data + raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}") + + +__all__ = [ + "Resolve", + "Elicit", + "ElicitationResult", + "AcceptedElicitation", + "DeclinedElicitation", + "CancelledElicitation", + "find_resolved_parameters", + "build_resolver_plans", + "resolve_arguments", +] diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 29894d7d1d..2ee2862794 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -1,12 +1,17 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Hashable from functools import cached_property from typing import TYPE_CHECKING, Any from pydantic import BaseModel, Field from mcp.server.mcpserver.exceptions import ToolError +from mcp.server.mcpserver.resolve import ( + build_resolver_plans, + find_resolved_parameters, + resolve_arguments, +) from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata from mcp.shared._callable_inspection import is_async_callable @@ -32,6 +37,14 @@ class Tool(BaseModel): ) is_async: bool = Field(description="Whether the tool is async") context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") + resolved_params: dict[str, Any] = Field( + default_factory=lambda: {}, + exclude=True, + description="Parameters filled by resolvers, mapped to (Resolve, wants_union)", + ) + resolver_plans: dict[Hashable, Any] = Field( + default_factory=lambda: {}, exclude=True, description="Static per-resolver parameter plans" + ) annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool") meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this tool") @@ -67,13 +80,23 @@ def from_function( if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) + resolved_params = find_resolved_parameters(fn) + + skip_names = [context_kwarg] if context_kwarg is not None else [] + skip_names.extend(resolved_params) + func_arg_metadata = func_metadata( fn, - skip_names=[context_kwarg] if context_kwarg is not None else [], + skip_names=skip_names, structured_output=structured_output, ) parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True) + # Match `model_dump_one_level`'s kwarg keys (alias when present, else field name) + # so a by-name resolver param resolves to a key that exists at call time. + tool_arg_names = {field.alias or name for name, field in func_arg_metadata.arg_model.model_fields.items()} + resolver_plans = build_resolver_plans(resolved_params, tool_arg_names) + return cls( fn=fn, name=func_name, @@ -83,6 +106,8 @@ def from_function( fn_metadata=func_arg_metadata, is_async=is_async, context_kwarg=context_kwarg, + resolved_params=dict(resolved_params), + resolver_plans=resolver_plans, annotations=annotations, icons=icons, meta=meta, @@ -100,11 +125,26 @@ async def run( ToolError: If the tool function raises during execution. """ try: + pass_directly: dict[str, Any] = {} + if self.context_kwarg is not None: + pass_directly[self.context_kwarg] = context + + # Resolvers see the same validated arguments the tool body receives: + # validate once and reuse it, so a `default_factory`/stateful validator + # can't hand a by-name resolver a different value than the body. + pre_validated: dict[str, Any] | None = None + if self.resolved_params: + pre_validated = self.fn_metadata.validate_arguments(arguments) + pass_directly |= await resolve_arguments( + self.resolved_params, self.resolver_plans, pre_validated, context + ) + result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, self.is_async, arguments, - {self.context_kwarg: context} if self.context_kwarg is not None else None, + pass_directly or None, + pre_validated=pre_validated, ) if convert_result: diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 6c553fbab9..05b4563df6 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -65,21 +65,36 @@ class FuncMetadata(BaseModel): output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None wrap_output: bool = False + def validate_arguments(self, arguments_to_validate: dict[str, Any]) -> dict[str, Any]: + """Validate raw arguments into a one-level kwargs dict (no function call). + + Used to feed resolver dependency injection the validated tool arguments + before the tool function itself runs. + """ + arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) + arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) + return arguments_parsed_model.model_dump_one_level() + async def call_fn_with_arg_validation( self, fn: Callable[..., Any | Awaitable[Any]], fn_is_async: bool, arguments_to_validate: dict[str, Any], arguments_to_pass_directly: dict[str, Any] | None, + pre_validated: dict[str, Any] | None = None, ) -> Any: """Call the given function with arguments validated and injected. Arguments are first attempted to be parsed from JSON, then validated against - the argument model, before being passed to the function. + the argument model, before being passed to the function. Pass `pre_validated` + (the output of `validate_arguments`) to reuse an earlier validation pass - + validating twice can re-run `default_factory`/stateful validators and hand the + function different values than a caller already observed. """ - arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) - arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) - arguments_parsed_dict = arguments_parsed_model.model_dump_one_level() + # Copy so a caller-provided `pre_validated` dict is never mutated in place. + arguments_parsed_dict = dict( + pre_validated if pre_validated is not None else self.validate_arguments(arguments_to_validate) + ) arguments_parsed_dict |= arguments_to_pass_directly or {} diff --git a/tests/server/mcpserver/test_func_metadata.py b/tests/server/mcpserver/test_func_metadata.py index 2763b3f503..be22722319 100644 --- a/tests/server/mcpserver/test_func_metadata.py +++ b/tests/server/mcpserver/test_func_metadata.py @@ -155,6 +155,28 @@ async def test_complex_function_runtime_arg_validation_with_json(): assert result == "ok!" +@pytest.mark.anyio +async def test_call_fn_does_not_mutate_pre_validated(): + """A caller-provided `pre_validated` dict must not be mutated by the call.""" + + def fn(x: int, ctx: str) -> str: + return f"{x}:{ctx}" + + meta = func_metadata(fn, skip_names=["ctx"]) + pre_validated = meta.validate_arguments({"x": 1}) + snapshot = dict(pre_validated) + + result = await meta.call_fn_with_arg_validation( + fn, + fn_is_async=False, + arguments_to_validate={"x": 1}, + arguments_to_pass_directly={"ctx": "injected"}, + pre_validated=pre_validated, + ) + assert result == "1:injected" + assert pre_validated == snapshot # `ctx` was not leaked into the caller's dict + + def test_str_vs_list_str(): """Test handling of string vs list[str] type annotations. diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py new file mode 100644 index 0000000000..3c94ebdf39 --- /dev/null +++ b/tests/server/mcpserver/test_resolve.py @@ -0,0 +1,458 @@ +"""Tests for resolver dependency injection (MRTR) on MCPServer tools.""" + +from typing import Annotated + +import pytest +from pydantic import BaseModel, Field + +from mcp import Client +from mcp.client import ClientRequestContext +from mcp.server.mcpserver import ( + AcceptedElicitation, + CancelledElicitation, + Context, + DeclinedElicitation, + Elicit, + MCPServer, + Resolve, +) +from mcp.server.mcpserver.exceptions import InvalidSignature +from mcp.server.mcpserver.resolve import _resolver_key, find_resolved_parameters +from mcp.server.mcpserver.tools.base import Tool +from mcp.types import ElicitRequestParams, ElicitResult, TextContent + + +class Login(BaseModel): + username: str + + +class Confirm(BaseModel): + ok: bool + + +def _accept(content: dict[str, str | int | float | bool | list[str] | None]): + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content=content) + + return callback + + +async def _decline(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + +async def _text(client: Client, tool: str, args: dict[str, object]) -> str: + result = await client.call_tool(tool, args) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + return result.content[0].text + + +@pytest.mark.anyio +async def test_resolver_returns_value_directly_without_eliciting(): + mcp = MCPServer(name="Direct") + + async def login(ctx: Context) -> Login | Elicit[Login]: + username = (ctx.headers or {}).get("x-github-user") + if username: # pragma: no cover - no headers on in-memory transport + return Login(username=username) + return Login(username="from-resolver") + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + return login.username + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "whoami", {}) == "from-resolver" + + +@pytest.mark.anyio +async def test_resolver_elicits_and_injects_unwrapped_model_on_accept(): + mcp = MCPServer(name="Accept") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + return login.username + + async with Client(mcp, mode="legacy", elicitation_callback=_accept({"username": "octocat"})) as client: + assert await _text(client, "whoami", {}) == "octocat" + + +@pytest.mark.anyio +async def test_consumer_receives_result_union_and_branches(): + mcp = MCPServer(name="Union") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami( + login: Annotated[AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation, Resolve(login)], + ) -> str: + match login: + case AcceptedElicitation(data=data): + return f"hi {data.username}" + case _: # pragma: no cover - accepted in this test + return "no username" + + async with Client(mcp, mode="legacy", elicitation_callback=_accept({"username": "octocat"})) as client: + assert await _text(client, "whoami", {}) == "hi octocat" + + +@pytest.mark.anyio +async def test_decline_reaches_union_consumer_without_aborting(): + mcp = MCPServer(name="UnionDecline") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami( + login: Annotated[AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation, Resolve(login)], + ) -> str: + if isinstance(login, DeclinedElicitation): + return "declined gracefully" + raise NotImplementedError + + async with Client(mcp, mode="legacy", elicitation_callback=_decline) as client: + assert await _text(client, "whoami", {}) == "declined gracefully" + + +@pytest.mark.anyio +async def test_decline_aborts_when_consumer_wants_unwrapped(): + mcp = MCPServer(name="UnwrappedDecline") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + raise NotImplementedError # pragma: no cover - never reached + + async with Client(mcp, mode="legacy", elicitation_callback=_decline) as client: + result = await client.call_tool("whoami", {}) + assert result.is_error + assert isinstance(result.content[0], TextContent) + assert "decline" in result.content[0].text + + +@pytest.mark.anyio +async def test_nested_resolver_sees_dependency_and_tool_args(): + mcp = MCPServer(name="Nested") + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + async def confirm(repo: str, login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"Star {repo} as {login.username}?", Confirm) + + @mcp.tool() + async def star_repo( + repo: str, + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + if confirm.ok: + return f"starred {repo} as {login.username}" + raise NotImplementedError + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + if "username" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + assert "Star modelcontextprotocol/python-sdk as octocat?" in params.message + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, mode="legacy", elicitation_callback=callback) as client: + text = await _text(client, "star_repo", {"repo": "modelcontextprotocol/python-sdk"}) + assert text == "starred modelcontextprotocol/python-sdk as octocat" + + +@pytest.mark.anyio +async def test_resolver_runs_once_for_two_consumers(): + mcp = MCPServer(name="ExactlyOnce") + elicit_count = 0 + + async def login(ctx: Context) -> Login | Elicit[Login]: + return Elicit("GitHub username?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"As {login.username}?", Confirm) + + @mcp.tool() + async def star_repo( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + nonlocal elicit_count + if "username" in params.message: + elicit_count += 1 + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, mode="legacy", elicitation_callback=callback) as client: + assert await _text(client, "star_repo", {}) == "octocat:True" + assert elicit_count == 1 + + +@pytest.mark.anyio +async def test_sync_resolver(): + mcp = MCPServer(name="Sync") + + def login(ctx: Context) -> Login: + return Login(username="sync-user") + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(login)]) -> str: + return login.username + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "whoami", {}) == "sync-user" + + +def test_resolved_params_absent_from_input_schema(): + async def login(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover - only the schema is inspected + + async def tool( + repo: Annotated[str, Field(description="repo name")], + login: Annotated[Login, Resolve(login)], + ) -> str: + return repo # pragma: no cover - only the schema is inspected + + built = Tool.from_function(tool) + properties = built.parameters["properties"] + assert "repo" in properties + assert "login" not in properties + + +def test_cycle_detection_raises_at_registration(): + async def a(dep: Login) -> Login: + return dep # pragma: no cover + + async def b(dep: Login) -> Login: + return dep # pragma: no cover + + # Close the loop after both exist: a depends on b, b depends on a. + a.__annotations__["dep"] = Annotated[Login, Resolve(b)] + b.__annotations__["dep"] = Annotated[Login, Resolve(a)] + + async def tool(value: Annotated[Login, Resolve(a)]) -> str: + return value.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="cyclic"): + Tool.from_function(tool) + + +def test_find_resolved_parameters_tolerates_unresolvable_hints(): + def fn(x: int) -> int: + return x # pragma: no cover + + fn.__annotations__["x"] = "DoesNotExist" + assert find_resolved_parameters(fn) == {} + + +def test_unresolvable_resolver_param_raises_at_registration(): + async def login(mystery: int) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(login)]) -> str: + return login.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="cannot be resolved"): + Tool.from_function(tool) + + +def test_resolve_marker_on_return_annotation_is_ignored(): + async def login(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(repo: str) -> Annotated[str, Resolve(login)]: + return repo # pragma: no cover + + assert find_resolved_parameters(tool) == {} + + +def test_callable_object_resolver_error_uses_type_name(): + class BadResolver: + async def __call__(self, mystery: int) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(BadResolver())]) -> str: + return login.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="'BadResolver'"): + Tool.from_function(tool) + + +@pytest.mark.anyio +async def test_by_name_resolver_param_uses_aliased_tool_arg(): + mcp = MCPServer(name="Aliased") + + # `schema` collides with a BaseModel attribute, so func_metadata aliases the field; + # the runtime kwarg key is the alias, which is what a by-name resolver must match. + async def upper(schema: str) -> Login: + return Login(username=schema.upper()) + + @mcp.tool() + async def run(schema: str, shouted: Annotated[Login, Resolve(upper)]) -> str: + return shouted.username + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {"schema": "gpt"}) == "GPT" + + +@pytest.mark.anyio +async def test_resolver_may_return_non_basemodel_value(): + mcp = MCPServer(name="NonModel") + + async def get_token(ctx: Context) -> str: + return "secret-token" + + @mcp.tool() + async def use_token(token: Annotated[str, Resolve(get_token)]) -> str: + return token + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "use_token", {}) == "secret-token" + + +@pytest.mark.anyio +async def test_resolver_accepts_optional_context_annotation(): + mcp = MCPServer(name="OptionalContext") + + async def whoami(ctx: Context | None) -> str: + assert ctx is not None + return "has-context" + + @mcp.tool() + async def run(who: Annotated[str, Resolve(whoami)]) -> str: + return who + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {}) == "has-context" + + +@pytest.mark.anyio +async def test_bound_method_resolver_runs_once_across_references(): + mcp = MCPServer(name="BoundMethod") + calls = 0 + + class Service: + async def token(self, ctx: Context) -> str: + nonlocal calls + calls += 1 + return "tok" + + service = Service() + + # Each `service.token` access is a fresh bound-method object; keying by the + # callable (not id) keeps the resolver memoized to a single call. + async def downstream(token: Annotated[str, Resolve(service.token)]) -> str: + return token.upper() + + @mcp.tool() + async def run( + token: Annotated[str, Resolve(service.token)], + shouted: Annotated[str, Resolve(downstream)], + ) -> str: + return f"{token}:{shouted}" + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {}) == "tok:TOK" + assert calls == 1 + + +def test_bound_method_cycle_is_detected(): + class Service: + async def a(self, dep: Login) -> Login: + return dep # pragma: no cover + + async def b(self, dep: Login) -> Login: + return dep # pragma: no cover + + service = Service() + service.a.__func__.__annotations__["dep"] = Annotated[Login, Resolve(service.b)] + service.b.__func__.__annotations__["dep"] = Annotated[Login, Resolve(service.a)] + + async def tool(value: Annotated[Login, Resolve(service.a)]) -> str: + return value.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="cyclic"): + Tool.from_function(tool) + + +@pytest.mark.anyio +async def test_resolver_and_body_see_the_same_validated_default(): + mcp = MCPServer(name="DefaultFactory") + counter = {"n": 0} + + def next_id() -> int: + counter["n"] += 1 + return counter["n"] + + # A by-name resolver and the tool body must observe one validation pass, so the + # `default_factory` runs once and both see the same generated value. + async def echo_id(request_id: int) -> int: + return request_id + + @mcp.tool() + async def run( + request_id: Annotated[int, Field(default_factory=next_id)], + resolved_id: Annotated[int, Resolve(echo_id)], + ) -> str: + return f"{request_id}:{resolved_id}" + + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit") + + async with Client(mcp, mode="legacy", elicitation_callback=never) as client: + assert await _text(client, "run", {}) == "1:1" + assert counter["n"] == 1 + + +def test_resolver_key_is_stable_for_methods_and_distinct_callables(): + class Service: + def handler(self) -> None: ... # pragma: no cover + + a, b = Service(), Service() + + # Pure-python bound methods: stable across accesses, distinct per instance. + assert _resolver_key(a.handler) == _resolver_key(a.handler) + assert _resolver_key(a.handler) != _resolver_key(b.handler) + + # Built-in bound methods (no `__func__`): fresh object each access, but the key + # is stable and keyed to `__self__`. + items: list[int] = [] + others: list[int] = [] + assert _resolver_key(items.append) == _resolver_key(items.append) + assert _resolver_key(items.append) != _resolver_key(others.append) + assert _resolver_key(items.append) != _resolver_key(items.pop) + + # Plain functions key by identity. + def fn() -> None: ... # pragma: no cover + + assert _resolver_key(fn) == _resolver_key(fn) diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 9b469e566a..c06f859c97 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1,5 +1,6 @@ import base64 from pathlib import Path +from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -1541,6 +1542,27 @@ async def test_report_progress_delegates_to_session_report_progress(): mock_session.report_progress.assert_awaited_once_with(50, 100, "halfway") +def _request_context(request: object | None) -> ServerRequestContext[None, object]: + return ServerRequestContext( + session=AsyncMock(), + method="tools/call", + lifespan_context=None, + protocol_version="2025-11-25", + request=request, + ) + + +def test_context_headers_returns_request_headers(): + request = SimpleNamespace(headers={"x-github-user": "octocat"}) + ctx = Context(request_context=_request_context(request), mcp_server=MagicMock()) + assert ctx.headers == {"x-github-user": "octocat"} + + +def test_context_headers_is_none_without_request(): + ctx = Context(request_context=_request_context(None), mcp_server=MagicMock()) + assert ctx.headers is None + + async def test_read_resource_template_error(): """Template-creation failure must surface as INTERNAL_ERROR, not INVALID_PARAMS (not-found).""" mcp = MCPServer()