Skip to content

Commit 56c7bdd

Browse files
benjibcBenny Chen
andauthored
type fix round 6 (#147)
* type fix round 6 * fix more tests * fix more errors * fixes * fix tests --------- Co-authored-by: Benny Chen <bchen@Bennys-MacBook-Air.local>
1 parent 28d8b3e commit 56c7bdd

File tree

16 files changed

+198
-66
lines changed

16 files changed

+198
-66
lines changed

eval_protocol/benchmarks/test_aime25.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from typing import Any, Dict, List, Optional
22

3-
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
3+
from eval_protocol.models import (
4+
EvaluateResult,
5+
EvaluationRow,
6+
Message,
7+
MetricResult,
8+
ChatCompletionContentPartTextParam,
9+
)
410
from eval_protocol.pytest.default_single_turn_rollout_process import (
511
SingleTurnRolloutProcessor,
612
)
@@ -11,6 +17,14 @@
1117
)
1218

1319

20+
def _coerce_content_to_str(
21+
content: str | list[ChatCompletionContentPartTextParam] | None,
22+
) -> str:
23+
if isinstance(content, list):
24+
return "".join([getattr(p, "text", str(p)) for p in content])
25+
return str(content or "")
26+
27+
1428
def _extract_boxed_text(text: str) -> str:
1529
import re
1630

@@ -80,9 +94,10 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
8094
)
8195
def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow:
8296
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
83-
content = assistant_msgs[-1].content if assistant_msgs else ""
97+
raw_content = assistant_msgs[-1].content if assistant_msgs else ""
98+
content_str = _coerce_content_to_str(raw_content)
8499

85-
extracted_text = _extract_boxed_text(content or "")
100+
extracted_text = _extract_boxed_text(content_str)
86101
extracted_int = _normalize_to_int_or_none(extracted_text)
87102
gt_int = _normalize_to_int_or_none(row.ground_truth or "")
88103

eval_protocol/benchmarks/test_gpqa.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
import requests
77

8-
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
8+
from eval_protocol.models import (
9+
EvaluateResult,
10+
EvaluationRow,
11+
Message,
12+
MetricResult,
13+
ChatCompletionContentPartTextParam,
14+
)
915
from eval_protocol.pytest.default_single_turn_rollout_process import (
1016
SingleTurnRolloutProcessor,
1117
)
@@ -47,6 +53,14 @@ def _load_gpqa_messages_from_csv() -> list[list[list[Message]]]:
4753
return [messages_list]
4854

4955

56+
def _coerce_content_to_str(
57+
content: str | list[ChatCompletionContentPartTextParam] | None,
58+
) -> str:
59+
if isinstance(content, list):
60+
return "".join([getattr(p, "text", str(p)) for p in content])
61+
return str(content or "")
62+
63+
5064
def _extract_abcd_letter(text: str) -> str | None:
5165
if not text:
5266
return None
@@ -58,9 +72,12 @@ def _extract_abcd_letter(text: str) -> str | None:
5872

5973

6074
def _strip_gt_messages(msgs: list[Message]) -> list[Message]:
61-
# assert that all the messages just have a plain .content string field
62-
assert all(isinstance(m.content, str) for m in msgs), "Messages must have a plain .content string field"
63-
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
75+
result: list[Message] = []
76+
for m in msgs:
77+
content_str = _coerce_content_to_str(m.content)
78+
if not (m.role == "system" and content_str.startswith("__GT__:")):
79+
result.append(m)
80+
return result
6481

6582

6683
class GPQAStripGTRolloutProcessor(RolloutProcessor):
@@ -75,15 +92,23 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
7592
processed: list[EvaluationRow] = []
7693

