Skip to content

Commit b0cbc25

Browse files
authored
respect dataset adapter (#439)
1 parent 3c8d8f2 commit b0cbc25

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def evaluation_test(
211211
completion_params = parse_ep_completion_params_overwrite(completion_params)
212212
original_completion_params = completion_params
213213
passed_threshold = parse_ep_passed_threshold(passed_threshold)
214-
data_loaders = parse_ep_dataloaders(data_loaders)
214+
data_loaders = parse_ep_dataloaders(data_loaders, dataset_adapter=dataset_adapter)
215215
custom_invocation_id = os.environ.get("EP_INVOCATION_ID", None)
216216

217217
# ignore other data input params when dataloader is provided

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
EvaluationThresholdDict,
2222
Status,
2323
)
24+
from eval_protocol.common_utils import load_jsonl
2425
from eval_protocol.data_loader import DynamicDataLoader
2526
from eval_protocol.data_loader.models import EvaluationDataLoader
2627
from eval_protocol.pytest.rollout_processor import RolloutProcessor
@@ -288,10 +289,21 @@ def _rows_from_jsonl(path: str) -> list[EvaluationRow]:
288289

289290
def parse_ep_dataloaders(
290291
dataloaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None,
292+
*,
293+
dataset_adapter: Callable[[list[dict[str, Any]]], list[EvaluationRow]] | None = None,
291294
) -> Sequence[EvaluationDataLoader] | EvaluationDataLoader | None:
295+
"""When ``EP_JSONL_PATH`` is set, load JSONL as raw dicts and run ``dataset_adapter`` if provided.
296+
297+
Without ``dataset_adapter``, rows are built with ``EvaluationRow(**dict)`` (legacy behavior),
298+
which skips custom label fields that adapters normally attach.
299+
"""
292300
try:
293301
load_from_jsonl_path = os.getenv("EP_JSONL_PATH")
294302
if load_from_jsonl_path:
303+
if dataset_adapter is not None:
304+
return DynamicDataLoader(
305+
generators=[lambda path=load_from_jsonl_path, da=dataset_adapter: da(load_jsonl(path))]
306+
)
295307
return DynamicDataLoader(generators=[lambda path=load_from_jsonl_path: _rows_from_jsonl(path)])
296308
except Exception:
297309
pass

0 commit comments

Comments
 (0)