-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_aime25.py
More file actions
133 lines (114 loc) · 4.06 KB
/
test_aime25.py
File metadata and controls
133 lines (114 loc) · 4.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import Any, Dict, List, Optional
from eval_protocol.models import (
EvaluateResult,
EvaluationRow,
Message,
MetricResult,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)
from eval_protocol.pytest.default_single_turn_rollout_process import (
SingleTurnRolloutProcessor,
)
from eval_protocol.pytest.evaluation_test import evaluation_test
SYSTEM_PROMPT = (
"You are a helpful math assistant. Please reason step by step, and put your final answer within \\boxed{...}."
)
def _coerce_content_to_str(
content: str | list[ChatCompletionContentPartParam] | None,
) -> str:
if isinstance(content, list):
return "".join(
getattr(p, "text", str(p)) if isinstance(p, ChatCompletionContentPartTextParam) else "" for p in content
)
return str(content or "")
def _extract_boxed_text(text: str) -> str:
import re
if not text:
return ""
pattern_boxed = r"boxed{(.*?)}|framebox{(.*?)}"
matches = re.findall(pattern_boxed, text, re.DOTALL)
if matches:
for match in matches[::-1]:
for group in match:
if group:
return group.split(",")[-1].strip()
matches_digits = re.findall(r"\d+", text, re.DOTALL)
if matches_digits:
return matches_digits[-1]
return ""
def _normalize_to_int_or_none(s: Optional[str]) -> Optional[int]:
import re
if s is None:
return None
m = re.match(r"\d+", str(s).strip())
if not m:
return None
try:
return int(m.group(0))
except ValueError:
return None
def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
converted: List[EvaluationRow] = []
for r in rows:
question = r.get("question", "")
answer = r.get("answer", None)
messages = [
Message(role="system", content=SYSTEM_PROMPT),
Message(role="user", content=str(question)),
]
converted.append(EvaluationRow(messages=messages, ground_truth=str(answer) if answer is not None else None))
return converted
@evaluation_test(
input_dataset=[
"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl",
"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl",
],
dataset_adapter=aime2025_dataset_adapter,
completion_params=[
{
"max_tokens": 131000,
"extra_body": {"reasoning_effort": "low"},
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
}
],
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
passed_threshold=0.8,
num_runs=8,
max_dataset_rows=2,
max_concurrent_rollouts=4,
mode="pointwise",
)
def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow:
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
raw_content = assistant_msgs[-1].content if assistant_msgs else ""
content_str = _coerce_content_to_str(raw_content)
extracted_text = _extract_boxed_text(content_str)
extracted_int = _normalize_to_int_or_none(extracted_text)
gt_int = _normalize_to_int_or_none(str(row.ground_truth))
is_valid = extracted_int is not None and gt_int is not None
score = 1.0 if (is_valid and extracted_int == gt_int) else 0.0
metrics = {
"exact_match": MetricResult(
score=score,
is_score_valid=is_valid,
reason=(
"Parsed both integers and they matched"
if score == 1.0
else ("Parsed integers did not match" if is_valid else "Failed to parse integer")
),
data={
"extracted_text": extracted_text,
"extracted_int": extracted_int,
"ground_truth_int": gt_int,
},
)
}
row.evaluation_result = EvaluateResult(
score=score,
reason=("Answer correct" if score == 1.0 else "Answer incorrect"),
is_score_valid=is_valid,
metrics=metrics,
)
return row