Skip to content

Commit 9ffabcc

Browse files
Benny ChenBenny Chen
authored andcommitted
also fix the type error
1 parent 8a58a1a commit 9ffabcc

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

eval_protocol/mcp/mcpgym.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from concurrent.futures import ThreadPoolExecutor
2525
from datetime import date, datetime
2626
from enum import Enum
27-
from typing import Any, Callable, Dict, Optional, Tuple
27+
from typing import Any, Callable, Dict, Optional, Tuple, Literal, cast
2828

2929
import uvicorn
3030
from mcp.server.fastmcp import Context, FastMCP
@@ -146,22 +146,20 @@ def _get_session_id(self, ctx: Context) -> str:
146146
print(f"🔍 _get_session_id: hasattr(ctx, 'session'): {hasattr(ctx, 'session')}")
147147

148148
# Use stable session ID based on client info (following simulation_server.py pattern)
149-
if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
150-
client_params = ctx.session.client_params
149+
if hasattr(ctx, "session"):
150+
client_params = getattr(ctx.session, "client_params", None)
151151
print(f"🔍 _get_session_id: client_params type: {type(client_params)}")
152-
print(f"🔍 _get_session_id: hasattr(client_params, 'clientInfo'): {hasattr(client_params, 'clientInfo')}")
153-
154-
if hasattr(client_params, "clientInfo"):
155-
client_info = client_params.clientInfo
152+
if client_params is not None and hasattr(client_params, "clientInfo"):
153+
client_info = getattr(client_params, "clientInfo", None)
156154
print(f"🔍 _get_session_id: client_info: {client_info}")
157-
print(f"🔍 _get_session_id: hasattr(client_info, '_extra'): {hasattr(client_info, '_extra')}")
158155

159-
if client_info and hasattr(client_info, "_extra"):
160-
extra_data = client_info._extra
156+
if client_info is not None:
157+
# Access private _extra with a cast to satisfy type checker
158+
extra_data = cast(Any, getattr(client_info, "_extra", None))
161159
print(f"🔍 _get_session_id: extra_data: {extra_data}")
162160
print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}")
163161

164-
if extra_data and isinstance(extra_data, dict):
162+
if isinstance(extra_data, dict):
165163
# use the client generated session id
166164
if "session_id" in extra_data:
167165
print(f"🔍 _get_session_id: using client generated session_id: {extra_data['session_id']}")
@@ -181,8 +179,8 @@ def _get_session_id(self, ctx: Context) -> str:
181179
"config": config_value,
182180
"dataset_row_id": dataset_row_id_value,
183181
"model_id": model_id_value,
184-
"name": client_info.name,
185-
"version": client_info.version,
182+
"name": getattr(client_info, "name", None),
183+
"version": getattr(client_info, "version", None),
186184
}
187185

188186
print(f"🔍 _get_session_id: stable_data: {stable_data}")
@@ -576,7 +574,8 @@ async def run_with_high_concurrency():
576574
if not kwargs.get("redirect_slashes", True) and hasattr(starlette_app, "router"):
577575
starlette_app.router.redirect_slashes = False
578576

579-
starlette_app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
577+
# Add middleware with proper type cast to satisfy basedpyright
578+
starlette_app.add_middleware(cast(Any, ProxyHeadersMiddleware), trusted_hosts="*")
580579

581580
config = uvicorn.Config(
582581
starlette_app,
@@ -598,7 +597,15 @@ async def run_with_high_concurrency():
598597
asyncio.run(run_with_high_concurrency())
599598
else:
600599
# Use default FastMCP run for other transports
601-
self.mcp.run(transport=transport, **kwargs)
600+
# Constrain transport to the allowed literal values for type checker
601+
allowed_transport: Literal["stdio", "sse", "streamable-http"]
602+
if transport in ("stdio", "sse", "streamable-http"):
603+
allowed_transport = cast(Literal["stdio", "sse", "streamable-http"], transport)
604+
else:
605+
# Default to streamable-http if unknown
606+
allowed_transport = cast(Literal["stdio", "sse", "streamable-http"], "streamable-http")
607+
608+
self.mcp.run(transport=allowed_transport, **kwargs)
602609

603610
def _to_json_serializable(self, obj: Any) -> Any:
604611
"""Convert any object to JSON-serializable format.
@@ -625,7 +632,8 @@ def _to_json_serializable(self, obj: Any) -> Any:
625632

626633
# Handle dataclasses
627634
elif dataclasses.is_dataclass(obj):
628-
return dataclasses.asdict(obj)
635+
# Cast for type checker because protocol uses ClassVar on __dataclass_fields__
636+
return dataclasses.asdict(cast(Any, obj))
629637

630638
# Handle dictionaries
631639
elif isinstance(obj, dict):

0 commit comments

Comments
 (0)