Skip to content

Commit a2e232a

Browse files
benjibcBenny Chen
andauthored
fix pyright round 2 (#141)
* fix more type errors * fix langchain and properly fix messages * nits * fix more tests --------- Co-authored-by: Benny Chen <bchen@Bennys-MacBook-Air.local>
1 parent ad86c71 commit a2e232a

File tree

11 files changed

+127
-41
lines changed

11 files changed

+127
-41
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,26 @@ async def initialize_session(self, session: MCPSession) -> None:
5353

5454
exit_stack = AsyncExitStack()
5555

56-
client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
57-
client_info._extra["session_id"] = session.session_id
56+
# Attach client metadata for the server to consume (session_id, seed, config, etc.).
57+
# The server inspects a private `_extra` dict on client_info, so we populate it here.
58+
client_info = Implementation(name="reward-kit", version="1.0.0")
59+
extra_data: Dict[str, Any] = {"session_id": session.session_id}
5860
if session.seed is not None:
59-
client_info._extra["seed"] = session.seed
61+
extra_data["seed"] = session.seed
6062
if session.dataset_row and session.dataset_row.environment_context:
61-
client_info._extra["config"] = session.dataset_row.environment_context
63+
extra_data["config"] = session.dataset_row.environment_context
6264
if session.dataset_row and session.dataset_row.id:
63-
client_info._extra["dataset_row_id"] = session.dataset_row.id
65+
extra_data["dataset_row_id"] = session.dataset_row.id
6466
if session.model_id:
65-
client_info._extra["model_id"] = session.model_id
67+
extra_data["model_id"] = session.model_id
68+
69+
# Merge with any existing _extra dict instead of overwriting
70+
existing_extra = getattr(client_info, "_extra", None)
71+
merged_extra: Dict[str, Any] = {}
72+
if isinstance(existing_extra, dict):
73+
merged_extra.update(existing_extra)
74+
merged_extra.update(extra_data)
75+
setattr(client_info, "_extra", merged_extra)
6676

6777
read_stream, write_stream, _ = await exit_stack.enter_async_context(
6878
streamablehttp_client(session.base_url, terminate_on_close=True)
@@ -92,7 +102,10 @@ async def _prewarm_tools_cache(self, session: MCPSession) -> None:
92102
# Only fetch tools if not already cached for this base_url
93103
if cache_key not in self._tools_cache:
94104
logger.debug(f"Pre-warming tools cache for {cache_key}")
95-
tools_response = await session._mcp_session.list_tools()
105+
mcp_session_local = session._mcp_session
106+
if mcp_session_local is None:
107+
raise RuntimeError("Session not initialized during prewarm")
108+
tools_response = await mcp_session_local.list_tools()
96109
tools = tools_response.tools if hasattr(tools_response, "tools") else []
97110

98111
tool_schemas = []
@@ -213,7 +226,7 @@ async def get_initial_state(self, session: MCPSession) -> Any:
213226
try:
214227
# Use shorter timeout for playback mode, longer timeout for high-concurrency initialization
215228
# (50+ concurrent sessions need more time for initial state setup)
216-
timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 15.0
229+
timeout = 3.0 if bool(getattr(session, "_is_playback_mode", False)) else 15.0
217230
async with httpx.AsyncClient(timeout=timeout) as client:
218231
initial_state_response = await client.get(
219232
f"{base_url}/control/initial_state",

eval_protocol/mcp/clients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, intermediary_server_url: str):
2929

3030
async def connect(self):
3131
"""Establishes connection and MCP session."""
32-
if self._mcp_session and not self._mcp_session.is_closed:
32+
if self._mcp_session is not None and not self._mcp_session.is_closed:
3333
logger.debug("Already connected.")
3434
return
3535

eval_protocol/mcp/mcp_multi_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ async def _connect_to_server(
9090
if env_config:
9191
self._validate_environment_variables(server_name, env_config)
9292

93-
# Use the current system environment (os.environ) - don't override with config
94-
server_params = StdioServerParameters(command=command, args=args, env=os.environ)
93+
# Use the current system environment (os.environ) - convert to plain dict for typing compatibility
94+
server_params = StdioServerParameters(command=command, args=args, env=dict(os.environ))
9595

9696
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
9797
stdio, write = stdio_transport

eval_protocol/mcp_agent/orchestration/local_docker_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ async def provision_instances(
244244
logger.info(
245245
f"Creating template container for commit: {temp_cont_name} from {backend_config.docker_image}"
246246
)
247+
if not backend_config.docker_image:
248+
raise ValueError(
249+
f"docker_image is required for template commit for backend {backend_config.backend_name_ref}"
250+
)
247251
temp_c = self.docker_client.containers.run( # type: ignore
248252
image=backend_config.docker_image,
249253
name=temp_cont_name,
@@ -322,6 +326,11 @@ async def provision_instances(
322326
logger.info(
323327
f"Provisioning instance {container_name} (transport: {backend_config.mcp_transport}) from image {image_to_run_from}"
324328
)
329+
# Ensure the image reference is present before using it in Docker APIs
330+
if not image_to_run_from:
331+
raise ValueError(
332+
f"docker_image is required to provision instance {container_name} for backend {backend_config.backend_name_ref}"
333+
)
325334
if backend_config.mcp_transport == "http":
326335
# ... (HTTP provisioning logic, ensure it uses current_container_volumes) ...
327336
if not self.docker_client:

eval_protocol/mcp_agent/orchestration/remote_http_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ async def call_tool_on_instance(
263263
logger.debug(f"Proxying tool {tool_name} to {target_url} for instance {instance.instance_id}")
264264
else:
265265
# Call tool directly on the instance's MCP endpoint
266+
# mypy/pyright: instance.mcp_endpoint_url is Optional[str]; validate before assignment
267+
if not instance.mcp_endpoint_url:
268+
raise ValueError(f"Instance {instance.instance_id} missing mcp_endpoint_url for direct tool call")
266269
target_url = instance.mcp_endpoint_url
267270
logger.debug(f"Calling tool {tool_name} directly on {target_url} for instance {instance.instance_id}")
268271

eval_protocol/models.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,19 @@ class Message(BaseModel):
243243

244244
@classmethod
245245
def model_validate(cls, obj, *args, **kwargs):
246-
if isinstance(obj, dict) and "role" not in obj:
247-
raise ValueError("Role is required")
246+
if isinstance(obj, dict):
247+
if "role" not in obj:
248+
raise ValueError("Role is required")
249+
# Be lenient: if tool_calls entries are missing required 'id', synthesize one
250+
tool_calls_obj = obj.get("tool_calls")
251+
if isinstance(tool_calls_obj, list):
252+
fixed_tool_calls = []
253+
for tc in tool_calls_obj:
254+
if isinstance(tc, dict):
255+
if not tc.get("id"):
256+
tc = {**tc, "id": generate_id()}
257+
fixed_tool_calls.append(tc)
258+
obj = {**obj, "tool_calls": fixed_tool_calls}
248259
return super().model_validate(obj, *args, **kwargs)
249260

250261

@@ -611,27 +622,35 @@ def get_steps(self) -> int:
611622
def get_total_reward(self) -> float:
612623
"""Get total reward from control_plane_step data."""
613624
messages_with_control_plane = [msg for msg in self.messages if msg.control_plane_step]
614-
return (
615-
sum(msg.control_plane_step["reward"] for msg in messages_with_control_plane)
616-
if messages_with_control_plane
617-
else 0.0
618-
)
625+
if not messages_with_control_plane:
626+
return 0.0
627+
total = 0.0
628+
for msg in messages_with_control_plane:
629+
step = msg.control_plane_step or {}
630+
try:
631+
total += float(step.get("reward", 0.0))
632+
except (TypeError, ValueError):
633+
continue
634+
return total
619635

620636
def get_terminated(self) -> bool:
621637
"""Get termination status from control_plane_step data."""
622638
messages_with_control_plane = [msg for msg in self.messages if msg.control_plane_step]
623-
return (
624-
any(msg.control_plane_step["terminated"] for msg in messages_with_control_plane)
625-
if messages_with_control_plane
626-
else False
627-
)
639+
if not messages_with_control_plane:
640+
return False
641+
for msg in messages_with_control_plane:
642+
step = msg.control_plane_step or {}
643+
if bool(step.get("terminated", False)):
644+
return True
645+
return False
628646

629647
def get_termination_reason(self) -> str:
630648
"""Get termination reason from the final control_plane_step data."""
631649
# Find the last message with control_plane_step that has termination_reason
632650
for msg in reversed(self.messages):
633651
if msg.control_plane_step and msg.control_plane_step.get("termination_reason"):
634-
return msg.control_plane_step["termination_reason"]
652+
reason = msg.control_plane_step.get("termination_reason")
653+
return str(reason)
635654
return "unknown"
636655

637656
def __hash__(self) -> int:

eval_protocol/pytest/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .rollout_processor import RolloutProcessor
99
from .types import RolloutProcessorConfig
1010

11-
# Conditional import for optional dependency
11+
# Conditional import for optional dependencies
1212
try:
1313
from .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
1414

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import asyncio
22
from typing import List
33

4-
from langchain_core.messages import BaseMessage
4+
try:
5+
from langchain_core.messages import BaseMessage
6+
except Exception: # pragma: no cover - optional dependency path
7+
# Minimal fallback base type to satisfy typing when langchain is not present
8+
class BaseMessage: # type: ignore
9+
pass
10+
511

612
from eval_protocol.models import EvaluationRow, Message
713
from eval_protocol.pytest.rollout_processor import RolloutProcessor
@@ -25,7 +31,13 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
2531

2632
async def _process_row(row: EvaluationRow) -> EvaluationRow:
2733
# Build LC messages from EP row
28-
from langchain_core.messages import HumanMessage
34+
try:
35+
from langchain_core.messages import HumanMessage
36+
except Exception:
37+
# Fallback minimal message if langchain_core is unavailable
38+
class HumanMessage: # type: ignore
39+
def __init__(self, content: str):
40+
self.content = content
2941

3042
lm_messages: List[BaseMessage] = []
3143
if row.messages:

eval_protocol/rewards/code_execution.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ def local_code_execution_reward(
171171
},
172172
)
173173

174-
response_content = messages[-1].content
174+
# Normalize content to string; Message.content may be str or list of content parts
175+
last_content = messages[-1].content
176+
response_content = (
177+
last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
178+
)
175179
expected_output_str = ground_truth
176180

177181
code_blocks = extract_code_blocks(response_content, language)
@@ -935,7 +939,10 @@ def e2b_code_execution_reward(
935939
},
936940
)
937941

938-
response_content = messages[-1].content
942+
last_content = messages[-1].content
943+
response_content = (
944+
last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
945+
)
939946
expected_output_str = ground_truth
940947

941948
code_blocks = extract_code_blocks(response_content, language)

eval_protocol/rewards/lean_prover.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def lean_prover_reward(
5757
},
5858
)
5959

60-
response = messages[-1].content
60+
last_content = messages[-1].content
61+
response = last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
6162
if not response:
6263
return EvaluateResult(
6364
score=0.0,
@@ -230,7 +231,10 @@ def deepseek_prover_v2_reward(
230231
and messages[-1].role == "assistant"
231232
and messages[-1].content is not None
232233
):
233-
response_content = messages[-1].content
234+
last_content = messages[-1].content
235+
response_content = (
236+
last_content if isinstance(last_content, str) else "".join([p.text for p in (last_content or [])])
237+
)
234238

235239
final_score = base_score
236240
subgoal_count = 0

0 commit comments

Comments
 (0)