diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 13ca5422..4c8df2df 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -169,9 +169,14 @@ def _load_data(): else: raise ValueError("RemoteRolloutProcessor's output_data_loader should return exactly one row.") - for r in rows: - tasks.append(asyncio.create_task(_process_row(r))) + semaphore = config.semaphore + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: + async with semaphore: + result = await _process_row(r) + return result + + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] return tasks def cleanup(self) -> None: