Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion eval_protocol/data_loader/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _process_variant(self, result: DataLoaderResult) -> DataLoaderResult:

def _apply_metadata(self, result: DataLoaderResult, original_count: int, processed_count: int) -> None:
"""Apply metadata to all rows in the result."""
for row in result.rows:
for idx, row in enumerate(result.rows):
if row.input_metadata.dataset_info is None:
row.input_metadata.dataset_info = {}

Expand All @@ -126,3 +126,4 @@ def _apply_metadata(self, result: DataLoaderResult, original_count: int, process
# Apply row counts
row.input_metadata.dataset_info["data_loader_num_rows"] = original_count
row.input_metadata.dataset_info["data_loader_num_rows_after_preprocessing"] = processed_count
row.input_metadata.dataset_info["data_loader_row_idx"] = idx
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Row index added after preprocessing instead of before

The data_loader_row_idx is enumerated from rows after preprocessing, but the PR aims to add the original example index. When preprocess_fn filters rows, the indices get renumbered (e.g., original rows 0, 2, 4 become indices 0, 1, 2), losing track of the original positions. To capture original indices, enumeration needs to happen before preprocessing in _process_variant and the index preserved through the preprocessing step.

Fix in Cursor Fix in Web

1 change: 1 addition & 0 deletions eval_protocol/pytest/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def update_row_with_remote_trace(
row.messages = remote_row.messages
row.tools = remote_row.tools
row.input_metadata.session_data = remote_row.input_metadata.session_data
row.input_metadata.dataset_info = remote_row.input_metadata.dataset_info
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Original row index lost when copying remote dataset info

The complete overwriting of row.input_metadata.dataset_info with remote_row.input_metadata.dataset_info causes the original row's data_loader_row_idx to be lost. The remote row typically has data_loader_row_idx=0 after filter_longest_conversation preprocessing returns a single-element list, which overwrites the original index that tracked the row's position in the source dataset. This also loses any other original dataset metadata fields.

Fix in Cursor Fix in Web

row.execution_metadata = remote_row.execution_metadata
return None
else:
Expand Down
22 changes: 22 additions & 0 deletions tests/data_loader/test_data_loader_stable_row_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
from eval_protocol.pytest import evaluation_test
from typing import List

def generator() -> list[EvaluationRow]:
return [EvaluationRow(messages=[Message(role="user", content="What is 2 + 2?")]) for _ in range(2)]

@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[generator],
),
mode="all",
)
def test_data_loader_stable_row_id_with_same_content(rows: List[EvaluationRow]) -> List[EvaluationRow]:
"""Test that the row id is stable even when the data loader is called multiple times."""
row_ids = set()
for row in rows:
row_ids.add(row.input_metadata.row_id)
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
assert len(row_ids) == 2
return rows
Loading