@@ -45,16 +45,22 @@ def __init__(
4545 active_logger : DatasetLogger ,
4646 eval_executor : Callable [[Union [EvaluationRow , List [EvaluationRow ]]], Awaitable [Union [EvaluationRow , List [EvaluationRow ]]]], # Callback to run evaluation
4747 mini_batch_data_buffer : Optional [MiniBatchDataBuffer ] = None ,
48+ max_concurrent_evaluations : Optional [int ] = None ,
4849 ):
4950 self .rollout_processor = rollout_processor
5051 self .max_concurrent_rollouts = max_concurrent_rollouts
52+ self .max_concurrent_evaluations = max_concurrent_evaluations
5153 self .active_logger = active_logger
5254 self .eval_executor = eval_executor
5355 self .mini_batch_data_buffer = mini_batch_data_buffer
5456
5557 # Priority Queue: Stores RolloutTask
5658 self .queue : asyncio .PriorityQueue [RolloutTask ] = asyncio .PriorityQueue ()
5759
60+ # Concurrency Control
61+ self .rollout_sem = asyncio .Semaphore (max_concurrent_rollouts )
62+ self .eval_sem = asyncio .Semaphore (max_concurrent_evaluations ) if max_concurrent_evaluations else None
63+
5864 self .num_runs = 0
5965 self .micro_batch_size = 0
6066
@@ -140,31 +146,48 @@ async def _process_task(self, task: RolloutTask):
140146 if task .run_indices :
141147 representative_run_idx = task .run_indices [0 ]
142148
143- async for result_row in rollout_processor_with_retry (
144- self .rollout_processor , current_batch_rows , task .config , representative_run_idx
145- ):
146- batch_results .append (result_row )
149+ async with self .rollout_sem :
150+ async for result_row in rollout_processor_with_retry (
151+ self .rollout_processor , current_batch_rows , task .config , representative_run_idx
152+ ):
153+ batch_results .append (result_row )
147154
148155 # 3. Evaluate and Collect History
149156 current_batch_history_updates = []
150157
151- for res in batch_results :
152- # Run Evaluation
153- eval_res = await self .eval_executor (res )
154-
155- # Depending on the execution mode, eval_executor might return a single row or a list
156- # For pointwise, it's a single row. For groupwise, it's a list.
157- # Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back
158- # But to be safe and type-correct, we handle both.
159-
160- if isinstance (eval_res , list ):
161- # Should not happen in pointwise mode which is typically used with this scheduler
162- # But if it does, we process each result
163- for r in eval_res :
158+ async def _run_eval ():
159+ for res in batch_results :
160+ # Run Evaluation
161+ eval_res = await self .eval_executor (res )
162+
163+ # Depending on the execution mode, eval_executor might return a single row or a list
164+ # For pointwise, it's a single row. For groupwise, it's a list.
165+ # Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back
166+ # But to be safe and type-correct, we handle both.
167+
168+ if isinstance (eval_res , list ):
169+ # Should not happen in pointwise mode which is typically used with this scheduler
170+ # But if it does, we process each result
171+ for r in eval_res :
172+ if self .mini_batch_data_buffer :
173+ await self .mini_batch_data_buffer .add_result (r )
174+
175+ last_msg = r .last_assistant_message ()
176+ if last_msg and last_msg .content :
177+ content = last_msg .content
178+ if isinstance (content , list ):
179+ text_parts = [p ["text" ] for p in content if p ["type" ] == "text" ]
180+ current_batch_history_updates .append ("" .join (text_parts ))
181+ else :
182+ current_batch_history_updates .append (str (content ))
183+ else :
184+ current_batch_history_updates .append ("" )
185+ else :
164186 if self .mini_batch_data_buffer :
165- await self .mini_batch_data_buffer .add_result (r )
166-
167- last_msg = r .last_assistant_message ()
187+ await self .mini_batch_data_buffer .add_result (eval_res )
188+
189+ # Extract prediction for history
190+ last_msg = eval_res .last_assistant_message ()
168191 if last_msg and last_msg .content :
169192 content = last_msg .content
170193 if isinstance (content , list ):
@@ -173,22 +196,13 @@ async def _process_task(self, task: RolloutTask):
173196 else :
174197 current_batch_history_updates .append (str (content ))
175198 else :
176- current_batch_history_updates .append ("" )
177- else :
178- if self .mini_batch_data_buffer :
179- await self .mini_batch_data_buffer .add_result (eval_res )
199+ current_batch_history_updates .append ("" ) # Empty string for failed turns
180200
181- # Extract prediction for history
182- last_msg = eval_res .last_assistant_message ()
183- if last_msg and last_msg .content :
184- content = last_msg .content
185- if isinstance (content , list ):
186- text_parts = [p ["text" ] for p in content if p ["type" ] == "text" ]
187- current_batch_history_updates .append ("" .join (text_parts ))
188- else :
189- current_batch_history_updates .append (str (content ))
190- else :
191- current_batch_history_updates .append ("" ) # Empty string for failed turns
201+ if self .eval_sem :
202+ async with self .eval_sem :
203+ await _run_eval ()
204+ else :
205+ await _run_eval ()
192206
193207 # 4. Schedule Next Micro-batch (High Priority)
194208 last_run_idx = task .run_indices [- 1 ]
@@ -220,7 +234,12 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_siz
220234 await self .schedule_dataset (dataset , base_config )
221235
222236 # 2. Start Workers
223- workers = [asyncio .create_task (self .worker ()) for _ in range (self .max_concurrent_rollouts )]
237+ # If we have separate limits, we need enough workers to saturate both stages
238+ num_workers = self .max_concurrent_rollouts
239+ if self .max_concurrent_evaluations :
240+ num_workers += self .max_concurrent_evaluations
241+
242+ workers = [asyncio .create_task (self .worker ()) for _ in range (num_workers )]
224243
225244 # 3. Wait for completion
226245 await self .queue .join ()
@@ -246,12 +265,14 @@ async def execute_priority_rollouts(
246265 active_logger : DatasetLogger ,
247266 eval_executor : Callable [[Union [EvaluationRow , List [EvaluationRow ]]], Awaitable [Union [EvaluationRow , List [EvaluationRow ]]]],
248267 mini_batch_data_buffer : Optional [MiniBatchDataBuffer ] = None ,
268+ max_concurrent_evaluations : Optional [int ] = None ,
249269):
250270 scheduler = PriorityRolloutScheduler (
251271 rollout_processor = rollout_processor ,
252272 max_concurrent_rollouts = max_concurrent_rollouts ,
253273 active_logger = active_logger ,
254274 eval_executor = eval_executor ,
255- mini_batch_data_buffer = mini_batch_data_buffer
275+ mini_batch_data_buffer = mini_batch_data_buffer ,
276+ max_concurrent_evaluations = max_concurrent_evaluations
256277 )
257278 return await scheduler .run (dataset , num_runs , micro_batch_size , config )
0 commit comments