From 8a58a1a04433edc88a5a8743c9ed98b891b369be Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Mon, 1 Sep 2025 21:07:25 +0800 Subject: [PATCH 1/2] revert mcp gym issue --- eval_protocol/mcp/client/connection.py | 20 ++----- eval_protocol/mcp/mcpgym.py | 18 ++++++ tests/pytest/test_mcp_session_autocreate.py | 61 +++++++++++++++++++++ 3 files changed, 84 insertions(+), 15 deletions(-) create mode 100644 tests/pytest/test_mcp_session_autocreate.py diff --git a/eval_protocol/mcp/client/connection.py b/eval_protocol/mcp/client/connection.py index f0c85ac6..a6fcd53d 100644 --- a/eval_protocol/mcp/client/connection.py +++ b/eval_protocol/mcp/client/connection.py @@ -441,19 +441,9 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict) # Extract data plane results (observation only) if tool_result.content and len(tool_result.content) > 0: content = tool_result.content[0] - # Safely attempt to read a "text" attribute if present across content types - text_attr = getattr(content, "text", None) - if isinstance(text_attr, str): - content_text = text_attr - elif isinstance(text_attr, list): - # text can also be an array of parts with optional .text fields - content_text = "".join([getattr(p, "text", "") for p in text_attr]) - else: - content_text = None - - if isinstance(content_text, str): + if hasattr(content, "text"): # Fix: Handle empty or invalid JSON responses gracefully - if content_text.strip() == "": + if not content.text or content.text.strip() == "": logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}") observation = { "observation": "empty_response", @@ -461,14 +451,14 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict) } else: try: - observation = json.loads(content_text) + observation = json.loads(content.text) except json.JSONDecodeError as e: logger.warning( - f"Session {session.session_id}: Invalid JSON from {tool_name}: {content_text}. Error: {e}" + f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}" ) # Create a structured response from the raw text observation = { - "observation": content_text, + "observation": content.text, "session_id": session.session_id, "error": "invalid_json_response", } diff --git a/eval_protocol/mcp/mcpgym.py b/eval_protocol/mcp/mcpgym.py index fb8d8caa..827d0394 100644 --- a/eval_protocol/mcp/mcpgym.py +++ b/eval_protocol/mcp/mcpgym.py @@ -205,6 +205,15 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]: """ session_id = self._get_session_id(ctx) print(f"🔍 _get_or_create_session: session_id: {session_id}") + if session_id not in self.sessions: + env, obs, info = self._new_env(seed=None) + with self.session_lock: + self.sessions[session_id] = { + "env": env, + "obs": obs, + "session_data": {}, + "session_id": session_id, + } return self.sessions[session_id] def _register_session_reset_endpoint(self): @@ -400,6 +409,15 @@ def _execute_session_environment_step(self, session_id: str, action: Any) -> Dic Returns: Data plane response (observation only, no rewards) """ + if session_id not in self.sessions: + env, obs, info = self._new_env(seed=None) + with self.session_lock: + self.sessions[session_id] = { + "env": env, + "obs": obs, + "session_data": {}, + "session_id": session_id, + } session_data = self.sessions[session_id] env = session_data["env"] diff --git a/tests/pytest/test_mcp_session_autocreate.py b/tests/pytest/test_mcp_session_autocreate.py new file mode 100644 index 00000000..df816f55 --- /dev/null +++ b/tests/pytest/test_mcp_session_autocreate.py @@ -0,0 +1,61 @@ +""" +Regression test: ensure MCP-Gym auto-creates a session on first tool call +without requiring a prior initial state fetch, and returns JSON. +""" + +import time +from multiprocessing import Process + +import httpx +import pytest + +from eval_protocol.mcp.client.connection import MCPConnectionManager +from eval_protocol.types import MCPSession + + +def _run_airline_server(): + import os + + os.environ["PORT"] = "9780" + from eval_protocol.mcp_servers.tau2.tau2_mcp import AirlineDomainMcp + + server = AirlineDomainMcp(seed=None) + server.run(transport="streamable-http") + + +@pytest.mark.asyncio +async def test_tool_call_returns_json_without_prior_initial_state(): + proc = Process(target=_run_airline_server, daemon=True) + proc.start() + + try: + base_url = "http://127.0.0.1:9780/mcp" + client = httpx.Client(timeout=1.0) + deadline = time.time() + 20 + while time.time() < deadline: + try: + r = client.get(base_url) + if r.status_code in (200, 307, 406): + break + except Exception: + pass + time.sleep(0.2) + else: + pytest.fail("Server did not start on port 9780 in time") + + session = MCPSession(base_url=base_url, session_id="test-autocreate", seed=None, model_id="test-model") + + mgr = MCPConnectionManager() + await mgr.initialize_session(session) + await mgr.discover_tools(session) + + observation, reward, done, info = await mgr.call_tool(session, "list_all_airports", {}) + + assert isinstance(observation, dict), f"Expected JSON dict, got: {type(observation)} {observation}" + assert observation.get("error") != "invalid_json_response" + + await mgr.reset_session(session) + await mgr.close_session(session) + finally: + proc.terminate() + proc.join(timeout=5) From 9ffabcc84c7afa25378ca43870c829822104b3ac Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Mon, 1 Sep 2025 21:13:21 +0800 Subject: [PATCH 2/2] also fix the type error --- eval_protocol/mcp/mcpgym.py | 40 ++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/eval_protocol/mcp/mcpgym.py b/eval_protocol/mcp/mcpgym.py index 827d0394..81e4dbdb 100644 --- a/eval_protocol/mcp/mcpgym.py +++ b/eval_protocol/mcp/mcpgym.py @@ -24,7 +24,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import date, datetime from enum import Enum -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple, Literal, cast import uvicorn from mcp.server.fastmcp import Context, FastMCP @@ -146,22 +146,20 @@ def _get_session_id(self, ctx: Context) -> str: print(f"🔍 _get_session_id: hasattr(ctx, 'session'): {hasattr(ctx, 'session')}") # Use stable session ID based on client info (following simulation_server.py pattern) - if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"): - client_params = ctx.session.client_params + if hasattr(ctx, "session"): + client_params = getattr(ctx.session, "client_params", None) print(f"🔍 _get_session_id: client_params type: {type(client_params)}") - print(f"🔍 _get_session_id: hasattr(client_params, 'clientInfo'): {hasattr(client_params, 'clientInfo')}") - - if hasattr(client_params, "clientInfo"): - client_info = client_params.clientInfo + if client_params is not None and hasattr(client_params, "clientInfo"): + client_info = getattr(client_params, "clientInfo", None) print(f"🔍 _get_session_id: client_info: {client_info}") - print(f"🔍 _get_session_id: hasattr(client_info, '_extra'): {hasattr(client_info, '_extra')}") - if client_info and hasattr(client_info, "_extra"): - extra_data = client_info._extra + if client_info is not None: + # Access private _extra with a cast to satisfy type checker + extra_data = cast(Any, getattr(client_info, "_extra", None)) print(f"🔍 _get_session_id: extra_data: {extra_data}") print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}") - if extra_data and isinstance(extra_data, dict): + if isinstance(extra_data, dict): # use the client generated session id if "session_id" in extra_data: print(f"🔍 _get_session_id: using client generated session_id: {extra_data['session_id']}") @@ -181,8 +179,8 @@ def _get_session_id(self, ctx: Context) -> str: "config": config_value, "dataset_row_id": dataset_row_id_value, "model_id": model_id_value, - "name": client_info.name, - "version": client_info.version, + "name": getattr(client_info, "name", None), + "version": getattr(client_info, "version", None), } print(f"🔍 _get_session_id: stable_data: {stable_data}") @@ -576,7 +574,8 @@ async def run_with_high_concurrency(): if not kwargs.get("redirect_slashes", True) and hasattr(starlette_app, "router"): starlette_app.router.redirect_slashes = False - starlette_app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") + # Add middleware with proper type cast to satisfy basedpyright + starlette_app.add_middleware(cast(Any, ProxyHeadersMiddleware), trusted_hosts="*") config = uvicorn.Config( starlette_app, @@ -598,7 +597,15 @@ async def run_with_high_concurrency(): asyncio.run(run_with_high_concurrency()) else: # Use default FastMCP run for other transports - self.mcp.run(transport=transport, **kwargs) + # Constrain transport to the allowed literal values for type checker + allowed_transport: Literal["stdio", "sse", "streamable-http"] + if transport in ("stdio", "sse", "streamable-http"): + allowed_transport = cast(Literal["stdio", "sse", "streamable-http"], transport) + else: + # Default to streamable-http if unknown + allowed_transport = cast(Literal["stdio", "sse", "streamable-http"], "streamable-http") + + self.mcp.run(transport=allowed_transport, **kwargs) def _to_json_serializable(self, obj: Any) -> Any: """Convert any object to JSON-serializable format. @@ -625,7 +632,8 @@ def _to_json_serializable(self, obj: Any) -> Any: # Handle dataclasses elif dataclasses.is_dataclass(obj): - return dataclasses.asdict(obj) + # Cast for type checker because protocol uses ClassVar on __dataclass_fields__ + return dataclasses.asdict(cast(Any, obj)) # Handle dictionaries elif isinstance(obj, dict):