@@ -44,15 +44,17 @@ def __init__(
4444 max_concurrent_rollouts : int ,
4545 active_logger : DatasetLogger ,
4646 eval_executor : Callable [[Union [EvaluationRow , List [EvaluationRow ]]], Awaitable [Union [EvaluationRow , List [EvaluationRow ]]]], # Callback to run evaluation
47- mini_batch_data_buffer : Optional [MiniBatchDataBuffer ] = None ,
47+ output_buffer : Optional [MiniBatchDataBuffer ] = None ,
4848 max_concurrent_evaluations : Optional [int ] = None ,
49+ mode : str = "pointwise" ,
4950 ):
5051 self .rollout_processor = rollout_processor
5152 self .max_concurrent_rollouts = max_concurrent_rollouts
5253 self .max_concurrent_evaluations = max_concurrent_evaluations
5354 self .active_logger = active_logger
5455 self .eval_executor = eval_executor
55- self .mini_batch_data_buffer = mini_batch_data_buffer
56+ self .output_buffer = output_buffer
57+ self .mode = mode
5658
5759 # Priority Queue: Stores RolloutTask
5860 self .queue : asyncio .PriorityQueue [RolloutTask ] = asyncio .PriorityQueue ()
@@ -61,6 +63,10 @@ def __init__(
6163 self .rollout_sem = asyncio .Semaphore (max_concurrent_rollouts )
6264 self .eval_sem = asyncio .Semaphore (max_concurrent_evaluations ) if max_concurrent_evaluations else None
6365
66+ # Results storage
67+ self .results : List [EvaluationRow ] = [] # for backward compatibility reason, we save all results here to return
68+ self .groups_buffer : Dict [int , List [EvaluationRow ]] = defaultdict (list ) # buffer for group results. only flush to output buffer when a whole group is ready
69+
6470 self .num_runs = 0
6571 self .micro_batch_size = 0
6672
@@ -155,24 +161,85 @@ async def _process_task(self, task: RolloutTask):
155161 # 3. Evaluate and Collect History
156162 current_batch_history_updates = []
157163
158- async def _run_eval ():
159- for res in batch_results :
160- # Run Evaluation
161- eval_res = await self .eval_executor (res )
162-
163- # Depending on the execution mode, eval_executor might return a single row or a list
164- # For pointwise, it's a single row. For groupwise, it's a list.
165- # Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back
166- # But to be safe and type-correct, we handle both.
164+ if self .mode == "groupwise" :
165+ # Collect all results from this batch
166+ for res in batch_results :
167+ self .groupwise_buffer [task .row_index ].append (res )
167168
168- if isinstance (eval_res , list ):
169- # Should not happen in pointwise mode which is typically used with this scheduler
170- # But if it does, we process each result
171- for r in eval_res :
169+ # Update history from rollout result (assuming eval doesn't change content needed for history)
170+ last_msg = res .last_assistant_message ()
171+ if last_msg and last_msg .content :
172+ content = last_msg .content
173+ if isinstance (content , list ):
174+ text_parts = [p ["text" ] for p in content if p ["type" ] == "text" ]
175+ current_batch_history_updates .append ("" .join (text_parts ))
176+ else :
177+ current_batch_history_updates .append (str (content ))
178+ else :
179+ current_batch_history_updates .append ("" )
180+
181+ # Check if this is the last batch for this sample
182+ last_run_idx = task .run_indices [- 1 ]
183+ if last_run_idx + 1 >= self .num_runs :
184+ # Last batch: Execute Groupwise Evaluation
185+ full_group = self .groupwise_buffer [task .row_index ]
186+
187+ async def _run_group_eval ():
188+ eval_res = await self .eval_executor (full_group )
189+ # Handle result (could be list or single row wrapping list?)
190+ # Usually groupwise returns list of scored rows
191+ if isinstance (eval_res , list ):
192+ self .results .extend (eval_res )
193+ if self .mini_batch_data_buffer :
194+ # Push the whole group at once if possible, or iterate
195+ for r in eval_res :
196+ await self .mini_batch_data_buffer .add_result (r )
197+ else :
198+ self .results .append (eval_res )
199+ if self .mini_batch_data_buffer :
200+ await self .mini_batch_data_buffer .add_result (eval_res )
201+
202+ if self .eval_sem :
203+ async with self .eval_sem :
204+ await _run_group_eval ()
205+ else :
206+ await _run_group_eval ()
207+
208+ # Clear buffer to free memory
209+ del self .groupwise_buffer [task .row_index ]
210+
211+ else :
212+ # Pointwise: Process each result individually
213+ async def _run_eval ():
214+ for res in batch_results :
215+ # Run Evaluation
216+ eval_res = await self .eval_executor (res )
217+
218+ if isinstance (eval_res , list ):
219+ # Should not happen in pointwise mode which is typically used with this scheduler
220+ # But if it does, we process each result
221+ self .results .extend (eval_res )
222+ for r in eval_res :
223+ if self .mini_batch_data_buffer :
224+ await self .mini_batch_data_buffer .add_result (r )
225+
226+ last_msg = r .last_assistant_message ()
227+ if last_msg and last_msg .content :
228+ content = last_msg .content
229+ if isinstance (content , list ):
230+ text_parts = [p ["text" ] for p in content if p ["type" ] == "text" ]
231+ current_batch_history_updates .append ("" .join (text_parts ))
232+ else :
233+ current_batch_history_updates .append (str (content ))
234+ else :
235+ current_batch_history_updates .append ("" )
236+ else :
237+ self .results .append (eval_res )
172238 if self .mini_batch_data_buffer :
173- await self .mini_batch_data_buffer .add_result (r )
174-
175- last_msg = r .last_assistant_message ()
239+ await self .mini_batch_data_buffer .add_result (eval_res )
240+
241+ # Extract prediction for history
242+ last_msg = eval_res .last_assistant_message ()
176243 if last_msg and last_msg .content :
177244 content = last_msg .content
178245 if isinstance (content , list ):
@@ -181,28 +248,13 @@ async def _run_eval():
181248 else :
182249 current_batch_history_updates .append (str (content ))
183250 else :
184- current_batch_history_updates .append ("" )
185- else :
186- if self .mini_batch_data_buffer :
187- await self .mini_batch_data_buffer .add_result (eval_res )
251+ current_batch_history_updates .append ("" ) # Empty string for failed turns
188252
189- # Extract prediction for history
190- last_msg = eval_res .last_assistant_message ()
191- if last_msg and last_msg .content :
192- content = last_msg .content
193- if isinstance (content , list ):
194- text_parts = [p ["text" ] for p in content if p ["type" ] == "text" ]
195- current_batch_history_updates .append ("" .join (text_parts ))
196- else :
197- current_batch_history_updates .append (str (content ))
198- else :
199- current_batch_history_updates .append ("" ) # Empty string for failed turns
200-
201- if self .eval_sem :
202- async with self .eval_sem :
253+ if self .eval_sem :
254+ async with self .eval_sem :
255+ await _run_eval ()
256+ else :
203257 await _run_eval ()
204- else :
205- await _run_eval ()
206258
207259 # 4. Schedule Next Micro-batch (High Priority)
208260 last_run_idx = task .run_indices [- 1 ]
@@ -248,12 +300,11 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_siz
248300 for w in workers :
249301 w .cancel ()
250302
251- # Ensure cancellation is complete
252303 if workers :
253304 await asyncio .gather (* workers , return_exceptions = True )
254305
255- # Return empty dict as we rely on side effects (streaming buffer)
256- return {}
306+ # Return collected results
307+ return self . results
257308
258309async def execute_priority_rollouts (
259310 dataset : List [EvaluationRow ],
0 commit comments