@@ -282,14 +282,17 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
282282 # used to determine whether this eval has stopped
283283 row .pid = os .getpid ()
284284
285+ # Create shared semaphore for unified concurrency control across all runs and rollouts
286+ shared_semaphore = asyncio .Semaphore (max_concurrent_rollouts )
287+
285288 # Prepare rollout processor config once; we will generate fresh outputs per run
286289 config = RolloutProcessorConfig (
287290 completion_params = completion_params if completion_params is not None else {},
288291 mcp_config_path = mcp_config_path or "" ,
289- max_concurrent_rollouts = max_concurrent_rollouts ,
290292 server_script_path = server_script_path ,
291293 steps = steps ,
292294 logger = active_logger ,
295+ semaphore = shared_semaphore ,
293296 kwargs = rollout_processor_kwargs or {},
294297 exception_handler_config = exception_handler_config ,
295298 )
@@ -372,10 +375,10 @@ async def _execute_groupwise_eval_with_semaphore(
372375 config = RolloutProcessorConfig (
373376 completion_params = cp if cp is not None else {},
374377 mcp_config_path = mcp_config_path or "" ,
375- max_concurrent_rollouts = max_concurrent_rollouts ,
376378 server_script_path = server_script_path ,
377379 steps = steps ,
378380 logger = active_logger ,
381+ semaphore = shared_semaphore ,
379382 kwargs = rollout_processor_kwargs or {},
380383 )
381384 lst = []
@@ -474,6 +477,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
474477 await task
475478 else :
476479 # For other processors, create all tasks at once and run in parallel
480+ # Concurrency is now controlled by the shared semaphore in each rollout processor
477481 tasks = []
478482 for i in range (num_runs ):
479483 tasks .append (asyncio .create_task (execute_run (i , config ))) # pyright: ignore[reportUnknownMemberType]
0 commit comments