diff --git a/eval_protocol/data_loader/models.py b/eval_protocol/data_loader/models.py index 0179272e..a2c0b8af 100644 --- a/eval_protocol/data_loader/models.py +++ b/eval_protocol/data_loader/models.py @@ -111,7 +111,7 @@ def _process_variant(self, result: DataLoaderResult) -> DataLoaderResult: def _apply_metadata(self, result: DataLoaderResult, original_count: int, processed_count: int) -> None: """Apply metadata to all rows in the result.""" - for row in result.rows: + for idx, row in enumerate(result.rows): if row.input_metadata.dataset_info is None: row.input_metadata.dataset_info = {} @@ -126,3 +126,4 @@ def _apply_metadata(self, result: DataLoaderResult, original_count: int, process # Apply row counts row.input_metadata.dataset_info["data_loader_num_rows"] = original_count row.input_metadata.dataset_info["data_loader_num_rows_after_preprocessing"] = processed_count + row.input_metadata.dataset_info["data_loader_row_idx"] = idx diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 0382ba40..0eea0c2e 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -171,6 +171,7 @@ def update_row_with_remote_trace( row.messages = remote_row.messages row.tools = remote_row.tools row.input_metadata.session_data = remote_row.input_metadata.session_data + row.input_metadata.dataset_info = remote_row.input_metadata.dataset_info row.execution_metadata = remote_row.execution_metadata return None else: diff --git a/tests/data_loader/test_data_loader_stable_row_id.py b/tests/data_loader/test_data_loader_stable_row_id.py new file mode 100644 index 00000000..d9aaab96 --- /dev/null +++ b/tests/data_loader/test_data_loader_stable_row_id.py @@ -0,0 +1,22 @@ +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, Message, EvaluateResult +from eval_protocol.pytest import evaluation_test +from typing import List + +def generator() -> list[EvaluationRow]: + return [EvaluationRow(messages=[Message(role="user", content="What is 2 + 2?")]) for _ in range(2)] + +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[generator], + ), + mode="all", +) +def test_data_loader_stable_row_id_with_same_content(rows: List[EvaluationRow]) -> List[EvaluationRow]: + """Test that the row id is stable even when the data loader is called multiple times.""" + row_ids = set() + for row in rows: + row_ids.add(row.input_metadata.row_id) + row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result") + assert len(row_ids) == 2 + return rows