Skip to content

Commit 693274e

Browse files
committed
attempt at primitive conversion
1 parent c61de5b commit 693274e

File tree

3 files changed

+125
-5
lines changed

3 files changed

+125
-5
lines changed

eval_protocol/benchmarks/test_aime25.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,44 @@ def _normalize_to_int_or_none(s: Optional[str]) -> Optional[int]:
6363
return None
6464

6565

66+
def _build_feedback_text(
67+
*,
68+
extracted_int: Optional[int],
69+
gt_int: Optional[int],
70+
is_valid: bool,
71+
raw_model_answer: str,
72+
ground_truth: Optional[str],
73+
) -> str:
74+
"""
75+
Build a feedback string similar in spirit to the GEPA `metric_with_feedback`.
76+
77+
Cases:
78+
- Parse failure (model or gold): explain integer formatting and show correct answer.
79+
- Correct: "Your answer is correct. The correct answer is '...'."
80+
- Incorrect: "Your answer is incorrect. The correct answer is '...'."
81+
"""
82+
correct_answer_display = str(gt_int if gt_int is not None else (ground_truth or ""))
83+
84+
if not is_valid:
85+
# Could not parse either the model answer or the gold answer as an integer.
86+
feedback_text = (
87+
"The final answer must be a valid integer and nothing else. "
88+
f"You responded with '{raw_model_answer}', which couldn't be parsed as a python integer. "
89+
"Please ensure your answer is a valid integer without any additional text or formatting."
90+
)
91+
if correct_answer_display:
92+
feedback_text += f" The correct answer is '{correct_answer_display}'."
93+
return feedback_text
94+
95+
if extracted_int == gt_int:
96+
return f"Your answer is correct. The correct answer is '{correct_answer_display}'."
97+
else:
98+
return f"Your answer is incorrect. The correct answer is '{correct_answer_display}'."
99+
100+
# TODO: our dataset does not contain written solutions, so we cannot provide feedback on the solution. maybe need to add it later.
101+
# they're using https://huggingface.co/datasets/AI-MO/aimo-validation-aime
102+
103+
66104
def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
67105
converted: List[EvaluationRow] = []
68106
for r in rows:
@@ -126,9 +164,17 @@ def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow:
126164
)
127165
}
128166

167+
feedback_text = _build_feedback_text(
168+
extracted_int=extracted_int,
169+
gt_int=gt_int,
170+
is_valid=is_valid,
171+
raw_model_answer=content_str,
172+
ground_truth=str(row.ground_truth),
173+
)
174+
129175
row.evaluation_result = EvaluateResult(
130176
score=score,
131-
reason=("Answer correct" if score == 1.0 else "Answer incorrect"),
177+
reason=feedback_text,
132178
is_score_valid=is_valid,
133179
metrics=metrics,
134180
)

eval_protocol/training/gepa_trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from eval_protocol.pytest.types import TestFunction
1212
from eval_protocol.training.trainer import Trainer
1313
from eval_protocol.training.utils import build_ep_parameters_from_test
14+
from eval_protocol.training.gepa_utils import ep_test_to_gepa_metric
1415

1516

1617
class GEPATrainer(Trainer):
@@ -33,11 +34,9 @@ def __init__(self, test_fn: TestFunction) -> None:
3334
super().__init__(test_fn)
3435
self.ep_params: EPParameters = build_ep_parameters_from_test(test_fn)
3536

36-
self.metric = test_fn # TODO @derek. need to convert our ep test_fn to a GEPA metric. also need to inject the feedback text.
37+
self.metric = ep_test_to_gepa_metric(test_fn)
3738

38-
self.program = (
39-
...
40-
) # TODO @shreymodi1: converting between a program (dspy.Module) and an @evaluation_test is a bit tricky.
39+
self.program = ... # TODO @shreymodi1: converting between a program (dspy.Module) and rollout processors is a bit tricky. maybe start with single turn
4140

4241
self.train_set, self.val_set, self.test_set = (
4342
...,

eval_protocol/training/gepa_utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import os
2+
from typing import Optional
23

34
import dspy
45
from dspy.clients.lm import LM
6+
from dspy.primitives import Example, Prediction
7+
from dspy.teleprompt.gepa.gepa_utils import DSPyTrace, ScoreWithFeedback
8+
from dspy.teleprompt.gepa.gepa import GEPAFeedbackMetric
9+
10+
from eval_protocol.pytest.types import TestFunction
11+
from eval_protocol.models import EvaluationRow, Message
12+
513

614
REFLECTION_LM_CONFIGS = {
715
"gpt-5": {
@@ -30,3 +38,70 @@ def build_reflection_lm(reflection_lm_name: str) -> LM:
3038
api_key=reflection_lm_config["api_key"],
3139
base_url=reflection_lm_config["base_url"],
3240
)
41+
42+
43+
def gold_and_pred_to_row(gold: Example, pred: Prediction) -> EvaluationRow:
44+
"""
45+
Convert a GEPA (gold, pred) pair into an EvaluationRow for an EP `@evaluation_test`.
46+
47+
Assumptions (aligned with common DSPy usage):
48+
- `gold.answer` holds the ground-truth answer.
49+
- `pred.answer` holds the model's final answer text.
50+
"""
51+
gt = gold.get("answer", None)
52+
ground_truth_str: Optional[str] = str(gt) if gt is not None else None
53+
54+
content = pred.get("answer", "")
55+
56+
return EvaluationRow(
57+
messages=[
58+
Message(role="assistant", content=str(content))
59+
], # TODO: for some evals, you might need system / user message too.
60+
ground_truth=ground_truth_str,
61+
)
62+
63+
64+
def row_to_prediction(row: EvaluationRow) -> ScoreWithFeedback:
65+
"""
66+
Convert an EvaluationRow into a GEPA-compatible ScoreWithFeedback
67+
(implemented as a dspy.Prediction subclass in dspy.teleprompt.gepa).
68+
"""
69+
if row.evaluation_result is None:
70+
return dspy.Prediction(
71+
score=0.0,
72+
feedback="No evaluation_result was produced by the evaluation_test.",
73+
)
74+
75+
score = float(row.evaluation_result.score or 0.0)
76+
feedback = row.evaluation_result.reason or f"This trajectory got a score of {score}."
77+
return dspy.Prediction(score=score, feedback=feedback)
78+
79+
80+
def ep_test_to_gepa_metric(
81+
test_fn: TestFunction,
82+
) -> GEPAFeedbackMetric:
83+
"""
84+
Adapter: convert an EP-style `test_fn(row: EvaluationRow) -> EvaluationRow` into
85+
a GEPAFeedbackMetric-compatible callable.
86+
87+
The resulting metric:
88+
- Constructs an EvaluationRow from (gold, pred) using a simple heuristic.
89+
- Applies the EP test_fn to populate `row.evaluation_result`.
90+
- Returns a dspy.Prediction(score, feedback) derived from that result.
91+
"""
92+
93+
def metric(
94+
gold: Example,
95+
pred: Prediction,
96+
trace: Optional[DSPyTrace] = None,
97+
pred_name: Optional[str] = None,
98+
pred_trace: Optional[DSPyTrace] = None,
99+
) -> ScoreWithFeedback:
100+
row = gold_and_pred_to_row(gold, pred)
101+
102+
evaluated_row: EvaluationRow = test_fn(row) # pyright: ignore
103+
# TODO: this is problematic. for groupwise, we will have to extend this to handle list[EvaluationRow]
104+
105+
return row_to_prediction(evaluated_row)
106+
107+
return metric

0 commit comments

Comments
 (0)