Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import inspect
import os
import sys
from collections import defaultdict
from typing import Any, Callable
from typing_extensions import Unpack
from collections.abc import Sequence

import pytest
from tqdm import tqdm

from eval_protocol.dataset_logger import default_logger
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
Expand Down Expand Up @@ -297,7 +299,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
exception_handler_config=exception_handler_config,
)

async def execute_run(i: int, config: RolloutProcessorConfig):
async def execute_run(run_idx: int, config: RolloutProcessorConfig):
nonlocal all_results

# Regenerate outputs each run by deep-copying the pristine dataset
Expand Down Expand Up @@ -357,13 +359,15 @@ async def _execute_groupwise_eval_with_semaphore(
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
pointwise_tasks: list[asyncio.Task[EvaluationRow]] = []
# Use wrapper that handles retry logic internally
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
async for row in rollout_processor_with_retry(
rollout_processor, fresh_dataset, config, run_idx
):
pointwise_tasks.append(
asyncio.create_task(_execute_pointwise_eval_with_semaphore(row=row))
)
results = await asyncio.gather(*pointwise_tasks)

all_results[i] = results
all_results[run_idx] = results
elif mode == "groupwise":
# rollout all the completion_params for the same row at once, and then send the output to the test_func
row_groups = defaultdict( # pyright: ignore[reportUnknownVariableType]
Expand All @@ -385,7 +389,9 @@ async def _execute_groupwise_eval_with_semaphore(

async def _collect_result(config, lst): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
result = []
async for row in rollout_processor_with_retry(rollout_processor, lst, config): # pyright: ignore[reportUnknownArgumentType]
async for row in rollout_processor_with_retry(
rollout_processor, lst, config, run_idx
): # pyright: ignore[reportUnknownArgumentType]
result.append(row) # pyright: ignore[reportUnknownMemberType]
return result # pyright: ignore[reportUnknownVariableType]

Expand All @@ -409,11 +415,13 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
for task in tasks:
res = await task
results.extend(res) # pyright: ignore[reportUnknownMemberType]
all_results[i] = results
all_results[run_idx] = results
else:
# Batch mode: collect all results first, then evaluate (no pipelining)
input_dataset = []
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
async for row in rollout_processor_with_retry(
rollout_processor, fresh_dataset, config, run_idx
):
input_dataset.append(row) # pyright: ignore[reportUnknownMemberType]
# NOTE: we will still evaluate errored rows (give users control over this)
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
Expand All @@ -438,7 +446,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
raise ValueError(
f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
)
all_results[i] = results
all_results[run_idx] = results

for r in results:
if r.eval_metadata is not None:
Expand Down Expand Up @@ -472,16 +480,34 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
# else, we execute runs in parallel
if isinstance(rollout_processor, MCPGymRolloutProcessor):
# For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
for i in range(num_runs):
task = asyncio.create_task(execute_run(i, config))
# For now, no tqdm progress bar because logs override it, we can revisit this later
for run_idx in range(num_runs):
task = asyncio.create_task(execute_run(run_idx, config))
await task
else:
# For other processors, create all tasks at once and run in parallel
# Concurrency is now controlled by the shared semaphore in each rollout processor
tasks = []
for i in range(num_runs):
tasks.append(asyncio.create_task(execute_run(i, config))) # pyright: ignore[reportUnknownMemberType]
await asyncio.gather(*tasks) # pyright: ignore[reportUnknownArgumentType]
with tqdm(
total=num_runs,
desc="Runs (Parallel)",
unit="run",
file=sys.__stderr__,
position=0,
leave=True,
dynamic_ncols=True,
miniters=1,
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
) as run_pbar:

async def execute_run_with_progress(run_idx: int, config):
result = await execute_run(run_idx, config)
run_pbar.update(1)
return result

tasks = []
for run_idx in range(num_runs):
tasks.append(asyncio.create_task(execute_run_with_progress(run_idx, config)))
await asyncio.gather(*tasks) # pyright: ignore[reportUnknownArgumentType]

# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
# rollout_id is used to differentiate the result from different completion_params
Expand Down
26 changes: 22 additions & 4 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from collections.abc import Sequence
import os
import re
import sys
from dataclasses import replace
from typing import Any, Literal

from tqdm import tqdm

from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.models import (
EvalMetadata,
Expand Down Expand Up @@ -157,6 +160,7 @@ async def rollout_processor_with_retry(
rollout_processor: RolloutProcessor,
fresh_dataset: list[EvaluationRow],
config: RolloutProcessorConfig,
run_idx: int = 0,
):
"""
Wrapper around rollout_processor that handles retry logic using the Python backoff library.
Expand Down Expand Up @@ -240,10 +244,24 @@ async def execute_row_with_backoff_and_log(task: asyncio.Task, row: EvaluationRo
for i, task in enumerate(base_tasks)
]

# Yield results as they complete
for task in asyncio.as_completed(retry_tasks):
result = await task
yield result
position = run_idx + 1 # Position 0 is reserved for main run bar, so shift up by 1
with tqdm(
total=len(retry_tasks),
desc=f" Run {position}",
unit="rollout",
file=sys.__stderr__,
leave=False,
position=position,
dynamic_ncols=True,
miniters=1,
mininterval=0.1,
bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
) as rollout_pbar:
# Yield results as they complete
for task in asyncio.as_completed(retry_tasks):
result = await task
rollout_pbar.update(1)
yield result

finally:
rollout_processor.cleanup()
Expand Down
Loading