diff --git a/pyproject.toml b/pyproject.toml index 3928cd8..23bcbd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mcpcat" -version = "0.1.13" +version = "0.1.14" description = "Analytics Tool for MCP Servers - provides insights into MCP tool usage patterns" authors = [ { name = "MCPCat", email = "support@mcpcat.io" }, diff --git a/src/mcpcat/modules/overrides/community/tool_manager.py b/src/mcpcat/modules/overrides/community/tool_manager.py index e56dd60..5e6d328 100644 --- a/src/mcpcat/modules/overrides/community/tool_manager.py +++ b/src/mcpcat/modules/overrides/community/tool_manager.py @@ -15,6 +15,33 @@ from fastmcp import FastMCP +def _ensure_context_parameter(tool: Any, description: str) -> None: + """Add or overwrite the 'context' parameter in a tool's schema. + + Ensures the tool has a valid parameters dict with a 'context' property + marked as required. + """ + if not hasattr(tool, "parameters") or not tool.parameters: + tool.parameters = {"type": "object", "properties": {}, "required": []} + + if "properties" not in tool.parameters: + tool.parameters["properties"] = {} + + tool.parameters["properties"]["context"] = { + "type": "string", + "description": description, + } + + if "required" not in tool.parameters: + tool.parameters["required"] = [] + + if isinstance(tool.parameters["required"], list): + if "context" not in tool.parameters["required"]: + tool.parameters["required"].append("context") + else: + tool.parameters["required"] = ["context"] + + def patch_community_fastmcp_tool_manager(server: Any) -> None: """Patch the community FastMCP tool manager to add MCPCat tracking. @@ -39,24 +66,25 @@ def patch_community_fastmcp_tool_manager(server: Any) -> None: # Add get_more_tools if enabled if data.options.enable_report_missing: try: - # Create the get_more_tools function that returns the proper format - async def get_more_tools(context: str | None = "") -> str: + async def get_more_tools(context: str) -> str: """Check for additional tools whenever your task might benefit from specialized capabilities.""" - # Handle None values - if context is None: - context = "" result = await handle_report_missing({"context": context}) - # Return just the text content for community FastMCP - if result.content and len(result.content) > 0: + if result.content: return result.content[0].text return "No additional tools available" - # Register it with the server server.tool( get_more_tools, name="get_more_tools", description="Check for additional tools whenever your task might benefit from specialized capabilities - even if existing tools could work as a fallback.", ) + + # Force the correct schema - Pydantic's TypeAdapter can mangle + # the type on async closures into anyOf: [string, null] + from mcpcat.modules.tools import GET_MORE_TOOLS_SCHEMA + if hasattr(server._tool_manager, "_tools") and "get_more_tools" in server._tool_manager._tools: + server._tool_manager._tools["get_more_tools"].parameters = GET_MORE_TOOLS_SCHEMA + write_to_log("Added get_more_tools tool to community FastMCP server") except Exception as e: write_to_log(f"Error adding get_more_tools: {e}") @@ -86,35 +114,10 @@ def patch_existing_tools(server: FastMCP) -> None: return for tool_name, tool in tool_manager._tools.items(): - # Skip get_more_tools if tool_name == "get_more_tools": continue - # Ensure tool has parameters - if not hasattr(tool, "parameters"): - tool.parameters = {"type": "object", "properties": {}, "required": []} - elif not tool.parameters: - tool.parameters = {"type": "object", "properties": {}, "required": []} - - # Ensure properties exists - if "properties" not in tool.parameters: - tool.parameters["properties"] = {} - - # Always overwrite the context property with MCPCat's version - tool.parameters["properties"]["context"] = { - "type": "string", - "description": data.options.custom_context_description, - } - - # Add to required array - if "required" not in tool.parameters: - tool.parameters["required"] = [] - if isinstance(tool.parameters["required"], list): - if "context" not in tool.parameters["required"]: - tool.parameters["required"].append("context") - else: - tool.parameters["required"] = ["context"] - + _ensure_context_parameter(tool, data.options.custom_context_description) write_to_log(f"Added/updated context parameter for existing tool: {tool_name}") except Exception as e: @@ -153,34 +156,9 @@ def patched_add_tool(tool: Any) -> Any: # Add context parameter if it's not get_more_tools if tool_name != "get_more_tools": - # Get tracking data to check if context injection is enabled data = get_server_tracking_data(server._mcp_server) if data and data.options.enable_tool_call_context: - # Ensure tool has parameters - if not hasattr(tool, "parameters"): - tool.parameters = {"type": "object", "properties": {}, "required": []} - elif not tool.parameters: - tool.parameters = {"type": "object", "properties": {}, "required": []} - - # Ensure properties exists - if "properties" not in tool.parameters: - tool.parameters["properties"] = {} - - # Always overwrite the context property with MCPCat's version - tool.parameters["properties"]["context"] = { - "type": "string", - "description": data.options.custom_context_description - } - - # Add to required array - if "required" not in tool.parameters: - tool.parameters["required"] = [] - if isinstance(tool.parameters["required"], list): - if "context" not in tool.parameters["required"]: - tool.parameters["required"].append("context") - else: - tool.parameters["required"] = ["context"] - + _ensure_context_parameter(tool, data.options.custom_context_description) write_to_log(f"Added/updated context parameter for new tool: {tool_name}") return result diff --git a/src/mcpcat/modules/overrides/community_v3/integration.py b/src/mcpcat/modules/overrides/community_v3/integration.py index de85e08..3806d23 100644 --- a/src/mcpcat/modules/overrides/community_v3/integration.py +++ b/src/mcpcat/modules/overrides/community_v3/integration.py @@ -4,9 +4,9 @@ FastMCP v3 servers using the middleware system. """ -from __future__ import annotations +from typing import Annotated, Any -from typing import Any +from pydantic import Field from mcpcat.modules.logging import write_to_log from mcpcat.modules.overrides.community_v3.middleware import MCPCatMiddleware @@ -57,62 +57,38 @@ def _register_get_more_tools_v3(server: Any, mcpcat_data: MCPCatData) -> None: server: A Community FastMCP v3 server instance. mcpcat_data: MCPCat tracking configuration. """ + from fastmcp.tools.tool import Tool + from mcpcat.modules.tools import handle_report_missing # Define the get_more_tools function - async def get_more_tools(context: str | None = None) -> str: - """Check for additional tools when your task might benefit from them. - - Args: - context: A description of your goal and what kind of tool would help. - - Returns: - A response message indicating the result. - """ - # Handle None values - context_str = context if context is not None else "" - - result = await handle_report_missing({"context": context_str}) - - # Return text content for FastMCP v3 - # The result.content is a list of TextContent objects - if result.content and len(result.content) > 0: - content_item = result.content[0] - if hasattr(content_item, "text"): - return content_item.text - + async def get_more_tools( + context: Annotated[ + str, + Field( + description="A description of your goal and what kind of tool would help accomplish it." + ), + ], + ) -> str: + """Check for additional tools when your task might benefit from them.""" + result = await handle_report_missing({"context": context}) + + if result.content and hasattr(result.content[0], "text"): + return result.content[0].text return "No additional tools available." try: - # Note: We don't check if get_more_tools already exists because - # FastMCP v3's list_tools is async and we're in a sync context. - # The tool decorator handles duplicates gracefully. - - get_more_tools_desc = ( - "Check for additional tools whenever your task might benefit from " - "specialized capabilities - even if existing tools could work as a " - "fallback." + tool = Tool.from_function( + get_more_tools, + name="get_more_tools", + description=( + "Check for additional tools whenever your task might benefit from " + "specialized capabilities - even if existing tools could work as a " + "fallback." + ), ) - - # Register the tool using the server's tool decorator or add_tool method - if hasattr(server, "tool"): - server.tool( - name="get_more_tools", - description=get_more_tools_desc, - )(get_more_tools) - write_to_log("Registered get_more_tools using server.tool() decorator") - elif hasattr(server, "add_tool"): - from fastmcp.tools.tool import Tool - - tool = Tool.from_function( - get_more_tools, - name="get_more_tools", - description=get_more_tools_desc, - ) - server.add_tool(tool) - write_to_log("Registered get_more_tools using server.add_tool()") - else: - write_to_log("Warning: Could not find method to register get_more_tools") + server.add_tool(tool) + write_to_log("Registered get_more_tools using server.add_tool()") except Exception as e: write_to_log(f"Error registering get_more_tools: {e}") diff --git a/src/mcpcat/modules/overrides/official/monkey_patch.py b/src/mcpcat/modules/overrides/official/monkey_patch.py index eabe4a6..b8aa1ec 100644 --- a/src/mcpcat/modules/overrides/official/monkey_patch.py +++ b/src/mcpcat/modules/overrides/official/monkey_patch.py @@ -7,7 +7,9 @@ import inspect from collections.abc import Callable from datetime import datetime, timezone -from typing import Any, List +from typing import Annotated, Any, List + +from pydantic import Field from mcpcat.modules import event_queue from mcpcat.modules.compatibility import is_official_fastmcp_server, is_mcp_error_response @@ -64,13 +66,17 @@ def patch_fastmcp_tool_manager(server: Any, mcpcat_data: MCPCatData) -> bool: # Add the get_more_tools tool if enabled if mcpcat_data.options.enable_report_missing: # Create the get_more_tools function that returns CallToolResult - async def get_more_tools(context: str | None = "") -> List[Any]: + async def get_more_tools( + context: Annotated[ + str, + Field( + description="A description of your goal and what kind of tool would help accomplish it." + ), + ], + ) -> List[Any]: """Check for additional tools whenever your task might benefit from specialized capabilities.""" from mcpcat.modules.tools import handle_report_missing - # Handle None values - if context is None: - context = "" result = await handle_report_missing({"context": context}) # Return just the content list for FastMCP return result.content @@ -261,11 +267,11 @@ async def patched_call_tool( # Extract user intent (non-critical) user_intent = None try: - if ( - current_data - and current_data.options.enable_tool_call_context - and name != "get_more_tools" - ): + should_capture_intent = ( + name == "get_more_tools" + or (current_data and current_data.options.enable_tool_call_context) + ) + if should_capture_intent: user_intent = arguments.get("context", None) except Exception as e: write_to_log(f"Error extracting user intent: {e}") diff --git a/src/mcpcat/modules/tools.py b/src/mcpcat/modules/tools.py index 5b9f9bc..e26e935 100644 --- a/src/mcpcat/modules/tools.py +++ b/src/mcpcat/modules/tools.py @@ -6,6 +6,21 @@ from .logging import write_to_log +# Correct schema for the get_more_tools tool parameter. +# Defined explicitly because Pydantic's TypeAdapter generates a broken schema +# (anyOf: [string, null], default: "") for Annotated[str, Field(description=...)] +# on async closure functions used by Tool.from_function(). +GET_MORE_TOOLS_SCHEMA = { + "type": "object", + "properties": { + "context": { + "type": "string", + "description": "A description of your goal and what kind of tool would help accomplish it.", + } + }, + "required": ["context"], +} + if TYPE_CHECKING or has_fastmcp_support(): try: from mcp.server import FastMCP @@ -19,7 +34,7 @@ async def handle_report_missing(arguments: dict[str, Any]) -> CallToolResult: content=[ TextContent( type="text", - text=f"Unfortunately, we have shown you the full tool list. We have noted your feedback and will work to improve the tool list in the future.", + text="Unfortunately, we have shown you the full tool list. We have noted your feedback and will work to improve the tool list in the future.", ) ] ) diff --git a/src/mcpcat/modules/truncation.py b/src/mcpcat/modules/truncation.py index 5ce57f5..2c43da2 100644 --- a/src/mcpcat/modules/truncation.py +++ b/src/mcpcat/modules/truncation.py @@ -7,7 +7,7 @@ """ from datetime import date, datetime -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING if TYPE_CHECKING: from mcpcat.types import UnredactedEvent @@ -18,6 +18,10 @@ MAX_STRING_BYTES = 10_240 # 10 KB per individual string MAX_DEPTH = 5 # Max nesting depth for dicts/lists MAX_BREADTH = 500 # Max items per dict/list +MIN_DEPTH = 1 # Never reduce depth below this to avoid type mismatches + +# Only these fields get truncated; all other top-level event fields pass through untouched. +TRUNCATABLE_FIELDS = {"parameters", "response", "error", "identify_data", "user_intent", "additional_properties"} def _truncate_string(value: str, max_bytes: int = MAX_STRING_BYTES) -> str: @@ -42,14 +46,12 @@ def _truncate_value( *, max_depth: int = MAX_DEPTH, max_string_bytes: int = MAX_STRING_BYTES, + max_breadth: int = MAX_BREADTH, _depth: int = 0, _seen: set[int] | None = None, ) -> Any: """Recursively walk a value and apply truncation limits.""" - if value is None or isinstance(value, (bool, int, float)): - return value - - if isinstance(value, (datetime, date)): + if value is None or isinstance(value, (bool, int, float, datetime, date)): return value if isinstance(value, str): @@ -70,24 +72,18 @@ def _truncate_value( items = list(value.items()) result = {} for i, (k, v) in enumerate(items): - if i >= MAX_BREADTH: - remaining = len(items) - MAX_BREADTH + if i >= max_breadth: + remaining = len(items) - max_breadth result["__truncated__"] = ( f"[... {remaining} more items truncated by MCPcat]" ) break - if at_depth_limit: - result[str(k)] = ( - f"[nested content truncated by MCPcat at depth {max_depth}]" - if isinstance(v, (dict, list, tuple)) - else _truncate_value( - v, max_depth=max_depth, max_string_bytes=max_string_bytes, - _depth=_depth + 1, _seen=_seen, - ) - ) + if at_depth_limit and isinstance(v, (dict, list, tuple)): + result[str(k)] = f"[nested content truncated by MCPcat at depth {max_depth}]" else: result[str(k)] = _truncate_value( v, max_depth=max_depth, max_string_bytes=max_string_bytes, + max_breadth=max_breadth, _depth=_depth + 1, _seen=_seen, ) return result @@ -98,13 +94,14 @@ def _truncate_value( result_list = [ _truncate_value( item, max_depth=max_depth, max_string_bytes=max_string_bytes, + max_breadth=max_breadth, _depth=_depth + 1, _seen=_seen, ) for i, item in enumerate(value) - if i < MAX_BREADTH + if i < max_breadth ] - if len(value) > MAX_BREADTH: - remaining = len(value) - MAX_BREADTH + if len(value) > max_breadth: + remaining = len(value) - max_breadth result_list.append( f"[... {remaining} more items truncated by MCPcat]" ) @@ -119,19 +116,21 @@ def _truncate_value( _seen.discard(obj_id) -def truncate_event(event: Optional["UnredactedEvent"]) -> Optional["UnredactedEvent"]: +def truncate_event(event: "UnredactedEvent | None") -> "UnredactedEvent | None": """Return a truncated copy of *event* if it exceeds MAX_EVENT_BYTES. Uses size-targeted normalization strategy: normalize with the current limits, check JSON byte size, and if still over the limit tighten limits and re-normalize until it fits. - Each pass reduces depth by 1 and halves the per-string byte limit. + Each pass halves the per-string byte limit and (once MIN_DEPTH is reached) + reduces breadth. Depth never goes below MIN_DEPTH to avoid replacing + dict-typed fields with string markers that fail model validation. - Checks serialized JSON byte size first (fast path) - Never mutates the original event - Returns original event unchanged if under limit - - Returns original event unchanged if truncation itself fails + - Returns last valid truncated candidate if loop exhausts limits """ if event is None: return None @@ -147,26 +146,41 @@ def truncate_event(event: Optional["UnredactedEvent"]) -> Optional["UnredactedEv f"({byte_size} bytes), truncating" ) - truncated_dict = event.model_dump() + event_cls = type(event) depth = MAX_DEPTH string_bytes = MAX_STRING_BYTES - - event_cls = type(event) - - while depth >= 0: - truncated_dict = _truncate_value( - truncated_dict, max_depth=depth, max_string_bytes=string_bytes, - ) - candidate = event_cls.model_validate(truncated_dict) + breadth = MAX_BREADTH + candidate = None + + while string_bytes >= 1: + # Always start from a fresh dump to avoid compounding artifacts + event_dict = event.model_dump() + for field_name in TRUNCATABLE_FIELDS: + if field_name in event_dict and event_dict[field_name] is not None: + if isinstance(event_dict[field_name], str): + event_dict[field_name] = _truncate_string(event_dict[field_name], max_bytes=string_bytes) + else: + event_dict[field_name] = _truncate_value( + event_dict[field_name], + max_depth=depth, + max_string_bytes=string_bytes, + max_breadth=breadth, + ) + candidate = event_cls.model_validate(event_dict) result_bytes = len(candidate.model_dump_json().encode("utf-8")) if result_bytes <= MAX_EVENT_BYTES: return candidate write_to_log( f"Event still {result_bytes} bytes at depth={depth} " - f"string_limit={string_bytes}, tightening limits" + f"string_limit={string_bytes} breadth={breadth}, tightening limits" ) - depth -= 1 + # Tighten: reduce depth (down to MIN_DEPTH), halve string limit + if depth > MIN_DEPTH: + depth -= 1 string_bytes //= 2 + # Breadth reduction as fallback once depth is at minimum + if depth <= MIN_DEPTH and breadth > 1: + breadth //= 2 return candidate diff --git a/tests/community/test_community_dynamic_tracking.py b/tests/community/test_community_dynamic_tracking.py index 539e460..e30bfbd 100644 --- a/tests/community/test_community_dynamic_tracking.py +++ b/tests/community/test_community_dynamic_tracking.py @@ -297,9 +297,10 @@ async def test_report_missing_tool_with_dynamic_tracking(self): result2 = await client.call_tool("get_more_tools", {"context": ""}) assert "Unfortunately" in str(result2) - # Test with missing context parameter - result3 = await client.call_tool("get_more_tools", {}) - assert "Unfortunately" in str(result3) + # Test with missing context parameter - should raise ToolError + # since context is a required parameter + with pytest.raises(Exception, match="(?i)required"): + await client.call_tool("get_more_tools", {}) # List tools tools = await get_server_tools(server) diff --git a/tests/community/test_community_report_missing.py b/tests/community/test_community_report_missing.py index 53ff418..e6570ba 100644 --- a/tests/community/test_community_report_missing.py +++ b/tests/community/test_community_report_missing.py @@ -47,6 +47,36 @@ async def test_report_missing_tool_injection(self): # Verify report_missing tool was injected assert "get_more_tools" in tool_names + @pytest.mark.asyncio + async def test_report_missing_tool_schema(self): + """Test that get_more_tools has context as a required string, not anyOf.""" + server = create_community_todo_server() + options = MCPCatOptions(enable_report_missing=True) + track(server, "test_project", options) + + async with create_community_test_client(server) as client: + tools_result = await client.list_tools() + + get_more_tools = next(t for t in tools_result if t.name == "get_more_tools") + schema = get_more_tools.inputSchema + + # context must be a simple {"type": "string"}, not anyOf/oneOf + context_prop = schema["properties"]["context"] + assert context_prop["type"] == "string", ( + f"Expected context type 'string', got: {context_prop}" + ) + assert "anyOf" not in context_prop, ( + f"context should not use anyOf: {context_prop}" + ) + assert "default" not in context_prop, ( + f"context should not have a default: {context_prop}" + ) + + # context must be required + assert "context" in schema.get("required", []), ( + f"context should be required, got required={schema.get('required')}" + ) + @pytest.mark.asyncio async def test_report_missing_disabled_by_default(self): """Verify tool is NOT injected when enable_report_missing=False.""" @@ -110,12 +140,12 @@ async def test_report_missing_with_missing_params(self): track(server, "test_project", options) async with create_community_test_client(server) as client: - # Test with missing parameters - should still work but with empty strings - result = await client.call_tool("get_more_tools", {}) - result_str = str(result) - assert "Unfortunately" in result_str + # Test with missing context - should raise a validation error + # since context is a required parameter + with pytest.raises(Exception, match="(?i)required"): + await client.call_tool("get_more_tools", {}) - # Test with only one parameter + # Test with valid context result = await client.call_tool("get_more_tools", {"context": "test_tool"}) result_str = str(result) assert "Unfortunately" in result_str @@ -204,11 +234,10 @@ async def test_report_missing_with_null_values(self): track(server, "test_project", options) async with create_community_test_client(server) as client: - # Test with None values - they should be treated as empty strings - result = await client.call_tool("get_more_tools", {"context": None}) - # Should still return a valid response - result_str = str(result) - assert "Unfortunately" in result_str + # Test with None context - should raise a validation error + # since context is required as a string + with pytest.raises(Exception, match="(?i)string"): + await client.call_tool("get_more_tools", {"context": None}) @pytest.mark.asyncio async def test_report_missing_publishes_event(self): diff --git a/tests/test_dynamic_tracking.py b/tests/test_dynamic_tracking.py index be50b73..4d0bbb1 100644 --- a/tests/test_dynamic_tracking.py +++ b/tests/test_dynamic_tracking.py @@ -259,10 +259,10 @@ async def test_report_missing_tool_with_dynamic_tracking(self, fastmcp_server): result2_text = result2[0].text if result2 else "" assert "Unfortunately" in result2_text, f"Expected 'Unfortunately' in result, got: {result2_text}" - # Test with missing context parameter - result3, _ = await fastmcp_server.call_tool("get_more_tools", {}) - result3_text = result3[0].text if result3 else "" - assert "Unfortunately" in result3_text, f"Expected 'Unfortunately' in result, got: {result3_text}" + # Test with missing context parameter - should raise validation error + # since context is a required parameter + with pytest.raises(Exception, match="(?i)required"): + await fastmcp_server.call_tool("get_more_tools", {}) @pytest.mark.asyncio async def test_lowlevel_server_dynamic_tracking(self, lowlevel_server): diff --git a/tests/test_report_missing.py b/tests/test_report_missing.py index 99fbceb..cbe5e89 100644 --- a/tests/test_report_missing.py +++ b/tests/test_report_missing.py @@ -110,12 +110,13 @@ async def test_report_missing_with_missing_params(self): track(server, "test_project", options) async with create_test_client(server) as client: - # Test with missing parameters - should still work but with empty strings + # Test with missing context - should return a validation error + # since context is a required parameter result = await client.call_tool("get_more_tools", {}) assert result.content[0].text - assert "Unfortunately" in result.content[0].text + assert result.isError is True - # Test with only one parameter + # Test with valid context result = await client.call_tool("get_more_tools", {"context": "test_tool"}) assert result.content[0].text assert "Unfortunately" in result.content[0].text @@ -236,11 +237,11 @@ async def test_report_missing_with_null_values(self): track(server, "test_project", options) async with create_test_client(server) as client: - # Test with None values - they should be treated as empty strings + # Test with None context - should return a validation error + # since context is required as a string result = await client.call_tool("get_more_tools", {"context": None}) - # Should still return a valid response assert result.content[0].text - assert "Unfortunately" in result.content[0].text + assert result.isError is True @pytest.mark.asyncio async def test_report_missing_publishes_event(self): diff --git a/tests/test_tool_context.py b/tests/test_tool_context.py index 15974bd..e17016e 100644 --- a/tests/test_tool_context.py +++ b/tests/test_tool_context.py @@ -1,5 +1,8 @@ """Test tool context functionality.""" +import time +from unittest.mock import MagicMock + import pytest from mcp.server import Server from mcp.server.fastmcp import FastMCP @@ -7,6 +10,7 @@ from mcpcat import MCPCatOptions, track from mcpcat.modules.constants import DEFAULT_CONTEXT_DESCRIPTION +from mcpcat.modules.event_queue import EventQueue, set_event_queue from .test_utils.client import create_test_client from .test_utils.todo_server import create_todo_server @@ -803,3 +807,135 @@ async def test_custom_context_with_tool_call(self): # Should succeed assert result.content assert "Added todo" in result.content[0].text + + +class TestGetMoreToolsContextSchema: + """Test that get_more_tools has a proper context parameter schema.""" + + @pytest.mark.asyncio + async def test_get_more_tools_context_has_string_type(self): + """get_more_tools context parameter should have type 'string', not a union type.""" + server = create_todo_server() + options = MCPCatOptions( + enable_report_missing=True, + enable_tool_call_context=True, + ) + track(server, "test_project", options) + + async with create_test_client(server) as client: + tools_result = await client.list_tools() + + get_more_tools_tool = next( + t for t in tools_result.tools if t.name == "get_more_tools" + ) + + context_schema = get_more_tools_tool.inputSchema["properties"]["context"] + # Context should be a simple string type, not a union/anyOf + assert "anyOf" not in context_schema, ( + f"get_more_tools context should not use anyOf (union type), got: {context_schema}" + ) + assert context_schema.get("type") == "string", ( + f"get_more_tools context should be type 'string', got: {context_schema}" + ) + + @pytest.mark.asyncio + async def test_get_more_tools_context_has_description(self): + """get_more_tools context parameter should have a meaningful description.""" + server = create_todo_server() + options = MCPCatOptions( + enable_report_missing=True, + enable_tool_call_context=True, + ) + track(server, "test_project", options) + + async with create_test_client(server) as client: + tools_result = await client.list_tools() + + get_more_tools_tool = next( + t for t in tools_result.tools if t.name == "get_more_tools" + ) + + context_schema = get_more_tools_tool.inputSchema["properties"]["context"] + assert "description" in context_schema, ( + "get_more_tools context parameter should have a description" + ) + assert len(context_schema["description"]) > 10, ( + "get_more_tools context description should be meaningful" + ) + + @pytest.mark.asyncio + async def test_get_more_tools_context_is_required(self): + """get_more_tools context parameter should be required.""" + server = create_todo_server() + options = MCPCatOptions( + enable_report_missing=True, + enable_tool_call_context=True, + ) + track(server, "test_project", options) + + async with create_test_client(server) as client: + tools_result = await client.list_tools() + + get_more_tools_tool = next( + t for t in tools_result.tools if t.name == "get_more_tools" + ) + + required = get_more_tools_tool.inputSchema.get("required", []) + assert "context" in required, ( + "get_more_tools context parameter should be required" + ) + + +class TestUserIntentCaptureInEvents: + """Test that user_intent is captured in events for the official FastMCP monkey patching.""" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + """Set up and tear down mock event queue.""" + from mcpcat.modules.event_queue import event_queue as original_queue + + yield + set_event_queue(original_queue) + + @pytest.mark.asyncio + async def test_get_more_tools_captures_user_intent_in_event(self): + """get_more_tools calls should capture user_intent from context in the event.""" + mock_api_client = MagicMock() + captured_events = [] + + def capture_event(publish_event_request): + captured_events.append(publish_event_request) + + mock_api_client.publish_event = MagicMock(side_effect=capture_event) + + test_queue = EventQueue(api_client=mock_api_client) + set_event_queue(test_queue) + + server = create_todo_server() + options = MCPCatOptions( + enable_tracing=True, + enable_report_missing=True, + enable_tool_call_context=True, + ) + track(server, "test_project", options) + + async with create_test_client(server) as client: + # Call get_more_tools with context + await client.call_tool( + "get_more_tools", + {"context": "I need a tool to send emails"}, + ) + time.sleep(1.0) + + tool_events = [ + e + for e in captured_events + if e.event_type == "mcp:tools/call" + and e.resource_name == "get_more_tools" + ] + assert len(tool_events) > 0, "No get_more_tools event captured" + + event = tool_events[0] + assert event.user_intent == "I need a tool to send emails", ( + f"user_intent should be captured from get_more_tools context, got: {event.user_intent}" + ) diff --git a/tests/test_truncation.py b/tests/test_truncation.py index 22ed364..d19d47f 100644 --- a/tests/test_truncation.py +++ b/tests/test_truncation.py @@ -14,6 +14,8 @@ MAX_DEPTH, MAX_BREADTH, MAX_EVENT_BYTES, + MIN_DEPTH, + TRUNCATABLE_FIELDS, ) from mcpcat.types import UnredactedEvent @@ -234,6 +236,34 @@ def test_depth_reduces_progressively(self): result_bytes = len(result.model_dump_json().encode("utf-8")) assert result_bytes <= MAX_EVENT_BYTES + def test_1mb_single_string_under_limit(self): + """A single 1 MB string is truncated to fit.""" + big = "x" * 1_048_576 + event = _make_event(parameters={"data": big}) + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + + def test_multiple_1mb_strings_under_limit(self): + """Multiple 1 MB strings across fields all fit after truncation.""" + big = "x" * 1_048_576 + event = _make_event( + user_intent=big, + parameters={"a": big, "b": big}, + response={"out": big}, + ) + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + + def test_extreme_breadth_1000_keys_under_limit(self): + """1000 keys with moderate values exercises breadth reduction.""" + params = {f"key_{i}": "x" * 500 for i in range(1000)} + event = _make_event(parameters=params) + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + class TestTruncateEventErrorHandling: """Truncation failures return the original event.""" @@ -349,3 +379,181 @@ async def test_oversized_response_is_truncated(self): event = list_events[0] event_bytes = len(event.model_dump_json().encode("utf-8")) assert event_bytes <= MAX_EVENT_BYTES + + +class TestMegabyteStrings: + """1 MB strings in various fields are truncated to fit under the limit.""" + + ONE_MB = "x" * 1_048_576 # 1 MB + + def test_1mb_user_intent(self): + event = _make_event(user_intent=self.ONE_MB) + result = truncate_event(event) + assert result is not event + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + assert "truncated by MCPcat" in result.user_intent + + def test_1mb_in_parameters(self): + event = _make_event(parameters={"context": self.ONE_MB}) + result = truncate_event(event) + assert result is not event + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + assert "truncated by MCPcat" in result.parameters["context"] + + def test_1mb_in_response(self): + event = _make_event(response={"output": self.ONE_MB}) + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + assert "truncated by MCPcat" in result.response["output"] + + def test_1mb_in_error(self): + event = _make_event(error={"message": "fail", "stack": self.ONE_MB}) + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + assert "truncated by MCPcat" in result.error["stack"] + + def test_1mb_in_all_fields_simultaneously(self): + event = _make_event( + user_intent=self.ONE_MB, + parameters={"context": self.ONE_MB}, + response={"output": self.ONE_MB}, + error={"message": "fail", "stack": self.ONE_MB}, + ) + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + + +class TestManyKeysRegression: + """Regression tests for the depth=0 crash bug. + + Events with many moderate-sized keys used to cause depth to reach 0, + which replaced dict-typed fields with string markers and crashed + model_validate(). The fix keeps depth >= MIN_DEPTH and uses breadth + reduction as a fallback. + """ + + def test_500_keys_x_50kb_stays_under_limit(self): + """500 keys * 50 KB = ~25 MB raw — exercises aggressive truncation.""" + params = {f"key_{i}": "x" * 50_000 for i in range(500)} + event = _make_event(parameters=params) + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + + def test_200_keys_x_1kb_stays_under_limit(self): + """200 keys * 1 KB = 200 KB — just over the limit, previously crashed.""" + params = {f"key_{i}": "x" * 1_000 for i in range(200)} + event = _make_event(parameters=params) + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + + def test_200_keys_x_10kb_stays_under_limit(self): + """200 keys * 10 KB = 2 MB — needs multiple passes.""" + params = {f"key_{i}": "x" * 10_000 for i in range(200)} + event = _make_event(parameters=params) + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + + def test_dict_fields_remain_dicts_after_truncation(self): + """Verify parameters/response/error/identify_data stay as dicts, not strings.""" + params = {f"key_{i}": "x" * 1_000 for i in range(200)} + event = _make_event( + parameters=params, + response={"output": "x" * 50_000}, + error={"message": "fail", "stack": "x" * 50_000}, + identify_data={"bio": "x" * 50_000}, + ) + result = truncate_event(event) + assert isinstance(result.parameters, dict), "parameters should remain a dict" + assert isinstance(result.response, dict), "response should remain a dict" + assert isinstance(result.error, dict), "error should remain a dict" + assert isinstance(result.identify_data, dict), "identify_data should remain a dict" + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + + def test_many_keys_across_multiple_dict_fields(self): + """Many keys spread across parameters + response + error.""" + params = {f"p_{i}": "x" * 2_000 for i in range(100)} + resp = {f"r_{i}": "x" * 2_000 for i in range(100)} + err = {f"e_{i}": "x" * 2_000 for i in range(100)} + event = _make_event(parameters=params, response=resp, error=err) + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + assert isinstance(result.parameters, dict) + assert isinstance(result.response, dict) + assert isinstance(result.error, dict) + + def test_top_level_fields_not_dropped_under_extreme_key_pressure(self): + """Top-level event metadata should survive aggressive truncation.""" + long_key = "k" * 20_000 + params = {f"{long_key}{i}": "x" for i in range(20)} + event = _make_event( + event_type="mcp:tools/call", + resource_name="test_tool", + session_id="test-session-id", + parameters=params, + ) + + result = truncate_event(event) + result_bytes = len(result.model_dump_json().encode("utf-8")) + + assert result_bytes <= MAX_EVENT_BYTES + assert result.event_type == "mcp:tools/call" + assert result.resource_name == "test_tool" + assert isinstance(result.parameters, dict) + assert len(result.parameters) > 0 + + +class TestMetadataProtection: + """Top-level metadata fields must never be truncated, even under extreme payload pressure.""" + + def test_metadata_fields_preserved_when_payload_forces_extreme_truncation(self): + """Top-level metadata strings must never be truncated, even with huge payloads.""" + event = _make_event( + event_type="mcp:tools/call", + resource_name="my_important_tool", + session_id="sess-12345", + actor_id="actor-67890", + user_intent="short intent", + parameters={"data": "x" * 1_048_576}, # 1 MB forces aggressive truncation + ) + result = truncate_event(event) + + # Metadata fields must be EXACTLY preserved + assert result.event_type == "mcp:tools/call" + assert result.resource_name == "my_important_tool" + assert result.session_id == "sess-12345" + assert result.actor_id == "actor-67890" + + # user_intent IS truncatable, but "short intent" is small enough to survive + assert result.user_intent == "short intent" + + # Payload was truncated + assert "truncated by MCPcat" in result.parameters["data"] + + # Still under size limit + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES + + def test_large_user_intent_truncated_but_metadata_preserved(self): + """Large user_intent is truncated while metadata stays intact.""" + event = _make_event( + event_type="mcp:tools/call", + resource_name="my_tool", + user_intent="x" * 200_000, + parameters={"key": "value"}, + ) + result = truncate_event(event) + assert result.event_type == "mcp:tools/call" + assert result.resource_name == "my_tool" + assert "truncated by MCPcat" in result.user_intent + result_bytes = len(result.model_dump_json().encode("utf-8")) + assert result_bytes <= MAX_EVENT_BYTES