7794
for r in rows:
78-
gt_tokens = [
79-
m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")
80-
]
95+
gt_tokens: list[str] = []
96+
for m in r.messages:
97+
if m.role == "system":
98+
content_str = _coerce_content_to_str(m.content)
99+
if content_str.startswith("__GT__:"):
100+
gt_tokens.append(content_str)
81101
if gt_tokens:
82102
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
83103
r.ground_truth = gt_val
84-
r.messages = [
85-
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
86-
]
104+
filtered: list[Message] = []
105+
for m in r.messages:
106+
if m.role == "system":
107+
content_str = _coerce_content_to_str(m.content)
108+
if content_str.startswith("__GT__:"):
109+
continue
110+
filtered.append(m)
111+
r.messages = filtered
87112
processed.append(r)
88113

89114
# Delegate to SingleTurnRolloutProcessor
@@ -103,9 +128,10 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
103128
)
104129
def test_gpqa_pointwise(row: EvaluationRow) -> EvaluationRow:
105130
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
106-
content = assistant_msgs[-1].content if assistant_msgs else ""
131+
raw_content = assistant_msgs[-1].content if assistant_msgs else ""
132+
content_str = _coerce_content_to_str(raw_content)
107133

108-
pred = _extract_abcd_letter(content or "")
134+
pred = _extract_abcd_letter(content_str)
109135
# GPQA diamond CSV constructs options so that the correct answer is always A
110136
gt = "A"
111137

eval_protocol/benchmarks/test_livebench_data_analysis.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import re
44
from typing import Any, Dict, List, Optional
55

6-
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
6+
from eval_protocol.models import (
7+
EvaluateResult,
8+
EvaluationRow,
9+
Message,
10+
MetricResult,
11+
ChatCompletionContentPartTextParam,
12+
)
713
from eval_protocol.pytest.default_single_turn_rollout_process import (
814
SingleTurnRolloutProcessor,
915
)
@@ -31,6 +37,12 @@ def _extract_last_boxed_segment(text: str) -> Optional[str]:
3137
return matches[-1]
3238

3339

40+
def _coerce_content_to_str(content: str | list[ChatCompletionContentPartTextParam] | None) -> str:
41+
if isinstance(content, list):
42+
return "".join([getattr(p, "text", str(p)) for p in content])
43+
return str(content or "")
44+
45+
3446
def _cta_process_results(ground_truth: str, llm_answer: str) -> int:
3547
parsed_answer = llm_answer
3648
if "\\boxed{" in parsed_answer or "\\framebox{" in parsed_answer:
@@ -275,6 +287,8 @@ def _read_jsonl_table_from_text(text: str, header_cols: List[str]):
275287
return 0
276288

277289
# Compare
290+
assert llm_df is not None, "LLM dataframe is None"
291+
assert gt_df is not None, "GT dataframe is None"
278292
try:
279293
gt_df.columns = [str(s).strip() for s in gt_df.columns]
280294
if "index" in gt_df.columns:
@@ -420,7 +434,8 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
420434
)
421435
def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
422436
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
423-
content = assistant_msgs[-1].content if assistant_msgs else ""
437+
raw_content = assistant_msgs[-1].content if assistant_msgs else ""
438+
content = _coerce_content_to_str(raw_content)
424439
payload = _extract_gt(row)
425440
gt = payload.get("ground_truth")
426441
gt_str = str(gt) if gt is not None else ""
@@ -462,9 +477,9 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
462477
)
463478
def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
464479
user_msgs = [m for m in row.messages if m.role == "user"]
465-
question = user_msgs[-1].content if user_msgs else ""
480+
question = _coerce_content_to_str(user_msgs[-1].content if user_msgs else "")
466481
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
467-
content = assistant_msgs[-1].content if assistant_msgs else ""
482+
content = _coerce_content_to_str(assistant_msgs[-1].content if assistant_msgs else "")
468483
payload = _extract_gt(row)
469484
gt = payload.get("ground_truth")
470485

