Skip to content

Commit 7b3c420

Browse files
committed
updated
1 parent 3336b90 commit 7b3c420

File tree

3 files changed

+59
-5
lines changed

3 files changed

+59
-5
lines changed

eval_protocol/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,7 @@ class EPParameters(BaseModel):
12051205
dataset_adapter: Optional[Callable[..., Any]] = None
12061206
rollout_processor: Any = None
12071207
rollout_processor_kwargs: Dict[str, Any] | None = None
1208+
evaluation_test_kwargs: Any = None
12081209
aggregation_method: Any = Field(default="mean")
12091210
passed_threshold: Any = None
12101211
disable_browser_open: bool = False

eval_protocol/pytest/evaluation_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ async def _collect_result(config, lst):
706706
dataset_adapter=dataset_adapter,
707707
rollout_processor=rollout_processor,
708708
rollout_processor_kwargs=rollout_processor_kwargs,
709+
evaluation_test_kwargs=evaluation_test_kwargs,
709710
aggregation_method=aggregation_method,
710711
passed_threshold=passed_threshold,
711712
disable_browser_open=disable_browser_open,

eval_protocol/training/gepa_trainer.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,29 @@ def _load_dataset(self) -> List[EvaluationRow]:
157157
158158
Supports:
159159
- input_rows: Pre-constructed EvaluationRow objects
160+
- Can be List[EvaluationRow] (direct usage)
161+
- Or Sequence[list[EvaluationRow]] (parameterized usage)
160162
- input_dataset: Paths to JSONL files (requires dataset_adapter)
161163
- input_messages: Raw message lists
164+
- data_loaders: EvaluationDataLoader instances
162165
"""
163166
ep = self.ep_params
164167

165168
# Case 1: Pre-constructed rows
169+
# Handle both direct List[EvaluationRow] and parameterized Sequence[list[EvaluationRow]]
166170
if ep.input_rows:
167-
return list(ep.input_rows)
171+
rows_input = ep.input_rows
172+
# Check if it's a list of EvaluationRows (direct) or list of lists (parameterized)
173+
if rows_input and isinstance(rows_input[0], EvaluationRow):
174+
# Direct usage: List[EvaluationRow]
175+
return list(rows_input)
176+
else:
177+
# Parameterized usage: Sequence[list[EvaluationRow]]
178+
all_rows: List[EvaluationRow] = []
179+
for rows_list in rows_input:
180+
if rows_list is not None:
181+
all_rows.extend(rows_list)
182+
return all_rows
168183

169184
# Case 2: Dataset paths with adapter
170185
if ep.input_dataset and ep.dataset_adapter:
@@ -183,17 +198,54 @@ def _load_dataset(self) -> List[EvaluationRow]:
183198
return ep.dataset_adapter(all_data)
184199

185200
# Case 3: Input messages (convert to rows)
201+
# Handle both direct List[List[Message]] and parameterized Sequence[list[list[Message]] | None]
186202
if ep.input_messages:
187-
from eval_protocol.models import Message
203+
rows: List[EvaluationRow] = []
204+
messages_input = ep.input_messages
205+
206+
# Check if first element is a Message (direct list of conversations) or a list (parameterized)
207+
if messages_input and messages_input[0]:
208+
first_elem = messages_input[0]
209+
# Check if it's List[Message] (a single conversation) or List[List[Message]]
210+
if hasattr(first_elem, "role"):
211+
# It's a Message - so input is a single conversation List[Message]
212+
rows.append(EvaluationRow(messages=list(messages_input)))
213+
elif first_elem and hasattr(first_elem[0], "role"):
214+
# It's List[List[Message]] - direct usage with multiple conversations
215+
for messages in messages_input:
216+
if messages:
217+
rows.append(EvaluationRow(messages=messages))
218+
else:
219+
# Parameterized usage: Sequence[list[list[Message]] | None]
220+
for messages_list in messages_input:
221+
if messages_list is not None:
222+
for messages in messages_list:
223+
rows.append(EvaluationRow(messages=messages))
224+
return rows
225+
226+
# Case 4: Data loaders
227+
if ep.data_loaders:
228+
from eval_protocol.data_loader.models import EvaluationDataLoader
188229

189230
rows = []
190-
for messages in ep.input_messages:
191-
rows.append(EvaluationRow(messages=messages))
231+
data_loaders = ep.data_loaders
232+
data_loaders_list = (
233+
[data_loaders] if isinstance(data_loaders, EvaluationDataLoader) else list(data_loaders)
234+
)
235+
for data_loader in data_loaders_list:
236+
results = data_loader.load()
237+
for result in results:
238+
rows.extend(result.rows)
239+
240+
# Apply max_dataset_rows limit
241+
if ep.max_dataset_rows:
242+
rows = rows[: ep.max_dataset_rows]
243+
192244
return rows
193245

194246
raise ValueError(
195247
"No dataset found in ep_params. "
196-
"Provide input_rows, input_dataset (with dataset_adapter), or input_messages."
248+
"Provide input_rows, input_dataset (with dataset_adapter), input_messages, or data_loaders."
197249
)
198250

199251
@property

0 commit comments

Comments
 (0)