Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
results = data_loader.load()
for result in results:
data.extend(result.rows)
# Apply max_dataset_rows limit to data from data loaders
if max_dataset_rows is not None:
data = data[:max_dataset_rows]
elif "dataset_path" in kwargs and kwargs["dataset_path"] is not None:
ds_arg: list[str] = kwargs["dataset_path"]
# Support either a single path or a list of paths; if a list is provided,
Expand Down
27 changes: 27 additions & 0 deletions tests/data_loader/test_dynamic_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,30 @@ def test_dynamic_data_loader_lambda(row: EvaluationRow) -> EvaluationRow:
assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader"
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
return row


def generate_many_rows() -> list[EvaluationRow]:
"""Factory function that generates many evaluation rows for testing max_dataset_rows."""
return [EvaluationRow(messages=[Message(role="user", content=f"What is {i} + {i}?")]) for i in range(10)]


@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[generate_many_rows],
),
max_dataset_rows=3,
)
def test_dynamic_data_loader_max_dataset_rows(row: EvaluationRow) -> EvaluationRow:
"""Dynamic data loader should respect max_dataset_rows parameter."""

# This test should only process 3 rows despite the generator creating 10
# The row content should be from the first 3 generated rows
content = row.messages[0].content
assert content in ["What is 0 + 0?", "What is 1 + 1?", "What is 2 + 2?"]

assert row.input_metadata.dataset_info is not None
assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "generate_many_rows"
assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader"
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False

return row
21 changes: 21 additions & 0 deletions tests/data_loader/test_inline_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,24 @@ def test_inline_data_loader(row: EvaluationRow) -> EvaluationRow:
assert row.input_metadata.dataset_info.get("data_loader_variant_description") is None
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False
return row


@evaluation_test(
data_loaders=InlineDataLoader(
messages=[[Message(role="user", content=f"What is {i} + {i}?")] for i in range(5)],
),
max_dataset_rows=2,
)
def test_inline_data_loader_max_dataset_rows(row: EvaluationRow) -> EvaluationRow:
"""Inline data loader should respect max_dataset_rows parameter."""

# This test should only process 2 rows despite the loader having 5
content = row.messages[0].content
assert content in ["What is 0 + 0?", "What is 1 + 1?"]

assert row.input_metadata.dataset_info is not None
assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "inline"
assert row.input_metadata.dataset_info.get("data_loader_type") == "InlineDataLoader"
assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False

return row
Loading