-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_gpqa.py
More file actions
126 lines (103 loc) · 4.47 KB
/
test_gpqa.py
File metadata and controls
126 lines (103 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import asyncio
import csv
import io
import re
import requests
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
from eval_protocol.pytest.default_single_turn_rollout_process import (
SingleTurnRolloutProcessor,
)
from eval_protocol.pytest.evaluation_test import evaluation_test
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
SYSTEM_PROMPT = (
"You are a helpful assistant. Read the question and options carefully. "
"Express your final answer strictly as a single letter: A, B, C, or D."
)
def _load_gpqa_messages_from_csv() -> list[list[list[Message]]]:
url = "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
resp = requests.get(url, timeout=60)
resp.raise_for_status()
messages_list: list[list[Message]] = []
reader = csv.DictReader(io.StringIO(resp.text))
for ex in reader:
q = str(ex.get("Question", ""))
correct = str(ex.get("Correct Answer", "")).strip()
inc1 = str(ex.get("Incorrect Answer 1", ""))
inc2 = str(ex.get("Incorrect Answer 2", ""))
inc3 = str(ex.get("Incorrect Answer 3", ""))
choices = [correct, inc1, inc2, inc3]
user_content = (
f"{q}\n\n(A) {choices[0]}\n(B) {choices[1]}\n(C) {choices[2]}\n(D) {choices[3]}\n\nAnswer with one letter."
)
messages_list.append(
[
Message(role="system", content=SYSTEM_PROMPT),
Message(role="user", content=user_content),
]
)
if not messages_list:
raise RuntimeError("Failed to load GPQA messages: no rows found from source")
return [messages_list]
def _extract_abcd_letter(text: str) -> str | None:
if not text:
return None
m = re.search(r"\b([ABCD])\b", text.upper())
return m.group(1) if m else None
_GPQA_INPUT_MESSAGES = _load_gpqa_messages_from_csv()
def _strip_gt_messages(msgs: list[Message]) -> list[Message]:
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
class GPQAStripGTRolloutProcessor(RolloutProcessor):
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to SingleTurnRolloutProcessor."""
def __init__(self):
super().__init__()
self.single_turn_processor = SingleTurnRolloutProcessor()
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
"""Preprocess rows and delegate to SingleTurnRolloutProcessor."""
processed: list[EvaluationRow] = []
for r in rows:
gt_tokens = [
m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")
]
if gt_tokens:
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
r.ground_truth = gt_val
r.messages = [
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
]
processed.append(r)
# Delegate to SingleTurnRolloutProcessor
return self.single_turn_processor(processed, config)
@evaluation_test(
input_messages=_GPQA_INPUT_MESSAGES,
completion_params=[
{"extra_body": {"reasoning_effort": "low"}, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}
],
rollout_processor=GPQAStripGTRolloutProcessor(),
aggregation_method="mean",
passed_threshold=0.6,
num_runs=8,
mode="pointwise",
)
def test_gpqa_pointwise(row: EvaluationRow) -> EvaluationRow:
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
content = assistant_msgs[-1].content if assistant_msgs else ""
pred = _extract_abcd_letter(content or "")
# GPQA diamond CSV constructs options so that the correct answer is always A
gt = "A"
is_valid = pred is not None and gt in {"A", "B", "C", "D"}
score = 1.0 if (is_valid and pred == gt) else 0.0
row.evaluation_result = EvaluateResult(
score=score,
reason=("Correct option" if score == 1.0 else "Incorrect option"),
is_score_valid=is_valid,
metrics={
"exact_match": MetricResult(
score=score,
is_score_valid=is_valid,
reason=("Matched" if score == 1.0 else "Not matched"),
data={"pred": pred, "gt": gt},
)
},
)
return row