Skip to content

Commit 2834bc9

Browse files
Benny ChenBenny Chen
authored andcommitted
fix more type errors
1 parent ad86c71 commit 2834bc9

File tree

8 files changed

+71
-27
lines changed

8 files changed

+71
-27
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,26 @@ async def initialize_session(self, session: MCPSession) -> None:
5353

5454
exit_stack = AsyncExitStack()
5555

56-
client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
57-
client_info._extra["session_id"] = session.session_id
56+
# Attach client metadata for the server to consume (session_id, seed, config, etc.).
57+
# The server inspects a private `_extra` dict on client_info, so we populate it here.
58+
client_info = Implementation(name="reward-kit", version="1.0.0")
59+
extra_data: Dict[str, Any] = {"session_id": session.session_id}
5860
if session.seed is not None:
59-
client_info._extra["seed"] = session.seed
61+
extra_data["seed"] = session.seed
6062
if session.dataset_row and session.dataset_row.environment_context:
61-
client_info._extra["config"] = session.dataset_row.environment_context
63+
extra_data["config"] = session.dataset_row.environment_context
6264
if session.dataset_row and session.dataset_row.id:
63-
client_info._extra["dataset_row_id"] = session.dataset_row.id
65+
extra_data["dataset_row_id"] = session.dataset_row.id
6466
if session.model_id:
65-
client_info._extra["model_id"] = session.model_id
67+
extra_data["model_id"] = session.model_id
68+
69+
# Merge with any existing _extra dict instead of overwriting
70+
existing_extra = getattr(client_info, "_extra", None)
71+
merged_extra: Dict[str, Any] = {}
72+
if isinstance(existing_extra, dict):
73+
merged_extra.update(existing_extra)
74+
merged_extra.update(extra_data)
75+
setattr(client_info, "_extra", merged_extra)
6676

