Skip to content

Commit 55b494e

Browse files
committed
Document stub imports for dataset prep tests
1 parent c40dc2e commit 55b494e

File tree

3 files changed

+355
-38
lines changed

3 files changed

+355
-38
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Utilities for preparing datasets for evaluation tests."""
2+
3+
from collections.abc import Callable
4+
from typing import Any
5+
6+
from eval_protocol.human_id import generate_id, num_combinations
7+
from eval_protocol.models import EvaluationRow
8+
from eval_protocol.pytest.generate_parameter_combinations import ParameterizedTestKwargs
9+
from eval_protocol.pytest.types import Dataset
10+
11+
from ..common_utils import load_jsonl
12+
13+
14+
def load_and_prepare_rows(
15+
kwargs: ParameterizedTestKwargs,
16+
*,
17+
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset],
18+
preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None,
19+
max_dataset_rows: int | None,
20+
) -> list[EvaluationRow]:
21+
"""Load and preprocess evaluation rows based on parameterized pytest kwargs.
22+
23+
This helper consolidates the logic that loads input data from various sources
24+
(dataset paths, raw messages, or pre-built :class:`EvaluationRow` objects),
25+
applies optional preprocessing, and ensures each row has a stable
26+
``row_id``. The behavior mirrors the original inline implementation inside
27+
:func:`eval_protocol.pytest.evaluation_test.evaluation_test`.
28+
"""
29+
30+
data: list[EvaluationRow] = []
31+
32+
if kwargs.get("dataset_path") is not None:
33+
ds_arg = kwargs["dataset_path"]
34+
data_jsonl: list[dict[str, Any]] = []
35+
for path in ds_arg or []:
36+
data_jsonl.extend(load_jsonl(path))
37+
if max_dataset_rows is not None:
38+
data_jsonl = data_jsonl[:max_dataset_rows]
39+
data = dataset_adapter(data_jsonl)
40+
elif kwargs.get("input_messages") is not None:
41+
input_messages = kwargs["input_messages"] or []
42+
data = [EvaluationRow(messages=dataset_messages) for dataset_messages in input_messages]
43+
elif kwargs.get("input_rows") is not None:
44+
input_rows = kwargs["input_rows"] or []
45+
data = [row.model_copy(deep=True) for row in input_rows]
46+
else:
47+
raise ValueError("No input dataset, input messages, or input rows provided")
48+
49+
if preprocess_fn:
50+
data = preprocess_fn(data)
51+
52+
for row in data:
53+
if row.input_metadata.row_id is None:
54+
index = hash(row)
55+
max_index = num_combinations() - 1
56+
index = abs(index) % (max_index + 1)
57+
row.input_metadata.row_id = generate_id(seed=0, index=index)
58+
59+
return data

eval_protocol/pytest/evaluation_test.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Status,
2424
)
2525
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
26+
from eval_protocol.pytest.dataset_preparation import load_and_prepare_rows
2627
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
2728
from eval_protocol.pytest.execution import execute_pytest
2829
from eval_protocol.pytest.generate_parameter_combinations import (
@@ -60,9 +61,6 @@
6061
rollout_processor_with_retry,
6162
)
6263

63-
from ..common_utils import load_jsonl
64-
65-
6664
def evaluation_test(
6765
*,
6866
completion_params: Sequence[CompletionParams | None] | None = None,
@@ -223,43 +221,14 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
223221
log_eval_status_and_rows(eval_metadata, rows, status, passed, active_logger)
224222

225223
try:
226-
# Handle dataset loading
227-
data: list[EvaluationRow] = []
228224
# Track all rows processed in the current run for error logging
229225
processed_rows_in_run: list[EvaluationRow] = []
230-
if "dataset_path" in kwargs and kwargs["dataset_path"] is not None:
231-
ds_arg: list[str] = kwargs["dataset_path"]
232-
# Support either a single path or a list of paths; if a list is provided,
233-
# concatenate the rows from each file in order.
234-
data_jsonl: list[dict[str, object]] = []
235-
for p in ds_arg:
236-
data_jsonl.extend(load_jsonl(p))
237-
# Apply override for max rows if present
238-
if max_dataset_rows is not None:
239-
data_jsonl = data_jsonl[:max_dataset_rows]
240-
data = dataset_adapter(data_jsonl)
241-
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
242-
# Support either a single row (List[Message]) or many rows (List[List[Message]])
243-
im = kwargs["input_messages"]
244-
data = [EvaluationRow(messages=dataset_messages) for dataset_messages in im]
245-
elif "input_rows" in kwargs and kwargs["input_rows"] is not None:
246-
# Deep copy pre-constructed EvaluationRow objects
247-
data = [row.model_copy(deep=True) for row in kwargs["input_rows"]]
248-
else:
249-
raise ValueError("No input dataset, input messages, or input rows provided")
250-
251-
if preprocess_fn:
252-
data = preprocess_fn(data)
253-
254-
for row in data:
255-
# generate a stable row_id for each row
256-
if row.input_metadata.row_id is None:
257-
# Generate a stable, deterministic row_id using the row's hash and num_combinations
258-
index = hash(row)
259-
max_index = num_combinations() - 1
260-
# Ensure index is a non-negative integer within [0, max_index]
261-
index = abs(index) % (max_index + 1)
262-
row.input_metadata.row_id = generate_id(seed=0, index=index)
226+
data = load_and_prepare_rows(
227+
kwargs,
228+
dataset_adapter=dataset_adapter,
229+
preprocess_fn=preprocess_fn,
230+
max_dataset_rows=max_dataset_rows,
231+
)
263232

264233
completion_params = kwargs["completion_params"]
265234
# Create eval metadata with test function info and current commit hash

0 commit comments

Comments
 (0)