Skip to content

Commit c04acf8

Browse files
committed
fixes
1 parent ce61cad commit c04acf8

File tree

3 files changed

+43
-26
lines changed

3 files changed

+43
-26
lines changed

eval_protocol/pytest/integrations/openenv_trl_vllm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,13 @@ def rollout_func(prompts: List[str], trainer) -> Dict[str, List]:
121121

122122
eval_func = candidate_tests[0]
123123
ep_eval_func = eval_func # used later after rollouts complete
124-
ep_params: Dict[str, Any] = getattr(eval_func, "__ep_params__", {})
125-
ep_rollout_processor = ep_params.get("rollout_processor")
126-
ep_rollout_processor_kwargs = ep_params.get("rollout_processor_kwargs") or {}
127-
ep_mcp_config_path = ep_params.get("mcp_config_path") or ""
124+
ep_params = getattr(eval_func, "__ep_params__", None)
125+
# ep_params is an EPParameters model (Pydantic), use attribute access
126+
ep_rollout_processor = getattr(ep_params, "rollout_processor", None) if ep_params else None
127+
ep_rollout_processor_kwargs = (
128+
(getattr(ep_params, "rollout_processor_kwargs", None) or {}) if ep_params else {}
129+
)
130+
ep_mcp_config_path = (getattr(ep_params, "mcp_config_path", None) or "") if ep_params else ""
128131
logger.info(
129132
"[OpenEnvVLLM] Loaded eval test '%s' with rollout_processor=%s",
130133
getattr(eval_func, "__name__", str(eval_func)),

eval_protocol/training/gepa_trainer.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
from eval_protocol.models import EPParameters, EvaluationRow, Message
1212
from eval_protocol.pytest.types import TestFunction, RolloutProcessorConfig
13+
from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor
14+
from eval_protocol.pytest.execution import execute_pytest
15+
from eval_protocol.dataset_logger import default_logger
1316
from eval_protocol.training.trainer import Trainer
1417
from eval_protocol.training.utils import build_ep_parameters_from_test
1518
from eval_protocol.training.gepa_utils import (
@@ -98,12 +101,15 @@ def __init__(
98101
# Store configuration
99102
self._input_field = input_field
100103
self._output_field = output_field
104+
self._train_ratio = train_ratio
105+
self._val_ratio = val_ratio
106+
self._seed = seed
101107

102108
# Configure DSPy to use the same LLM as EP
103109
configure_dspy_lm(self.ep_params)
104110

105-
# Wrap the EP test function as a GEPA metric
106-
self.metric = ep_test_to_gepa_metric(test_fn)
111+
# Wrap the EP test function as a GEPA metric (with configured field names)
112+
self.metric = ep_test_to_gepa_metric(test_fn, input_field, output_field)
107113

108114
# Load and split the dataset
109115
self._rows: List[EvaluationRow] = self._load_dataset()
@@ -113,6 +119,10 @@ def __init__(
113119
val_ratio=val_ratio,
114120
seed=seed,
115121
)
122+
# Store original EvaluationRow objects for later use in evaluate_with_ep
123+
self._train_rows: List[EvaluationRow] = train_rows
124+
self._val_rows: List[EvaluationRow] = val_rows
125+
self._test_rows: List[EvaluationRow] = test_rows
116126

117127
# Extract the system prompt from the dataset (this is what GEPA will optimize!)
118128
self._initial_system_prompt = extract_system_prompt_from_rows(self._rows)
@@ -372,23 +382,13 @@ async def evaluate_with_ep(
372382
- 'score': Aggregate score
373383
- 'optimized_prompt': The prompt used for evaluation
374384
"""
375-
from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor
376-
from eval_protocol.pytest.execution import execute_pytest
377-
from eval_protocol.logging import default_logger
378-
379385
# Get optimized system prompt
380386
optimized_prompt = self.get_optimized_system_prompt(optimized_program)
381387

382388
# Get rows to evaluate
383389
if use_test_set:
384-
# Reconstruct test rows from test_set examples
385-
_, _, test_rows = train_val_test_split(
386-
self._rows,
387-
train_ratio=0.5, # Match the ratio used in training
388-
val_ratio=0.3,
389-
seed=42,
390-
)
391-
rows_to_eval = test_rows
390+
# Use stored test rows (same split from __init__)
391+
rows_to_eval = self._test_rows
392392
else:
393393
rows_to_eval = self._rows
394394

eval_protocol/training/gepa_utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,30 @@ def build_reflection_lm(reflection_lm_name: str) -> LM:
9191
return dspy.LM(model=reflection_lm_name)
9292

9393

94-
def gold_and_pred_to_row(gold: Example, pred: Prediction) -> EvaluationRow:
94+
def gold_and_pred_to_row(
95+
gold: Example,
96+
pred: Prediction,
97+
input_field: str = "problem",
98+
output_field: str = "answer",
99+
) -> EvaluationRow:
95100
"""
96101
Convert a GEPA (gold, pred) pair into an EvaluationRow for an EP `@evaluation_test`.
97102
98-
Assumptions (aligned with common DSPy usage):
99-
- `gold.answer` holds the ground-truth answer.
100-
- `pred.answer` holds the model's final answer text.
103+
Args:
104+
gold: The ground-truth example
105+
pred: The model's prediction
106+
input_field: Name of the input field in the DSPy signature
107+
output_field: Name of the output field in the DSPy signature
101108
102109
Note: ground_truth is preserved in its original type (list, dict, str, etc.)
103110
to support structured comparisons like SQL result matching.
104111
"""
105-
gt = gold.get("answer", None)
112+
gt = gold.get(output_field, None)
106113
# Preserve original type - don't convert to string!
107114
# This is important for SQL evaluators that expect list[dict] results
108115
ground_truth = gt
109116

110-
content = pred.get("answer", "")
117+
content = pred.get(output_field, "")
111118

112119
return EvaluationRow(
113120
messages=[
@@ -135,13 +142,20 @@ def row_to_prediction(row: EvaluationRow) -> ScoreWithFeedback:
135142

136143
def ep_test_to_gepa_metric(
137144
test_fn: TestFunction,
145+
input_field: str = "problem",
146+
output_field: str = "answer",
138147
) -> GEPAFeedbackMetric:
139148
"""
140149
Adapter: convert an EP-style `test_fn(row: EvaluationRow) -> EvaluationRow` into
141150
a GEPAFeedbackMetric-compatible callable.
142151
152+
Args:
153+
test_fn: The EP evaluation test function
154+
input_field: Name of the input field in the DSPy signature (default: "problem")
155+
output_field: Name of the output field in the DSPy signature (default: "answer")
156+
143157
The resulting metric:
144-
- Constructs an EvaluationRow from (gold, pred) using a simple heuristic.
158+
- Constructs an EvaluationRow from (gold, pred) using the configured field names.
145159
- Applies the EP test_fn to populate `row.evaluation_result`.
146160
- Returns a dspy.Prediction(score, feedback) derived from that result.
147161
@@ -158,7 +172,7 @@ def metric(
158172
pred_name: Optional[str] = None,
159173
pred_trace: Optional[DSPyTrace] = None,
160174
) -> ScoreWithFeedback:
161-
row = gold_and_pred_to_row(gold, pred)
175+
row = gold_and_pred_to_row(gold, pred, input_field, output_field)
162176

163177
# Call the test function - handle both sync and async
164178
result = test_fn(row) # pyright: ignore

0 commit comments

Comments
 (0)