6777
read_stream, write_stream, _ = await exit_stack.enter_async_context(
6878
streamablehttp_client(session.base_url, terminate_on_close=True)
@@ -92,7 +102,10 @@ async def _prewarm_tools_cache(self, session: MCPSession) -> None:
92102
# Only fetch tools if not already cached for this base_url
93103
if cache_key not in self._tools_cache:
94104
logger.debug(f"Pre-warming tools cache for {cache_key}")
95-
tools_response = await session._mcp_session.list_tools()
105+
mcp_session_local = session._mcp_session
106+
if mcp_session_local is None:
107+
raise RuntimeError("Session not initialized during prewarm")
108+
tools_response = await mcp_session_local.list_tools()
96109
tools = tools_response.tools if hasattr(tools_response, "tools") else []
97110

98111
tool_schemas = []
@@ -213,7 +226,7 @@ async def get_initial_state(self, session: MCPSession) -> Any:
213226
try:
214227
# Use shorter timeout for playback mode, longer timeout for high-concurrency initialization
215228
# (50+ concurrent sessions need more time for initial state setup)
216-
timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 15.0
229+
timeout = 3.0 if bool(getattr(session, "_is_playback_mode", False)) else 15.0
217230
async with httpx.AsyncClient(timeout=timeout) as client:
218231
initial_state_response = await client.get(
219232
f"{base_url}/control/initial_state",

eval_protocol/mcp/clients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, intermediary_server_url: str):
2929

3030
async def connect(self):
3131
"""Establishes connection and MCP session."""
32-
if self._mcp_session and not self._mcp_session.is_closed:
32+
if self._mcp_session is not None:
3333
logger.debug("Already connected.")
3434
return
3535

eval_protocol/mcp/mcp_multi_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ async def _connect_to_server(
9090
if env_config:
9191
self._validate_environment_variables(server_name, env_config)
9292

93-
# Use the current system environment (os.environ) - don't override with config
94-
server_params = StdioServerParameters(command=command, args=args, env=os.environ)
93+
# Use the current system environment (os.environ) - convert to plain dict for typing compatibility
94+
server_params = StdioServerParameters(command=command, args=args, env=dict(os.environ))
9595

9696
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
9797
stdio, write = stdio_transport

eval_protocol/mcp_agent/orchestration/local_docker_client.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ async def provision_instances(
244244
logger.info(
245245
f"Creating template container for commit: {temp_cont_name} from {backend_config.docker_image}"
246246
)
247+
if not backend_config.docker_image:
248+
raise ValueError(
249+
f"docker_image is required for template commit for backend {backend_config.backend_name_ref}"
250+
)
247251
temp_c = self.docker_client.containers.run( # type: ignore
248252
image=backend_config.docker_image,
249253
name=temp_cont_name,
@@ -322,6 +326,11 @@ async def provision_instances(
322326
logger.info(
323327
f"Provisioning instance {container_name} (transport: {backend_config.mcp_transport}) from image {image_to_run_from}"
324328
)
329+
# Ensure the image reference is present before using it in Docker APIs
330+
if not image_to_run_from:
331+
raise ValueError(
332+
f"docker_image is required to provision instance {container_name} for backend {backend_config.backend_name_ref}"
333+
)
325334
if backend_config.mcp_transport == "http":
326335
# ... (HTTP provisioning logic, ensure it uses current_container_volumes) ...
327336
if not self.docker_client:
@@ -640,7 +649,7 @@ async def list_tools_on_instance(self, instance: ManagedInstanceInfo) -> types.L
640649
)
641650
target_base_url = instance.mcp_endpoint_url.rstrip("/")
642651
try:
643-
async with streamablehttp_client(base_url=target_base_url) as (
652+
async with streamablehttp_client(target_base_url) as (
644653
read_s,
645654
write_s,
646655
_, # get_session_id_func usually not needed for a single call

eval_protocol/mcp_agent/orchestration/remote_http_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ async def call_tool_on_instance(
263263
logger.debug(f"Proxying tool {tool_name} to {target_url} for instance {instance.instance_id}")
264264
else:
265265
# Call tool directly on the instance's MCP endpoint
266+
# mypy/pyright: instance.mcp_endpoint_url is Optional[str]; validate before assignment
267+
if not instance.mcp_endpoint_url:
268+
raise ValueError(f"Instance {instance.instance_id} missing mcp_endpoint_url for direct tool call")
266269
target_url = instance.mcp_endpoint_url
267270
logger.debug(f"Calling tool {tool_name} directly on {target_url} for instance {instance.instance_id}")
268271

eval_protocol/models.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -611,27 +611,35 @@ def get_steps(self) -> int:
611611
def get_total_reward(self) -> float:
612612
"""Get total reward from control_plane_step data."""
613613
messages_with_control_plane = [msg for msg in self.messages if msg.control_plane_step]
614-
return (
615-
sum(msg.control_plane_step["reward"] for msg in messages_with_control_plane)
616-
if messages_with_control_plane
617-
else 0.0
618-
)
614+
if not messages_with_control_plane:
615+
return 0.0
616+
total = 0.0
617+
for msg in messages_with_control_plane:
618+
step = msg.control_plane_step or {}
619+
try:
620+
total += float(step.get("reward", 0.0))
621+
except (TypeError, ValueError):
622+
continue
623+
return total
619624

620625
def get_terminated(self) -> bool:
621626
"""Get termination status from control_plane_step data."""
622627
messages_with_control_plane = [msg for msg in self.messages if msg.control_plane_step]
623-
return (
624-
any(msg.control_plane_step["terminated"] for msg in messages_with_control_plane)
625-
if messages_with_control_plane
626-
else False
627-
)
628+
if not messages_with_control_plane:
629+
return False
630+
for msg in messages_with_control_plane:
631+
step = msg.control_plane_step or {}
632+
if bool(step.get("terminated", False)):
633+
return True
634+
return False
628635

629636
def get_termination_reason(self) -> str:
630637
"""Get termination reason from the final control_plane_step data."""
631638
# Find the last message with control_plane_step that has termination_reason
632639
for msg in reversed(self.messages):
633640
if msg.control_plane_step and msg.control_plane_step.get("termination_reason"):
634-
return msg.control_plane_step["termination_reason"]
641+
reason = msg.control_plane_step.get("termination_reason")
642+
return str(reason)
635643
return "unknown"
636644

637645
def __hash__(self) -> int:

eval_protocol/rewards/code_execution.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ def local_code_execution_reward(
171171
},
172172
)
173173

174-
response_content = messages[-1].content
174+
# Normalize content to string; Message.content may be str or list of content parts
175+
last_content = messages[-1].content
176+
response_content = (
177+
last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
178+
)
175179
expected_output_str = ground_truth
176180

177181
code_blocks = extract_code_blocks(response_content, language)
@@ -935,7 +939,10 @@ def e2b_code_execution_reward(
935939
},
936940
)
937941

938-
response_content = messages[-1].content
942+
last_content = messages[-1].content
943+
response_content = (
944+
last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
945+
)
939946
expected_output_str = ground_truth
940947

941948
code_blocks = extract_code_blocks(response_content, language)

eval_protocol/rewards/lean_prover.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def lean_prover_reward(
5757
},
5858
)
5959

60-
response = messages[-1].content
60+
last_content = messages[-1].content
61+
response = last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
6162
if not response:
6263
return EvaluateResult(
6364
score=0.0,
@@ -230,7 +231,10 @@ def deepseek_prover_v2_reward(
230231
and messages[-1].role == "assistant"
231232
and messages[-1].content is not None
232233
):
233-
response_content = messages[-1].content
234+
last_content = messages[-1].content
235+
response_content = (
236+
last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
237+
)
234238

235239
final_score = base_score
236240
subgoal_count = 0

0 commit comments

Comments
 (0)