From 2b820c59201f90dca24d4e51646323ea3a13522f Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 29 Sep 2025 12:07:33 -0700 Subject: [PATCH] Add max_dataset_rows support to data loaders in evaluation tests - Implemented max_dataset_rows limit in evaluation_test function to restrict the number of rows processed from data loaders. - Added tests for DynamicDataLoader and InlineDataLoader to ensure they respect the max_dataset_rows parameter, verifying that only the specified number of rows are processed despite larger datasets being generated. --- eval_protocol/pytest/evaluation_test.py | 3 +++ tests/data_loader/test_dynamic_data_loader.py | 27 +++++++++++++++++++ tests/data_loader/test_inline_data_loader.py | 21 +++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 4625114a..2cf8ce0c 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -262,6 +262,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo results = data_loader.load() for result in results: data.extend(result.rows) + # Apply max_dataset_rows limit to data from data loaders + if max_dataset_rows is not None: + data = data[:max_dataset_rows] elif "dataset_path" in kwargs and kwargs["dataset_path"] is not None: ds_arg: list[str] = kwargs["dataset_path"] # Support either a single path or a list of paths; if a list is provided, diff --git a/tests/data_loader/test_dynamic_data_loader.py b/tests/data_loader/test_dynamic_data_loader.py index 18780d9d..73e24bbb 100644 --- a/tests/data_loader/test_dynamic_data_loader.py +++ b/tests/data_loader/test_dynamic_data_loader.py @@ -46,3 +46,30 @@ def test_dynamic_data_loader_lambda(row: EvaluationRow) -> EvaluationRow: assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader" assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False return row + + +def generate_many_rows() -> list[EvaluationRow]: + """Factory function that generates many evaluation rows for testing max_dataset_rows.""" + return [EvaluationRow(messages=[Message(role="user", content=f"What is {i} + {i}?")]) for i in range(10)] + + +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[generate_many_rows], + ), + max_dataset_rows=3, +) +def test_dynamic_data_loader_max_dataset_rows(row: EvaluationRow) -> EvaluationRow: + """Dynamic data loader should respect max_dataset_rows parameter.""" + + # This test should only process 3 rows despite the generator creating 10 + # The row content should be from the first 3 generated rows + content = row.messages[0].content + assert content in ["What is 0 + 0?", "What is 1 + 1?", "What is 2 + 2?"] + + assert row.input_metadata.dataset_info is not None + assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "generate_many_rows" + assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader" + assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False + + return row diff --git a/tests/data_loader/test_inline_data_loader.py b/tests/data_loader/test_inline_data_loader.py index a6cde17c..2df3fa24 100644 --- a/tests/data_loader/test_inline_data_loader.py +++ b/tests/data_loader/test_inline_data_loader.py @@ -21,3 +21,24 @@ def test_inline_data_loader(row: EvaluationRow) -> EvaluationRow: assert row.input_metadata.dataset_info.get("data_loader_variant_description") is None assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False return row + + +@evaluation_test( + data_loaders=InlineDataLoader( + messages=[[Message(role="user", content=f"What is {i} + {i}?")] for i in range(5)], + ), + max_dataset_rows=2, +) +def test_inline_data_loader_max_dataset_rows(row: EvaluationRow) -> EvaluationRow: + """Inline data loader should respect max_dataset_rows parameter.""" + + # This test should only process 2 rows despite the loader having 5 + content = row.messages[0].content + assert content in ["What is 0 + 0?", "What is 1 + 1?"] + + assert row.input_metadata.dataset_info is not None + assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "inline" + assert row.input_metadata.dataset_info.get("data_loader_type") == "InlineDataLoader" + assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False + + return row