Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 82 additions & 4 deletions eval_protocol/benchmarks/test_aime25.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
SingleTurnRolloutProcessor,
)
from eval_protocol.pytest.evaluation_test import evaluation_test
from eval_protocol.training import GEPATrainer
from eval_protocol.training.gepa_utils import build_reflection_lm

SYSTEM_PROMPT = (
"You are a helpful math assistant. Please reason step by step, and put your final answer within \\boxed{...}."
Expand Down Expand Up @@ -61,6 +63,44 @@ def _normalize_to_int_or_none(s: Optional[str]) -> Optional[int]:
return None


def _build_feedback_text(
*,
extracted_int: Optional[int],
gt_int: Optional[int],
is_valid: bool,
raw_model_answer: str,
ground_truth: Optional[str],
) -> str:
"""
Build a feedback string similar in spirit to the GEPA `metric_with_feedback`.

Cases:
- Parse failure (model or gold): explain integer formatting and show correct answer.
- Correct: "Your answer is correct. The correct answer is '...'."
- Incorrect: "Your answer is incorrect. The correct answer is '...'."
"""
correct_answer_display = str(gt_int if gt_int is not None else (ground_truth or ""))

if not is_valid:
# Could not parse either the model answer or the gold answer as an integer.
feedback_text = (
"The final answer must be a valid integer and nothing else. "
f"You responded with '{raw_model_answer}', which couldn't be parsed as a python integer. "
"Please ensure your answer is a valid integer without any additional text or formatting."
)
if correct_answer_display:
feedback_text += f" The correct answer is '{correct_answer_display}'."
return feedback_text

if extracted_int == gt_int:
return f"Your answer is correct. The correct answer is '{correct_answer_display}'."
else:
return f"Your answer is incorrect. The correct answer is '{correct_answer_display}'."

# TODO: our dataset does not contain written solutions, so we cannot provide feedback on the solution. maybe need to add it later.
# they're using https://huggingface.co/datasets/AI-MO/aimo-validation-aime


def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
converted: List[EvaluationRow] = []
for r in rows:
Expand All @@ -83,15 +123,14 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
completion_params=[
{
"max_tokens": 131000,
"extra_body": {"reasoning_effort": "low"},
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1-terminus",
}
],
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
passed_threshold=0.8,
num_runs=8,
max_dataset_rows=2,
max_dataset_rows=None, # Use full dataset
max_concurrent_rollouts=4,
mode="pointwise",
)
Expand Down Expand Up @@ -124,10 +163,49 @@ def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow:
)
}

feedback_text = _build_feedback_text(
extracted_int=extracted_int,
gt_int=gt_int,
is_valid=is_valid,
raw_model_answer=content_str,
ground_truth=str(row.ground_truth),
)

row.evaluation_result = EvaluateResult(
score=score,
reason=("Answer correct" if score == 1.0 else "Answer incorrect"),
reason=feedback_text,
is_score_valid=is_valid,
metrics=metrics,
)
return row


if __name__ == "__main__":
import asyncio

trainer = GEPATrainer(
test_aime25_pointwise,
train_ratio=0.5, # 50% for training (15 problems)
val_ratio=0.3, # 30% for validation (9 problems)
# test_ratio = 20% (6 problems) - calculated automatically
)

# Use same Fireworks model for both main and reflection
reflection_lm = build_reflection_lm("fireworks_ai/accounts/fireworks/models/deepseek-v3p1-terminus")

optimized_program = trainer.train(
num_threads=4, # Reduced from 32 to avoid API timeouts
track_stats=True,
reflection_minibatch_size=5, # Reduced to limit concurrent requests
reflection_lm=reflection_lm,
)

# Option 1: Quick DSPy evaluation (doesn't use EP infrastructure)
print("\n=== DSPy Evaluation ===")
print(trainer.evaluate(optimized_program))

# Option 2: Full EP evaluation (uses LLM proxy, Fireworks tracing, etc.)
# This goes through the normal @evaluation_test pipeline
print("\n=== EP Evaluation (with tracing) ===")
results = trainer.run_ep_evaluation(optimized_program)
print(f"Final EP Score: {results['score']:.3f}")
31 changes: 30 additions & 1 deletion eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib
from datetime import datetime, timezone
from enum import Enum
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union, Callable, Sequence

JSONType = Union[Dict[str, Any], List[Any], str, int, float, bool, None]

Expand Down Expand Up @@ -1190,3 +1190,32 @@ class MCPMultiClientConfiguration(BaseModel):
"""Represents a MCP configuration."""

mcpServers: Dict[str, Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]]


class EPParameters(BaseModel):
"""The parameters of an `@evaluation_test`. Used for trainable integrations."""

completion_params: Any = None
input_messages: Any = None
input_dataset: Any = None
input_rows: Any = None
data_loaders: Any = None
dataset_adapter: Optional[Callable[..., Any]] = None
rollout_processor: Any = None
rollout_processor_kwargs: Dict[str, Any] | None = None
aggregation_method: Any = Field(default="mean")
passed_threshold: Any = None
disable_browser_open: bool = False
num_runs: int = 1
filtered_row_ids: Optional[Sequence[str]] = None
max_dataset_rows: Optional[int] = None
mcp_config_path: Optional[str] = None
max_concurrent_rollouts: int = 8
max_concurrent_evaluations: int = 64
server_script_path: Optional[str] = None
steps: int = 30
mode: Any = Field(default="pointwise")
combine_datasets: bool = True
preprocess_fn: Optional[Callable[[list[EvaluationRow]], list[EvaluationRow]]] = None
logger: Any = None
exception_handler_config: Any = None
Comment thread
shreymodi1 marked this conversation as resolved.
Comment thread
cursor[bot] marked this conversation as resolved.
Comment thread
shreymodi1 marked this conversation as resolved.
35 changes: 28 additions & 7 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EvaluationThresholdDict,
EvaluateResult,
Status,
EPParameters,
)
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
Expand Down Expand Up @@ -695,13 +696,33 @@ async def _collect_result(config, lst):
)
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)

ep_params: dict[str, Any] = {
"rollout_processor": rollout_processor,
"server_script_path": server_script_path,
"mcp_config_path": mcp_config_path,
"rollout_processor_kwargs": rollout_processor_kwargs,
"mode": mode,
}
# Attach full evaluation parameter metadata for training integrations
ep_params: EPParameters = EPParameters(
completion_params=completion_params,
input_messages=input_messages,
input_dataset=input_dataset,
input_rows=input_rows,
data_loaders=data_loaders,
dataset_adapter=dataset_adapter,
rollout_processor=rollout_processor,
rollout_processor_kwargs=rollout_processor_kwargs,
aggregation_method=aggregation_method,
passed_threshold=passed_threshold,
disable_browser_open=disable_browser_open,
num_runs=num_runs,
filtered_row_ids=filtered_row_ids,
max_dataset_rows=max_dataset_rows,
mcp_config_path=mcp_config_path,
max_concurrent_rollouts=max_concurrent_rollouts,
max_concurrent_evaluations=max_concurrent_evaluations,
server_script_path=server_script_path,
steps=steps,
mode=mode,
combine_datasets=combine_datasets,
preprocess_fn=preprocess_fn,
logger=logger,
exception_handler_config=exception_handler_config,
)
Comment thread
cursor[bot] marked this conversation as resolved.

# Create the dual mode wrapper
dual_mode_wrapper = create_dual_mode_wrapper(
Expand Down
Loading
Loading