Skip to content

Commit 3184377

Browse files
authored
log eval time (#392)
add
1 parent c260b5a commit 3184377

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ async def execute_run(run_idx: int, config: RolloutProcessorConfig):
489489
async def _execute_pointwise_eval_with_semaphore(
490490
row: EvaluationRow,
491491
) -> EvaluationRow:
492+
start_time = time.perf_counter()
492493
async with semaphore:
493494
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
494495
async with rollout_logging_context(
@@ -505,11 +506,15 @@ async def _execute_pointwise_eval_with_semaphore(
505506
raise ValueError(
506507
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
507508
)
509+
result.execution_metadata.eval_duration_seconds = (
510+
time.perf_counter() - start_time
511+
)
508512
return result
509513

510514
async def _execute_groupwise_eval_with_semaphore(
511515
rows: list[EvaluationRow],
512516
) -> list[EvaluationRow]:
517+
start_time = time.perf_counter()
513518
async with semaphore:
514519
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
515520
primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None
@@ -531,6 +536,9 @@ async def _execute_groupwise_eval_with_semaphore(
531536
raise ValueError(
532537
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
533538
)
539+
eval_duration = time.perf_counter() - start_time
540+
for r in results:
541+
r.execution_metadata.eval_duration_seconds = eval_duration
534542
return results
535543

536544
if mode == "pointwise":
@@ -617,11 +625,16 @@ async def _collect_result(config, lst):
617625
run_id=run_id,
618626
rollout_ids=group_rollout_ids or None,
619627
):
628+
start_time = time.perf_counter()
620629
results = await execute_pytest_with_exception_handling(
621630
test_func=test_func,
622631
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
623632
processed_dataset=input_dataset,
624633
)
634+
if isinstance(results, list):
635+
eval_duration = time.perf_counter() - start_time
636+
for r in results:
637+
r.execution_metadata.eval_duration_seconds = eval_duration
625638
if (
626639
results is None
627640
or not isinstance(results, list)

0 commit comments

Comments
 (0)