Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions eval_protocol/mcp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,34 +441,24 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
# Extract data plane results (observation only)
if tool_result.content and len(tool_result.content) > 0:
content = tool_result.content[0]
# Safely attempt to read a "text" attribute if present across content types
text_attr = getattr(content, "text", None)
if isinstance(text_attr, str):
content_text = text_attr
elif isinstance(text_attr, list):
# text can also be an array of parts with optional .text fields
content_text = "".join([getattr(p, "text", "") for p in text_attr])
else:
content_text = None

if isinstance(content_text, str):
if hasattr(content, "text"):
# Fix: Handle empty or invalid JSON responses gracefully
if content_text.strip() == "":
if not content.text or content.text.strip() == "":
logger.warning(f"Session {session.session_id}: Empty tool response from {tool_name}")
observation = {
"observation": "empty_response",
"session_id": session.session_id,
}
else:
try:
observation = json.loads(content_text)
observation = json.loads(content.text)
except json.JSONDecodeError as e:
logger.warning(
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content_text}. Error: {e}"
f"Session {session.session_id}: Invalid JSON from {tool_name}: {content.text}. Error: {e}"
)
# Create a structured response from the raw text
observation = {
"observation": content_text,
"observation": content.text,
"session_id": session.session_id,
"error": "invalid_json_response",
}
Expand Down
58 changes: 42 additions & 16 deletions eval_protocol/mcp/mcpgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import date, datetime
from enum import Enum
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, Literal, cast

import uvicorn
from mcp.server.fastmcp import Context, FastMCP
Expand Down Expand Up @@ -146,22 +146,20 @@ def _get_session_id(self, ctx: Context) -> str:
print(f"🔍 _get_session_id: hasattr(ctx, 'session'): {hasattr(ctx, 'session')}")

# Use stable session ID based on client info (following simulation_server.py pattern)
if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
client_params = ctx.session.client_params
if hasattr(ctx, "session"):
client_params = getattr(ctx.session, "client_params", None)
print(f"🔍 _get_session_id: client_params type: {type(client_params)}")
print(f"🔍 _get_session_id: hasattr(client_params, 'clientInfo'): {hasattr(client_params, 'clientInfo')}")

if hasattr(client_params, "clientInfo"):
client_info = client_params.clientInfo
if client_params is not None and hasattr(client_params, "clientInfo"):
client_info = getattr(client_params, "clientInfo", None)
print(f"🔍 _get_session_id: client_info: {client_info}")
print(f"🔍 _get_session_id: hasattr(client_info, '_extra'): {hasattr(client_info, '_extra')}")

if client_info and hasattr(client_info, "_extra"):
extra_data = client_info._extra
if client_info is not None:
# Access private _extra with a cast to satisfy type checker
extra_data = cast(Any, getattr(client_info, "_extra", None))
print(f"🔍 _get_session_id: extra_data: {extra_data}")
print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}")

if extra_data and isinstance(extra_data, dict):
if isinstance(extra_data, dict):
# use the client generated session id
if "session_id" in extra_data:
print(f"🔍 _get_session_id: using client generated session_id: {extra_data['session_id']}")
Expand All @@ -181,8 +179,8 @@ def _get_session_id(self, ctx: Context) -> str:
"config": config_value,
"dataset_row_id": dataset_row_id_value,
"model_id": model_id_value,
"name": client_info.name,
"version": client_info.version,
"name": getattr(client_info, "name", None),
"version": getattr(client_info, "version", None),
}

