Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/mcp/server/fastmcp/resources/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Base classes and interfaces for FastMCP resources."""

from __future__ import annotations

import abc
from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from pydantic import (
AnyUrl,
Expand All @@ -15,6 +17,11 @@

from mcp.types import Annotations, Icon

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT


class Resource(BaseModel, abc.ABC):
"""Base class for all resources."""
Expand Down Expand Up @@ -44,6 +51,9 @@ def set_default_name(cls, name: str | None, info: ValidationInfo) -> str:
raise ValueError("Either name or uri must be provided")

@abc.abstractmethod
async def read(self) -> str | bytes:
async def read(
self,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> str | bytes:
"""Read the resource content."""
pass # pragma: no cover
31 changes: 17 additions & 14 deletions src/mcp/server/fastmcp/resources/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

from __future__ import annotations

import inspect
import re
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel, Field, validate_call
from pydantic import BaseModel, Field

from mcp.server.fastmcp.resources.types import FunctionResource, Resource
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata, is_async_callable
from mcp.types import Annotations, Icon

if TYPE_CHECKING:
Expand All @@ -33,6 +32,10 @@ class ResourceTemplate(BaseModel):
fn: Callable[..., Any] = Field(exclude=True)
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
fn_metadata: FuncMetadata = Field(
description="Metadata about the function including a pydantic model for arguments"
)
is_async: bool = Field(description="Whether the function is async")

@classmethod
def from_function(
Expand All @@ -56,16 +59,15 @@ def from_function(
if context_kwarg is None: # pragma: no branch
context_kwarg = find_context_parameter(fn)

is_async = is_async_callable(fn)

# Get schema from func_metadata, excluding context parameter
func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
)
parameters = func_arg_metadata.arg_model.model_json_schema()

# ensure the arguments are properly cast
fn = validate_call(fn)

return cls(
uri_template=uri_template,
name=func_name,
Expand All @@ -77,6 +79,8 @@ def from_function(
fn=fn,
parameters=parameters,
context_kwarg=context_kwarg,
fn_metadata=func_arg_metadata,
is_async=is_async,
)

def matches(self, uri: str) -> dict[str, Any] | None:
Expand All @@ -96,13 +100,12 @@ async def create_resource(
) -> Resource:
"""Create a resource from the template with the given parameters."""
try:
# Add context to params if needed
params = inject_context(self.fn, params, context, self.context_kwarg)

# Call function and check if result is a coroutine
result = self.fn(**params)
if inspect.iscoroutine(result):
result = await result
result = await self.fn_metadata.call_fn_with_arg_validation(
self.fn,
self.is_async,
params,
{self.context_kwarg: context} if self.context_kwarg is not None else None,
)

return FunctionResource(
uri=uri, # type: ignore
Expand Down
61 changes: 47 additions & 14 deletions src/mcp/server/fastmcp/resources/types.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
"""Concrete resource implementations."""

from __future__ import annotations

import inspect
import json
from collections.abc import Callable
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import anyio
import anyio.to_thread
import httpx
import pydantic
import pydantic_core
from pydantic import AnyUrl, Field, ValidationInfo, validate_call
from pydantic import AnyUrl, Field, ValidationInfo

from mcp.server.fastmcp.resources.base import Resource
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
from mcp.types import Annotations, Icon

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT


class TextResource(Resource):
"""A resource that reads from a string."""

text: str = Field(description="Text content of the resource")

async def read(self) -> str:
async def read(
self,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> str:
"""Read the text content."""
return self.text # pragma: no cover

Expand All @@ -32,7 +43,10 @@ class BinaryResource(Resource):

data: bytes = Field(description="Binary content of the resource")

async def read(self) -> bytes:
async def read(
self,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> bytes:
"""Read the binary content."""
return self.data # pragma: no cover

Expand All @@ -50,13 +64,22 @@ class FunctionResource(Resource):
- other types will be converted to JSON
"""

fn: Callable[[], Any] = Field(exclude=True)
fn: Callable[..., Any] = Field(exclude=True)
context_kwarg: str | None = Field(default=None, description="Name of the kwarg that should receive context")

async def read(self) -> str | bytes:
async def read(
self,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> str | bytes:
"""Read the resource by calling the wrapped function."""
try:
# Call the function first to see if it returns a coroutine
result = self.fn()
# Inject context if needed
kwargs: dict[str, Any] = {}
if self.context_kwarg is not None and context is not None:
kwargs[self.context_kwarg] = context

# Call the function
result = self.fn(**kwargs)
# If it's a coroutine, await it
if inspect.iscoroutine(result):
result = await result
Expand All @@ -83,14 +106,14 @@ def from_function(
mime_type: str | None = None,
icons: list[Icon] | None = None,
annotations: Annotations | None = None,
) -> "FunctionResource":
) -> FunctionResource:
"""Create a FunctionResource from a function."""
func_name = name or fn.__name__
if func_name == "<lambda>": # pragma: no cover
raise ValueError("You must provide a name for lambda functions")

