2424from concurrent .futures import ThreadPoolExecutor
2525from datetime import date , datetime
2626from enum import Enum
27- from typing import Any , Callable , Dict , Optional , Tuple
27+ from typing import Any , Callable , Dict , Optional , Tuple , Literal , cast
2828
2929import uvicorn
3030from mcp .server .fastmcp import Context , FastMCP
@@ -146,22 +146,20 @@ def _get_session_id(self, ctx: Context) -> str:
146146 print (f"🔍 _get_session_id: hasattr(ctx, 'session'): { hasattr (ctx , 'session' )} " )
147147
148148 # Use stable session ID based on client info (following simulation_server.py pattern)
149- if hasattr (ctx , "session" ) and hasattr ( ctx . session , "client_params" ) :
150- client_params = ctx .session . client_params
149+ if hasattr (ctx , "session" ):
150+ client_params = getattr ( ctx .session , " client_params" , None )
151151 print (f"🔍 _get_session_id: client_params type: { type (client_params )} " )
152- print (f"🔍 _get_session_id: hasattr(client_params, 'clientInfo'): { hasattr (client_params , 'clientInfo' )} " )
153-
154- if hasattr (client_params , "clientInfo" ):
155- client_info = client_params .clientInfo
152+ if client_params is not None and hasattr (client_params , "clientInfo" ):
153+ client_info = getattr (client_params , "clientInfo" , None )
156154 print (f"🔍 _get_session_id: client_info: { client_info } " )
157- print (f"🔍 _get_session_id: hasattr(client_info, '_extra'): { hasattr (client_info , '_extra' )} " )
158155
159- if client_info and hasattr (client_info , "_extra" ):
160- extra_data = client_info ._extra
156+ if client_info is not None :
157+ # Access private _extra with a cast to satisfy type checker
158+ extra_data = cast (Any , getattr (client_info , "_extra" , None ))
161159 print (f"🔍 _get_session_id: extra_data: { extra_data } " )
162160 print (f"🔍 _get_session_id: extra_data type: { type (extra_data )} " )
163161
164- if extra_data and isinstance (extra_data , dict ):
162+ if isinstance (extra_data , dict ):
165163 # use the client generated session id
166164 if "session_id" in extra_data :
167165 print (f"🔍 _get_session_id: using client generated session_id: { extra_data ['session_id' ]} " )
@@ -181,8 +179,8 @@ def _get_session_id(self, ctx: Context) -> str:
181179 "config" : config_value ,
182180 "dataset_row_id" : dataset_row_id_value ,
183181 "model_id" : model_id_value ,
184- "name" : client_info . name ,
185- "version" : client_info . version ,
182+ "name" : getattr ( client_info , " name" , None ) ,
183+ "version" : getattr ( client_info , " version" , None ) ,
186184 }
187185
188186 print (f"🔍 _get_session_id: stable_data: { stable_data } " )
@@ -576,7 +574,8 @@ async def run_with_high_concurrency():
576574 if not kwargs .get ("redirect_slashes" , True ) and hasattr (starlette_app , "router" ):
577575 starlette_app .router .redirect_slashes = False
578576
579- starlette_app .add_middleware (ProxyHeadersMiddleware , trusted_hosts = "*" )
577+ # Add middleware with proper type cast to satisfy basedpyright
578+ starlette_app .add_middleware (cast (Any , ProxyHeadersMiddleware ), trusted_hosts = "*" )
580579
581580 config = uvicorn .Config (
582581 starlette_app ,
@@ -598,7 +597,15 @@ async def run_with_high_concurrency():
598597 asyncio .run (run_with_high_concurrency ())
599598 else :
600599 # Use default FastMCP run for other transports
601- self .mcp .run (transport = transport , ** kwargs )
600+ # Constrain transport to the allowed literal values for type checker
601+ allowed_transport : Literal ["stdio" , "sse" , "streamable-http" ]
602+ if transport in ("stdio" , "sse" , "streamable-http" ):
603+ allowed_transport = cast (Literal ["stdio" , "sse" , "streamable-http" ], transport )
604+ else :
605+ # Default to streamable-http if unknown
606+ allowed_transport = cast (Literal ["stdio" , "sse" , "streamable-http" ], "streamable-http" )
607+
608+ self .mcp .run (transport = allowed_transport , ** kwargs )
602609
603610 def _to_json_serializable (self , obj : Any ) -> Any :
604611 """Convert any object to JSON-serializable format.
@@ -625,7 +632,8 @@ def _to_json_serializable(self, obj: Any) -> Any:
625632
626633 # Handle dataclasses
627634 elif dataclasses .is_dataclass (obj ):
628- return dataclasses .asdict (obj )
635+ # Cast for type checker because protocol uses ClassVar on __dataclass_fields__
636+ return dataclasses .asdict (cast (Any , obj ))
629637
630638 # Handle dictionaries
631639 elif isinstance (obj , dict ):
0 commit comments