print(f"🔍 _get_session_id: stable_data: {stable_data}")
Expand All @@ -205,6 +203,15 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
"""
session_id = self._get_session_id(ctx)
print(f"🔍 _get_or_create_session: session_id: {session_id}")
if session_id not in self.sessions:
env, obs, info = self._new_env(seed=None)
with self.session_lock:
self.sessions[session_id] = {
"env": env,
"obs": obs,
"session_data": {},
"session_id": session_id,
}
return self.sessions[session_id]

def _register_session_reset_endpoint(self):
Expand Down Expand Up @@ -400,6 +407,15 @@ def _execute_session_environment_step(self, session_id: str, action: Any) -> Dic
Returns:
Data plane response (observation only, no rewards)
"""
if session_id not in self.sessions:
env, obs, info = self._new_env(seed=None)
with self.session_lock:
self.sessions[session_id] = {
"env": env,
"obs": obs,
"session_data": {},
"session_id": session_id,
}
session_data = self.sessions[session_id]
env = session_data["env"]

Expand Down Expand Up @@ -558,7 +574,8 @@ async def run_with_high_concurrency():
if not kwargs.get("redirect_slashes", True) and hasattr(starlette_app, "router"):
starlette_app.router.redirect_slashes = False

starlette_app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
# Add middleware with proper type cast to satisfy basedpyright
starlette_app.add_middleware(cast(Any, ProxyHeadersMiddleware), trusted_hosts="*")

config = uvicorn.Config(
starlette_app,
Expand All @@ -580,7 +597,15 @@ async def run_with_high_concurrency():
asyncio.run(run_with_high_concurrency())
else:
# Use default FastMCP run for other transports
self.mcp.run(transport=transport, **kwargs)
# Constrain transport to the allowed literal values for type checker
allowed_transport: Literal["stdio", "sse", "streamable-http"]
if transport in ("stdio", "sse", "streamable-http"):
allowed_transport = cast(Literal["stdio", "sse", "streamable-http"], transport)
else:
# Default to streamable-http if unknown
allowed_transport = cast(Literal["stdio", "sse", "streamable-http"], "streamable-http")

self.mcp.run(transport=allowed_transport, **kwargs)

def _to_json_serializable(self, obj: Any) -> Any:
"""Convert any object to JSON-serializable format.
Expand All @@ -607,7 +632,8 @@ def _to_json_serializable(self, obj: Any) -> Any:

# Handle dataclasses
elif dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)
# Cast for type checker because protocol uses ClassVar on __dataclass_fields__
return dataclasses.asdict(cast(Any, obj))

# Handle dictionaries
elif isinstance(obj, dict):
Expand Down
61 changes: 61 additions & 0 deletions tests/pytest/test_mcp_session_autocreate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Regression test: ensure MCP-Gym auto-creates a session on first tool call
without requiring a prior initial state fetch, and returns JSON.
"""

import time
from multiprocessing import Process

import httpx
import pytest

from eval_protocol.mcp.client.connection import MCPConnectionManager
from eval_protocol.types import MCPSession


def _run_airline_server():
import os

os.environ["PORT"] = "9780"
from eval_protocol.mcp_servers.tau2.tau2_mcp import AirlineDomainMcp

server = AirlineDomainMcp(seed=None)
server.run(transport="streamable-http")


@pytest.mark.asyncio
async def test_tool_call_returns_json_without_prior_initial_state():
proc = Process(target=_run_airline_server, daemon=True)
proc.start()

try:
base_url = "http://127.0.0.1:9780/mcp"
client = httpx.Client(timeout=1.0)
deadline = time.time() + 20
while time.time() < deadline:
try:
r = client.get(base_url)
if r.status_code in (200, 307, 406):
break
except Exception:
pass
time.sleep(0.2)
else:
pytest.fail("Server did not start on port 9780 in time")

session = MCPSession(base_url=base_url, session_id="test-autocreate", seed=None, model_id="test-model")

mgr = MCPConnectionManager()
await mgr.initialize_session(session)
await mgr.discover_tools(session)

observation, reward, done, info = await mgr.call_tool(session, "list_all_airports", {})

assert isinstance(observation, dict), f"Expected JSON dict, got: {type(observation)} {observation}"
assert observation.get("error") != "invalid_json_response"

await mgr.reset_session(session)
await mgr.close_session(session)
finally:
proc.terminate()
proc.join(timeout=5)
Loading