Skip to content

Commit 0ec8ba5

Browse files
author
Dylan Huang
authored
Add max_dataset_rows support to data loaders in evaluation tests (#232)
- Implemented max_dataset_rows limit in evaluation_test function to restrict the number of rows processed from data loaders. - Added tests for DynamicDataLoader and InlineDataLoader to ensure they respect the max_dataset_rows parameter, verifying that only the specified number of rows are processed despite larger datasets being generated.
1 parent 2d97758 commit 0ec8ba5

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
262262
results = data_loader.load()
263263
for result in results:
264264
data.extend(result.rows)
265+
# Apply max_dataset_rows limit to data from data loaders
266+
if max_dataset_rows is not None:
267+
data = data[:max_dataset_rows]
265268
elif "dataset_path" in kwargs and kwargs["dataset_path"] is not None:
266269
ds_arg: list[str] = kwargs["dataset_path"]
267270
# Support either a single path or a list of paths; if a list is provided,

tests/data_loader/test_dynamic_data_loader.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,30 @@ def test_dynamic_data_loader_lambda(row: EvaluationRow) -> EvaluationRow:
4646
assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader"
4747
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
4848
return row
49+
50+
51+
def generate_many_rows() -> list[EvaluationRow]:
52+
"""Factory function that generates many evaluation rows for testing max_dataset_rows."""
53+
return [EvaluationRow(messages=[Message(role="user", content=f"What is {i} + {i}?")]) for i in range(10)]
54+
55+
56+
@evaluation_test(
57+
data_loaders=DynamicDataLoader(
58+
generators=[generate_many_rows],
59+
),
60+
max_dataset_rows=3,
61+
)
62+
def test_dynamic_data_loader_max_dataset_rows(row: EvaluationRow) -> EvaluationRow:
63+
"""Dynamic data loader should respect max_dataset_rows parameter."""
64+
65+
# This test should only process 3 rows despite the generator creating 10
66+
# The row content should be from the first 3 generated rows
67+
content = row.messages[0].content
68+
assert content in ["What is 0 + 0?", "What is 1 + 1?", "What is 2 + 2?"]
69+
70+
assert row.input_metadata.dataset_info is not None
71+
assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "generate_many_rows"
72+
assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader"
73+
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
74+
75+
return row

tests/data_loader/test_inline_data_loader.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,24 @@ def test_inline_data_loader(row: EvaluationRow) -> EvaluationRow:
2121
assert row.input_metadata.dataset_info.get("data_loader_variant_description") is None
2222
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
2323
return row
24+
25+
26+
@evaluation_test(
27+
data_loaders=InlineDataLoader(
28+
messages=[[Message(role="user", content=f"What is {i} + {i}?")] for i in range(5)],
29+
),
30+
max_dataset_rows=2,
31+
)
32+
def test_inline_data_loader_max_dataset_rows(row: EvaluationRow) -> EvaluationRow:
33+
"""Inline data loader should respect max_dataset_rows parameter."""
34+
35+
# This test should only process 2 rows despite the loader having 5
36+
content = row.messages[0].content
37+
assert content in ["What is 0 + 0?", "What is 1 + 1?"]
38+
39+
assert row.input_metadata.dataset_info is not None
40+
assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "inline"
41+
assert row.input_metadata.dataset_info.get("data_loader_type") == "InlineDataLoader"
42+
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
43+
44+
return row

0 commit comments

Comments
 (0)