diff --git a/src/mcp/server/fastmcp/resources/base.py b/src/mcp/server/fastmcp/resources/base.py index c733e1a46b..fb9b0020c5 100644 --- a/src/mcp/server/fastmcp/resources/base.py +++ b/src/mcp/server/fastmcp/resources/base.py @@ -1,7 +1,9 @@ """Base classes and interfaces for FastMCP resources.""" +from __future__ import annotations + import abc -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from pydantic import ( AnyUrl, @@ -15,6 +17,11 @@ from mcp.types import Annotations, Icon +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context + from mcp.server.session import ServerSessionT + from mcp.shared.context import LifespanContextT, RequestT + class Resource(BaseModel, abc.ABC): """Base class for all resources.""" @@ -44,6 +51,9 @@ def set_default_name(cls, name: str | None, info: ValidationInfo) -> str: raise ValueError("Either name or uri must be provided") @abc.abstractmethod - async def read(self) -> str | bytes: + async def read( + self, + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + ) -> str | bytes: """Read the resource content.""" pass # pragma: no cover diff --git a/src/mcp/server/fastmcp/resources/templates.py b/src/mcp/server/fastmcp/resources/templates.py index a98d37f0ac..809967ef7a 100644 --- a/src/mcp/server/fastmcp/resources/templates.py +++ b/src/mcp/server/fastmcp/resources/templates.py @@ -2,16 +2,15 @@ from __future__ import annotations -import inspect import re from collections.abc import Callable from typing import TYPE_CHECKING, Any -from pydantic import BaseModel, Field, validate_call +from pydantic import BaseModel, Field from mcp.server.fastmcp.resources.types import FunctionResource, Resource -from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context -from mcp.server.fastmcp.utilities.func_metadata import func_metadata +from mcp.server.fastmcp.utilities.context_injection import find_context_parameter +from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata, is_async_callable from mcp.types import Annotations, Icon if TYPE_CHECKING: @@ -33,6 +32,10 @@ class ResourceTemplate(BaseModel): fn: Callable[..., Any] = Field(exclude=True) parameters: dict[str, Any] = Field(description="JSON schema for function parameters") context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") + fn_metadata: FuncMetadata = Field( + description="Metadata about the function including a pydantic model for arguments" + ) + is_async: bool = Field(description="Whether the function is async") @classmethod def from_function( @@ -56,6 +59,8 @@ def from_function( if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) + is_async = is_async_callable(fn) + # Get schema from func_metadata, excluding context parameter func_arg_metadata = func_metadata( fn, @@ -63,9 +68,6 @@ def from_function( ) parameters = func_arg_metadata.arg_model.model_json_schema() - # ensure the arguments are properly cast - fn = validate_call(fn) - return cls( uri_template=uri_template, name=func_name, @@ -77,6 +79,8 @@ def from_function( fn=fn, parameters=parameters, context_kwarg=context_kwarg, + fn_metadata=func_arg_metadata, + is_async=is_async, ) def matches(self, uri: str) -> dict[str, Any] | None: @@ -96,13 +100,12 @@ async def create_resource( ) -> Resource: """Create a resource from the template with the given parameters.""" try: - # Add context to params if needed - params = inject_context(self.fn, params, context, self.context_kwarg) - - # Call function and check if result is a coroutine - result = self.fn(**params) - if inspect.iscoroutine(result): - result = await result + result = await self.fn_metadata.call_fn_with_arg_validation( + self.fn, + self.is_async, + params, + {self.context_kwarg: context} if self.context_kwarg is not None else None, + ) return FunctionResource( uri=uri, # type: ignore diff --git a/src/mcp/server/fastmcp/resources/types.py b/src/mcp/server/fastmcp/resources/types.py index 680e72dc09..c9eb2316fc 100644 --- a/src/mcp/server/fastmcp/resources/types.py +++ b/src/mcp/server/fastmcp/resources/types.py @@ -1,28 +1,39 @@ """Concrete resource implementations.""" +from __future__ import annotations + import inspect import json from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import anyio import anyio.to_thread import httpx import pydantic import pydantic_core -from pydantic import AnyUrl, Field, ValidationInfo, validate_call +from pydantic import AnyUrl, Field, ValidationInfo from mcp.server.fastmcp.resources.base import Resource +from mcp.server.fastmcp.utilities.context_injection import find_context_parameter from mcp.types import Annotations, Icon +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context + from mcp.server.session import ServerSessionT + from mcp.shared.context import LifespanContextT, RequestT + class TextResource(Resource): """A resource that reads from a string.""" text: str = Field(description="Text content of the resource") - async def read(self) -> str: + async def read( + self, + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + ) -> str: """Read the text content.""" return self.text # pragma: no cover @@ -32,7 +43,10 @@ class BinaryResource(Resource): data: bytes = Field(description="Binary content of the resource") - async def read(self) -> bytes: + async def read( + self, + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + ) -> bytes: """Read the binary content.""" return self.data # pragma: no cover @@ -50,13 +64,22 @@ class FunctionResource(Resource): - other types will be converted to JSON """ - fn: Callable[[], Any] = Field(exclude=True) + fn: Callable[..., Any] = Field(exclude=True) + context_kwarg: str | None = Field(default=None, description="Name of the kwarg that should receive context") - async def read(self) -> str | bytes: + async def read( + self, + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + ) -> str | bytes: """Read the resource by calling the wrapped function.""" try: - # Call the function first to see if it returns a coroutine - result = self.fn() + # Inject context if needed + kwargs: dict[str, Any] = {} + if self.context_kwarg is not None and context is not None: + kwargs[self.context_kwarg] = context + + # Call the function + result = self.fn(**kwargs) # If it's a coroutine, await it if inspect.iscoroutine(result): result = await result @@ -83,14 +106,14 @@ def from_function( mime_type: str | None = None, icons: list[Icon] | None = None, annotations: Annotations | None = None, - ) -> "FunctionResource": + ) -> FunctionResource: """Create a FunctionResource from a function.""" func_name = name or fn.__name__ if func_name == "": # pragma: no cover raise ValueError("You must provide a name for lambda functions") - # ensure the arguments are properly cast - fn = validate_call(fn) + # Find context parameter if it exists + context_kwarg = find_context_parameter(fn) return cls( uri=AnyUrl(uri), @@ -99,6 +122,7 @@ def from_function( description=description or fn.__doc__ or "", mime_type=mime_type or "text/plain", fn=fn, + context_kwarg=context_kwarg, icons=icons, annotations=annotations, ) @@ -137,7 +161,10 @@ def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> boo mime_type = info.data.get("mime_type", "text/plain") return not mime_type.startswith("text/") - async def read(self) -> str | bytes: + async def read( + self, + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + ) -> str | bytes: """Read the file content.""" try: if self.is_binary: @@ -153,7 +180,10 @@ class HttpResource(Resource): url: str = Field(description="URL to fetch content from") mime_type: str = Field(default="application/json", description="MIME type of the resource content") - async def read(self) -> str | bytes: + async def read( + self, + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + ) -> str | bytes: """Read the HTTP content.""" async with httpx.AsyncClient() as client: # pragma: no cover response = await client.get(self.url) @@ -191,7 +221,10 @@ def list_files(self) -> list[Path]: # pragma: no cover except Exception as e: raise ValueError(f"Error listing directory {self.path}: {e}") - async def read(self) -> str: # Always returns JSON string # pragma: no cover + async def read( + self, + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + ) -> str: # Always returns JSON string # pragma: no cover """Read the directory listing.""" try: files = await anyio.to_thread.run_sync(self.list_files) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 5d6781f83d..cda7b75763 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -372,7 +372,7 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent raise ResourceError(f"Unknown resource: {uri}") try: - content = await resource.read() + content = await resource.read(context=context) return [ReadResourceContents(content=content, mime_type=resource.mime_type)] except Exception as e: # pragma: no cover logger.exception(f"Error reading resource {uri}") @@ -571,21 +571,18 @@ async def get_weather(city: str) -> str: ) def decorator(fn: AnyFunction) -> AnyFunction: - # Check if this should be a template + # Extract signature and parameters sig = inspect.signature(fn) - has_uri_params = "{" in uri and "}" in uri - has_func_params = bool(sig.parameters) + uri_params = set(re.findall(r"{(\w+)}", uri)) + context_param = find_context_parameter(fn) + func_params = {p for p in sig.parameters.keys() if p != context_param} - if has_uri_params or has_func_params: - # Check for Context parameter to exclude from validation - context_param = find_context_parameter(fn) - - # Validate that URI params match function params (excluding context) - uri_params = set(re.findall(r"{(\w+)}", uri)) - # We need to remove the context_param from the resource function if - # there is any. - func_params = {p for p in sig.parameters.keys() if p != context_param} + # Determine if this should be a template + has_uri_params = len(uri_params) != 0 + has_func_params = len(func_params) != 0 + if has_uri_params or has_func_params: + # Validate that URI params match function params if uri_params != func_params: raise ValueError( f"Mismatch between URI parameters {uri_params} and function parameters {func_params}" diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index cf89fc8aa1..94489a3473 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -1,7 +1,5 @@ from __future__ import annotations as _annotations -import functools -import inspect from collections.abc import Callable from functools import cached_property from typing import TYPE_CHECKING, Any @@ -10,7 +8,7 @@ from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.utilities.context_injection import find_context_parameter -from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata, is_async_callable from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations @@ -63,7 +61,7 @@ def from_function( raise ValueError("You must provide a name for lambda functions") func_doc = description or fn.__doc__ or "" - is_async = _is_async_callable(fn) + is_async = is_async_callable(fn) if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) @@ -110,12 +108,3 @@ async def run( return result except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e - - -def _is_async_callable(obj: Any) -> bool: - while isinstance(obj, functools.partial): # pragma: no cover - obj = obj.func - - return inspect.iscoroutinefunction(obj) or ( - callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) - ) diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index fa443d2fcb..390ffd33d1 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -1,3 +1,4 @@ +import functools import inspect import json from collections.abc import Awaitable, Callable, Sequence @@ -531,3 +532,13 @@ def _convert_to_content( result = pydantic_core.to_json(result, fallback=str, indent=2).decode() return [TextContent(type="text", text=result)] + + +def is_async_callable(obj: Any) -> bool: + """Check if an object is an async callable.""" + while isinstance(obj, functools.partial): # pragma: no cover + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) diff --git a/tests/server/fastmcp/resources/test_function_resources.py b/tests/server/fastmcp/resources/test_function_resources.py index fccada4750..435bf5d639 100644 --- a/tests/server/fastmcp/resources/test_function_resources.py +++ b/tests/server/fastmcp/resources/test_function_resources.py @@ -1,7 +1,9 @@ import pytest from pydantic import AnyUrl, BaseModel +from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.resources import FunctionResource +from mcp.server.session import ServerSession class TestFunctionResource: @@ -155,3 +157,79 @@ async def get_data() -> str: # pragma: no cover assert resource.mime_type == "text/plain" assert resource.name == "test" assert resource.uri == AnyUrl("function://test") + + +class TestFunctionResourceContextHandling: + """Test context injection in FunctionResource.""" + + def test_context_kwarg_detection(self): + """Test that from_function() correctly detects context parameters.""" + + def func_with_context(ctx: Context[ServerSession, None]) -> str: # pragma: no cover + return "test" + + resource = FunctionResource.from_function(fn=func_with_context, uri="test://uri") + assert resource.context_kwarg == "ctx" + + def test_context_kwarg_custom_name(self): + """Test detection of context with custom parameter names.""" + + def func_with_custom_ctx(my_context: Context[ServerSession, None]) -> str: # pragma: no cover + return "test" + + resource = FunctionResource.from_function(fn=func_with_custom_ctx, uri="test://uri") + assert resource.context_kwarg == "my_context" + + def test_no_context_kwarg(self): + """Test that functions without context have context_kwarg=None.""" + + def func_without_context() -> str: # pragma: no cover + return "test" + + resource = FunctionResource.from_function(fn=func_without_context, uri="test://uri") + assert resource.context_kwarg is None + + @pytest.mark.anyio + async def test_read_with_context_injection(self): + """Test that read(context=ctx) injects context into function.""" + received_context = None + + def func_with_context(ctx: Context[ServerSession, None]) -> str: + nonlocal received_context + received_context = ctx + return "result" + + resource = FunctionResource.from_function(fn=func_with_context, uri="test://uri") + mcp = FastMCP() + ctx = mcp.get_context() + result = await resource.read(context=ctx) + assert received_context is ctx + assert result == "result" + + @pytest.mark.anyio + async def test_read_without_context_when_not_needed(self): + """Test that functions without context work normally.""" + + def func_without_context() -> str: + return "no context needed" + + resource = FunctionResource.from_function(fn=func_without_context, uri="test://uri") + result = await resource.read() + assert result == "no context needed" + + @pytest.mark.anyio + async def test_read_async_with_context(self): + """Test async functions with context injection.""" + received_context = None + + async def async_func_with_context(ctx: Context[ServerSession, None]) -> str: + nonlocal received_context + received_context = ctx + return "async result" + + resource = FunctionResource.from_function(fn=async_func_with_context, uri="test://uri") + mcp = FastMCP() + ctx = mcp.get_context() + result = await resource.read(context=ctx) + assert received_context is ctx + assert result == "async result" diff --git a/tests/server/fastmcp/resources/test_resource_manager.py b/tests/server/fastmcp/resources/test_resource_manager.py index a0c06be86c..08eaddcad3 100644 --- a/tests/server/fastmcp/resources/test_resource_manager.py +++ b/tests/server/fastmcp/resources/test_resource_manager.py @@ -4,7 +4,9 @@ import pytest from pydantic import AnyUrl, FileUrl +from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.resources import FileResource, FunctionResource, ResourceManager, ResourceTemplate +from mcp.server.session import ServerSession @pytest.fixture @@ -134,3 +136,30 @@ def test_list_resources(self, temp_file: Path): resources = manager.list_resources() assert len(resources) == 2 assert resources == [resource1, resource2] + + @pytest.mark.anyio + async def test_get_resource_passes_context_to_template(self): + """Test that get_resource() passes context to template's create_resource().""" + received_context = None + + def func_with_context(name: str, ctx: Context[ServerSession, None]) -> str: + nonlocal received_context + received_context = ctx + return f"Hello {name}" + + manager = ResourceManager() + template = ResourceTemplate.from_function( + fn=func_with_context, + uri_template="greet://{name}", + name="greeter", + ) + manager._templates[template.uri_template] = template + + mcp = FastMCP() + ctx = mcp.get_context() + resource = await manager.get_resource(AnyUrl("greet://world"), context=ctx) + + assert received_context is ctx + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == "Hello world" diff --git a/tests/server/fastmcp/resources/test_resource_template.py b/tests/server/fastmcp/resources/test_resource_template.py index c910f8fa85..642f0b399d 100644 --- a/tests/server/fastmcp/resources/test_resource_template.py +++ b/tests/server/fastmcp/resources/test_resource_template.py @@ -4,8 +4,9 @@ import pytest from pydantic import BaseModel -from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.resources import FunctionResource, ResourceTemplate +from mcp.server.session import ServerSession from mcp.types import Annotations @@ -258,3 +259,65 @@ def get_item(item_id: str) -> str: # pragma: no cover # Verify the resource works correctly content = await resource.read() assert content == "Item 123" + + +class TestResourceTemplateContextHandling: + """Test context injection in ResourceTemplate.""" + + def test_template_context_kwarg_detection(self): + """Test that from_function() correctly detects context parameters.""" + + def func_with_context(name: str, ctx: Context[ServerSession, None]) -> str: # pragma: no cover + return f"Hello {name}" + + template = ResourceTemplate.from_function( + fn=func_with_context, + uri_template="test://{name}", + name="test", + ) + assert template.context_kwarg == "ctx" + + @pytest.mark.anyio + async def test_create_resource_with_context(self): + """Test that create_resource() passes context to function.""" + received_context = None + + def func_with_context(name: str, ctx: Context[ServerSession, None]) -> str: + nonlocal received_context + received_context = ctx + return f"Hello {name}" + + template = ResourceTemplate.from_function( + fn=func_with_context, + uri_template="test://{name}", + name="test", + ) + + mcp = FastMCP() + ctx = mcp.get_context() + resource = await template.create_resource("test://world", {"name": "world"}, context=ctx) + + assert received_context is ctx + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == "Hello world" + + @pytest.mark.anyio + async def test_template_without_context(self): + """Test that templates without context work normally.""" + + def func_without_context(name: str) -> str: + return f"Hello {name}" + + template = ResourceTemplate.from_function( + fn=func_without_context, + uri_template="test://{name}", + name="test", + ) + + assert template.context_kwarg is None + resource = await template.create_resource("test://world", {"name": "world"}) + + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == "Hello world"