# ensure the arguments are properly cast
fn = validate_call(fn)
# Find context parameter if it exists
context_kwarg = find_context_parameter(fn)

return cls(
uri=AnyUrl(uri),
Expand All @@ -99,6 +122,7 @@ def from_function(
description=description or fn.__doc__ or "",
mime_type=mime_type or "text/plain",
fn=fn,
context_kwarg=context_kwarg,
icons=icons,
annotations=annotations,
)
Expand Down Expand Up @@ -137,7 +161,10 @@ def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> boo
mime_type = info.data.get("mime_type", "text/plain")
return not mime_type.startswith("text/")

async def read(self) -> str | bytes:
async def read(
self,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> str | bytes:
"""Read the file content."""
try:
if self.is_binary:
Expand All @@ -153,7 +180,10 @@ class HttpResource(Resource):
url: str = Field(description="URL to fetch content from")
mime_type: str = Field(default="application/json", description="MIME type of the resource content")

async def read(self) -> str | bytes:
async def read(
self,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> str | bytes:
"""Read the HTTP content."""
async with httpx.AsyncClient() as client: # pragma: no cover
response = await client.get(self.url)
Expand Down Expand Up @@ -191,7 +221,10 @@ def list_files(self) -> list[Path]: # pragma: no cover
except Exception as e:
raise ValueError(f"Error listing directory {self.path}: {e}")

async def read(self) -> str: # Always returns JSON string # pragma: no cover
async def read(
self,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> str: # Always returns JSON string # pragma: no cover
"""Read the directory listing."""
try:
files = await anyio.to_thread.run_sync(self.list_files)
Expand Down
23 changes: 10 additions & 13 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent
raise ResourceError(f"Unknown resource: {uri}")

try:
content = await resource.read()
content = await resource.read(context=context)
return [ReadResourceContents(content=content, mime_type=resource.mime_type)]
except Exception as e: # pragma: no cover
logger.exception(f"Error reading resource {uri}")
Expand Down Expand Up @@ -571,21 +571,18 @@ async def get_weather(city: str) -> str:
)

def decorator(fn: AnyFunction) -> AnyFunction:
# Check if this should be a template
# Extract signature and parameters
sig = inspect.signature(fn)
has_uri_params = "{" in uri and "}" in uri
has_func_params = bool(sig.parameters)
uri_params = set(re.findall(r"{(\w+)}", uri))
context_param = find_context_parameter(fn)
func_params = {p for p in sig.parameters.keys() if p != context_param}

if has_uri_params or has_func_params:
# Check for Context parameter to exclude from validation
context_param = find_context_parameter(fn)

# Validate that URI params match function params (excluding context)
uri_params = set(re.findall(r"{(\w+)}", uri))
# We need to remove the context_param from the resource function if
# there is any.
func_params = {p for p in sig.parameters.keys() if p != context_param}
# Determine if this should be a template
has_uri_params = len(uri_params) != 0
has_func_params = len(func_params) != 0

if has_uri_params or has_func_params:
# Validate that URI params match function params
if uri_params != func_params:
raise ValueError(
f"Mismatch between URI parameters {uri_params} and function parameters {func_params}"
Expand Down
15 changes: 2 additions & 13 deletions src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations as _annotations

import functools
import inspect
from collections.abc import Callable
from functools import cached_property
from typing import TYPE_CHECKING, Any
Expand All @@ -10,7 +8,7 @@

from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata, is_async_callable
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
from mcp.types import Icon, ToolAnnotations

Expand Down Expand Up @@ -63,7 +61,7 @@ def from_function(
raise ValueError("You must provide a name for lambda functions")

func_doc = description or fn.__doc__ or ""
is_async = _is_async_callable(fn)
is_async = is_async_callable(fn)

if context_kwarg is None: # pragma: no branch
context_kwarg = find_context_parameter(fn)
Expand Down Expand Up @@ -110,12 +108,3 @@ async def run(
return result
except Exception as e:
raise ToolError(f"Error executing tool {self.name}: {e}") from e


def _is_async_callable(obj: Any) -> bool:
while isinstance(obj, functools.partial): # pragma: no cover
obj = obj.func

return inspect.iscoroutinefunction(obj) or (
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
)
11 changes: 11 additions & 0 deletions src/mcp/server/fastmcp/utilities/func_metadata.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import inspect
import json
from collections.abc import Awaitable, Callable, Sequence
Expand Down Expand Up @@ -531,3 +532,13 @@ def _convert_to_content(
result = pydantic_core.to_json(result, fallback=str, indent=2).decode()

return [TextContent(type="text", text=result)]


def is_async_callable(obj: Any) -> bool:
"""Check if an object is an async callable."""
while isinstance(obj, functools.partial): # pragma: no cover
obj = obj.func

return inspect.iscoroutinefunction(obj) or (
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
)
Loading