@@ -505,9 +520,9 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
505520
)
506521
def test_livebench_tablereformat_pointwise(row: EvaluationRow) -> EvaluationRow:
507522
user_msgs = [m for m in row.messages if m.role == "user"]
508-
question = user_msgs[-1].content if user_msgs else ""
523+
question = _coerce_content_to_str(user_msgs[-1].content if user_msgs else "")
509524
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
510-
content = assistant_msgs[-1].content if assistant_msgs else ""
525+
content = _coerce_content_to_str(assistant_msgs[-1].content if assistant_msgs else "")
511526
payload = _extract_gt(row)
512527
gt = payload.get("ground_truth")
513528
release = payload.get("release") or ""

eval_protocol/integrations/braintrust.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def scorer_to_reward_fn(
1818
"""Wrap a Braintrust scorer as an Eval Protocol reward function."""
1919

2020
@reward_function
21-
def reward_fn(messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs) -> EvaluateResult:
21+
def reward_fn(
22+
messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs: Any
23+
) -> EvaluateResult:
2224
input_val = messages_to_input(messages) if messages_to_input else messages[0].content
2325
output_val = messages[-1].content
2426
expected_val = None

eval_protocol/integrations/deepeval.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def _build_case_kwargs() -> Dict[str, Any]:
7979
case_kwargs["actual_output"] = output
8080
return case_kwargs
8181

82-
if isinstance(metric, BaseConversationalMetric):
82+
if BaseConversationalMetric is not None and isinstance(metric, BaseConversationalMetric):
83+
# Narrow types for optional imports to satisfy the type checker
84+
assert LLMTestCase is not None
85+
assert ConversationalTestCase is not None
8386
turns = []
8487
for i, msg in enumerate(messages):
8588
turn_input = messages[i - 1].get("content", "") if i > 0 else ""
@@ -93,10 +96,16 @@ def _build_case_kwargs() -> Dict[str, Any]:
9396
output = messages[-1].get("content", "")
9497
test_case = ConversationalTestCase(turns=turns)
9598
else:
99+
# Narrow types for optional imports to satisfy the type checker
100+
assert LLMTestCase is not None
96101
case_kwargs = _build_case_kwargs()
97102
test_case = LLMTestCase(**case_kwargs)
98103

99-
metric.measure(test_case, **kwargs)
104+
# Guard against metric.measure being None or non-callable
105+
measure_fn = getattr(metric, "measure", None)
106+
if not callable(measure_fn):
107+
raise TypeError("Provided metric does not have a callable 'measure' method")
108+
measure_fn(test_case, **kwargs)
100109
score = float(metric.score or 0.0)
101110
reason = getattr(metric, "reason", None)
102111
name = _metric_name(metric)

eval_protocol/mcp/mcpgym.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def format_observation(self, obs: Any, env: Any) -> Dict[str, Any]:
563563
else:
564564
return {"observation": serialized_obs}
565565

566-
def run(self, transport: str = "streamable-http", **kwargs):
566+
def run(self, transport: Literal["stdio", "sse", "streamable-http"] = "streamable-http", **kwargs):
567567
"""Run the unified MCP-Gym server with high concurrency settings."""
568568
if transport == "streamable-http":
569569
# Run with custom high-concurrency uvicorn config

eval_protocol/mcp/simulation_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def reset_environment(self, env, seed): ...
3030
from abc import ABC, abstractmethod
3131
from collections.abc import AsyncIterator
3232
from contextlib import asynccontextmanager
33-
from typing import Any, Callable, Dict, List, Optional, Tuple
33+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Iterable
34+
from pydantic import AnyUrl
3435

3536
import uvicorn
3637
from mcp.server.lowlevel import Server

eval_protocol/mcp_agent/orchestration/local_docker_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ async def startup(self) -> None:
5757
except docker.errors.DockerException as e:
5858
logger.warning(f"docker.from_env() failed: {e}. Trying explicit base_url.")
5959
try:
60+
# docker.from_env is preferred, but as a fallback use DockerClient with url param name 'base_url'
6061
self.docker_client = docker.DockerClient(base_url="unix://var/run/docker.sock")
6162
if not self.docker_client.ping(): # type: ignore
6263
raise ConnectionError("Failed to connect to Docker daemon with explicit base_url.")
@@ -649,7 +650,7 @@ async def list_tools_on_instance(self, instance: ManagedInstanceInfo) -> types.L
649650
)
650651
target_base_url = instance.mcp_endpoint_url.rstrip("/")
651652
try:
652-
async with streamablehttp_client(base_url=target_base_url) as (
653+
async with streamablehttp_client(base_url=target_base_url) as ( # type: ignore
653654
read_s,
654655
write_s,
655656
_, # get_session_id_func usually not needed for a single call

eval_protocol/mcp_servers/tau2/tau2_mcp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, seed: Optional[int] = None, **kwargs):
4343

4444
self.adapter = EnvironmentAdapter(env_class=AirlineEnvironment, default_config=default_config)
4545

46+
# Ensure name is a str and not None
4647
super().__init__("airline", self.adapter, seed, **kwargs)
4748

4849
def _register_tools(self):
@@ -421,7 +422,7 @@ def _register_tools(self):
421422
"""Register mock-specific MCP tools matching τ²-Bench schemas"""
422423

423424
@self.mcp.tool(name="create_task", description="Create a new task for a user.")
424-
def create_task(user_id: str, title: str, ctx: Context, description: str = None) -> Dict[str, Any]:
425+
def create_task(user_id: str, title: str, ctx: Context, description: Optional[str] = None) -> Dict[str, Any]:
425426
"""Create a new task for a user"""
426427
session_id = self._get_session_id(ctx)
427428

eval_protocol/models.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,33 @@ class ChatCompletionContentPartTextParam(BaseModel):
224224
text: str = Field(..., description="The text content.")
225225
type: Literal["text"] = Field("text", description="The type of the content part.")
226226

227+
# Provide dict-like access for tests and ergonomic usage
228+
def __getitem__(self, key: str) -> Any:
229+
if key == "text":
230+
return self.text
231+
if key == "type":
232+
return self.type
233+
raise KeyError(key)
234+
235+
def get(self, key: str, default: Any = None) -> Any:
236+
try:
237+
return self[key]
238+
except KeyError:
239+
return default
240+
241+
def keys(self):
242+
return (k for k in ("text", "type"))
243+
244+
def values(self):
245+
return (self.text, self.type)
246+
247+
def items(self):
248+
return [("text", self.text), ("type", self.type)]
249+
250+
def __iter__(self):
251+
# Iterate over keys only
252+
return iter(["text", "type"])
253+
227254

228255
class Message(BaseModel):
229256
"""Chat message model with trajectory evaluation support."""
@@ -271,6 +298,7 @@ class MetricResult(BaseModel):
271298
is_score_valid: bool = True
272299
score: float = Field(..., ge=0.0, le=1.0)
273300
reason: str
301+
data: Dict[str, Any] = Field(default_factory=dict, description="Optional extra metric data for debugging.")
274302

275303
def __getitem__(self, key: str) -> Any:
276304
if key in self.__fields__: # Changed to __fields__ for Pydantic v1 compatibility
@@ -292,10 +320,12 @@ def values(self):
292320
return [getattr(self, key) for key in self.__fields__.keys()] # Changed to __fields__
293321

294322
def items(self):
295-
return [(key, getattr(self, key)) for key in self.__fields__.keys()] # Changed to __fields__
323+
# Exclude 'data' from items to keep items hashable and match tests
324+
return [(key, getattr(self, key)) for key in self.__fields__.keys() if key != "data"] # Changed to __fields__
296325

297326
def __iter__(self):
298-
return iter(self.__fields__.keys()) # Changed to __fields__
327+
# Exclude 'data' to match expectations in tests
328+
return iter([k for k in self.__fields__.keys() if k != "data"]) # Changed to __fields__
299329

300330

301331
class StepOutput(BaseModel):

0 commit comments

Comments
 (0)