diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index cafcb978..e2fabddd 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -167,7 +167,16 @@ def _load_data(): elif len(output_rows) == 1: # Return the Langfuse row langfuse_row = output_rows[0] langfuse_row.input_metadata.completion_params = row.input_metadata.completion_params + # merge dataset_info dicts on input_metadata + if langfuse_row.input_metadata.dataset_info and row.input_metadata.dataset_info: + langfuse_row.input_metadata.dataset_info = { + **row.input_metadata.dataset_info, + **langfuse_row.input_metadata.dataset_info, + } + elif row.input_metadata.dataset_info: + langfuse_row.input_metadata.dataset_info = row.input_metadata.dataset_info langfuse_row.eval_metadata = row.eval_metadata + langfuse_row.ground_truth = row.ground_truth return langfuse_row else: raise ValueError("RemoteRolloutProcessor's output_data_loader should return exactly one row.")