Skip to content

Commit 847ff69

Browse files
Benny ChenBenny Chen
authored andcommitted
type fix round 6
1 parent caf93cf commit 847ff69

File tree

5 files changed

+65
-18
lines changed

5 files changed

+65
-18
lines changed

eval_protocol/benchmarks/test_aime25.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from typing import Any, Dict, List, Optional
22

3-
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
3+
from eval_protocol.models import (
4+
EvaluateResult,
5+
EvaluationRow,
6+
Message,
7+
MetricResult,
8+
ChatCompletionContentPartTextParam,
9+
)
410
from eval_protocol.pytest.default_single_turn_rollout_process import (
511
SingleTurnRolloutProcessor,
612
)
@@ -11,6 +17,14 @@
1117
)
1218

1319

20+
def _coerce_content_to_str(
21+
content: str | list[ChatCompletionContentPartTextParam] | None,
22+
) -> str:
23+
if isinstance(content, list):
24+
return "".join([getattr(p, "text", str(p)) for p in content])
25+
return str(content or "")
26+
27+
1428
def _extract_boxed_text(text: str) -> str:
1529
import re
1630

@@ -80,9 +94,10 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
8094
)
8195
def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow:
8296
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
83-
content = assistant_msgs[-1].content if assistant_msgs else ""
97+
raw_content = assistant_msgs[-1].content if assistant_msgs else ""
98+
content_str = _coerce_content_to_str(raw_content)
8499

85-
extracted_text = _extract_boxed_text(content or "")
100+
extracted_text = _extract_boxed_text(content_str)
86101
extracted_int = _normalize_to_int_or_none(extracted_text)
87102
gt_int = _normalize_to_int_or_none(row.ground_truth or "")
88103

eval_protocol/benchmarks/test_gpqa.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
import requests
77

8-
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
8+
from eval_protocol.models import (
9+
EvaluateResult,
10+
EvaluationRow,
11+
Message,
12+
MetricResult,
13+
ChatCompletionContentPartTextParam,
14+
)
915
from eval_protocol.pytest.default_single_turn_rollout_process import (
1016
SingleTurnRolloutProcessor,
1117
)
@@ -47,6 +53,14 @@ def _load_gpqa_messages_from_csv() -> list[list[list[Message]]]:
4753
return [messages_list]
4854

4955

56+
def _coerce_content_to_str(
57+
content: str | list[ChatCompletionContentPartTextParam] | None,
58+
) -> str:
59+
if isinstance(content, list):
60+
return "".join([getattr(p, "text", str(p)) for p in content])
61+
return str(content or "")
62+
63+
5064
def _extract_abcd_letter(text: str) -> str | None:
5165
if not text:
5266
return None
@@ -58,9 +72,12 @@ def _extract_abcd_letter(text: str) -> str | None:
5872

5973

6074
def _strip_gt_messages(msgs: list[Message]) -> list[Message]:
61-
# assert that all the messages just have a plain .content string field
62-
assert all(isinstance(m.content, str) for m in msgs), "Messages must have a plain .content string field"
63-
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
75+
result: list[Message] = []
76+
for m in msgs:
77+
content_str = _coerce_content_to_str(m.content)
78+
if not (m.role == "system" and content_str.startswith("__GT__:")):
79+
result.append(m)
80+
return result
6481

6582

6683
class GPQAStripGTRolloutProcessor(RolloutProcessor):
@@ -75,15 +92,23 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
7592
processed: list[EvaluationRow] = []
7693

7794
for r in rows:
78-
gt_tokens = [
79-
m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")
80-
]
95+
gt_tokens: list[str] = []
96+
for m in r.messages:
97+
if m.role == "system":
98+
content_str = _coerce_content_to_str(m.content)
99+
if content_str.startswith("__GT__:"):
100+
gt_tokens.append(content_str)
81101
if gt_tokens:
82102
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
83103
r.ground_truth = gt_val
84-
r.messages = [
85-
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
86-
]
104+
filtered: list[Message] = []
105+
for m in r.messages:
106+
if m.role == "system":
107+
content_str = _coerce_content_to_str(m.content)
108+
if content_str.startswith("__GT__:"):
109+
continue
110+
filtered.append(m)
111+
r.messages = filtered
87112
processed.append(r)
88113

89114
# Delegate to SingleTurnRolloutProcessor
@@ -103,9 +128,10 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
103128
)
104129
def test_gpqa_pointwise(row: EvaluationRow) -> EvaluationRow:
105130
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
106-
content = assistant_msgs[-1].content if assistant_msgs else ""
131+
raw_content = assistant_msgs[-1].content if assistant_msgs else ""
132+
content_str = _coerce_content_to_str(raw_content)
107133

108-
pred = _extract_abcd_letter(content or "")
134+
pred = _extract_abcd_letter(content_str)
109135
# GPQA diamond CSV constructs options so that the correct answer is always A
110136
gt = "A"
111137

eval_protocol/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ class MetricResult(BaseModel):
271271
is_score_valid: bool = True
272272
score: float = Field(..., ge=0.0, le=1.0)
273273
reason: str
274+
data: Dict[str, Any] = Field(default_factory=dict, description="Optional extra metric data for debugging.")
274275

275276
def __getitem__(self, key: str) -> Any:
276277
if key in self.__fields__: # Changed to __fields__ for Pydantic v1 compatibility

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
from mcp.types import CallToolResult, TextContent
88
from openai import NOT_GIVEN, NotGiven
9-
from openai.types.chat import ChatCompletionContentPartTextParam
9+
from openai.types.chat import ChatCompletionContentPartTextParam as OpenAIChatContentPart
1010
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
1111

1212
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1313
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
1414
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
15-
from eval_protocol.models import EvaluationRow, Message
15+
from eval_protocol.models import EvaluationRow, Message, ChatCompletionContentPartTextParam
1616
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1717
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
1818
from pydantic import BaseModel
@@ -215,6 +215,7 @@ def _format_tool_message_content(
215215
"""
216216
if len(content) == 1 and isinstance(content[0], TextContent):
217217
return content[0].text
218+
# Build our SDK's ChatCompletionContentPartTextParam instances, not OpenAI types
218219
return [ChatCompletionContentPartTextParam(text=c.text, type="text") for c in content]
219220

220221

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ def __init__(self, content: str):
4343
if row.messages:
4444
last_user = [m for m in row.messages if m.role == "user"]
4545
if last_user:
46-
lm_messages.append(HumanMessage(content=last_user[-1].content or ""))
46+
content = last_user[-1].content or ""
47+
if isinstance(content, list):
48+
# Flatten our SDK content parts into a single string for LangChain
49+
content = "".join([getattr(p, "text", str(p)) for p in content])
50+
lm_messages.append(HumanMessage(content=str(content)))
4751
if not lm_messages:
4852
lm_messages = [HumanMessage(content="")] # minimal
4953

0 commit comments

Comments
 (0)