Skip to content

Commit 37e0210

Browse files
committed
groupwise
1 parent 2865b79 commit 37e0210

File tree

2 files changed

+142
-41
lines changed

2 files changed

+142
-41
lines changed

eval_protocol/pytest/priority_scheduler.py

Lines changed: 92 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,17 @@ def __init__(
4444
max_concurrent_rollouts: int,
4545
active_logger: DatasetLogger,
4646
eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], # Callback to run evaluation
47-
mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None,
47+
output_buffer: Optional[MiniBatchDataBuffer] = None,
4848
max_concurrent_evaluations: Optional[int] = None,
49+
mode: str = "pointwise",
4950
):
5051
self.rollout_processor = rollout_processor
5152
self.max_concurrent_rollouts = max_concurrent_rollouts
5253
self.max_concurrent_evaluations = max_concurrent_evaluations
5354
self.active_logger = active_logger
5455
self.eval_executor = eval_executor
55-
self.mini_batch_data_buffer = mini_batch_data_buffer
56+
self.output_buffer = output_buffer
57+
self.mode = mode
5658

5759
# Priority Queue: Stores RolloutTask
5860
self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue()
@@ -61,6 +63,10 @@ def __init__(
6163
self.rollout_sem = asyncio.Semaphore(max_concurrent_rollouts)
6264
self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) if max_concurrent_evaluations else None
6365

66+
# Results storage
67+
self.results: List[EvaluationRow] = [] # for backward compatibility reason, we save all results here to return
68+
self.groups_buffer: Dict[int, List[EvaluationRow]] = defaultdict(list) # buffer for group results. only flush to output buffer when a whole group is ready
69+
6470
self.num_runs = 0
6571
self.micro_batch_size = 0
6672

@@ -155,24 +161,85 @@ async def _process_task(self, task: RolloutTask):
155161
# 3. Evaluate and Collect History
156162
current_batch_history_updates = []
157163

158-
async def _run_eval():
159-
for res in batch_results:
160-
# Run Evaluation
161-
eval_res = await self.eval_executor(res)
162-
163-
# Depending on the execution mode, eval_executor might return a single row or a list
164-
# For pointwise, it's a single row. For groupwise, it's a list.
165-
# Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back
166-
# But to be safe and type-correct, we handle both.
164+
if self.mode == "groupwise":
165+
# Collect all results from this batch
166+
for res in batch_results:
167+
self.groupwise_buffer[task.row_index].append(res)
167168

168-
if isinstance(eval_res, list):
169-
# Should not happen in pointwise mode which is typically used with this scheduler
170-
# But if it does, we process each result
171-
for r in eval_res:
169+
# Update history from rollout result (assuming eval doesn't change content needed for history)
170+
last_msg = res.last_assistant_message()
171+
if last_msg and last_msg.content:
172+
content = last_msg.content
173+
if isinstance(content, list):
174+
text_parts = [p["text"] for p in content if p["type"] == "text"]
175+
current_batch_history_updates.append("".join(text_parts))
176+
else:
177+
current_batch_history_updates.append(str(content))
178+
else:
179+
current_batch_history_updates.append("")
180+
181+
# Check if this is the last batch for this sample
182+
last_run_idx = task.run_indices[-1]
183+
if last_run_idx + 1 >= self.num_runs:
184+
# Last batch: Execute Groupwise Evaluation
185+
full_group = self.groupwise_buffer[task.row_index]
186+
187+
async def _run_group_eval():
188+
eval_res = await self.eval_executor(full_group)
189+
# Handle result (could be list or single row wrapping list?)
190+
# Usually groupwise returns list of scored rows
191+
if isinstance(eval_res, list):
192+
self.results.extend(eval_res)
193+
if self.mini_batch_data_buffer:
194+
# Push the whole group at once if possible, or iterate
195+
for r in eval_res:
196+
await self.mini_batch_data_buffer.add_result(r)
197+
else:
198+
self.results.append(eval_res)
199+
if self.mini_batch_data_buffer:
200+
await self.mini_batch_data_buffer.add_result(eval_res)
201+
202+
if self.eval_sem:
203+
async with self.eval_sem:
204+
await _run_group_eval()
205+
else:
206+
await _run_group_eval()
207+
208+
# Clear buffer to free memory
209+
del self.groupwise_buffer[task.row_index]
210+
211+
else:
212+
# Pointwise: Process each result individually
213+
async def _run_eval():
214+
for res in batch_results:
215+
# Run Evaluation
216+
eval_res = await self.eval_executor(res)
217+
218+
if isinstance(eval_res, list):
219+
# Should not happen in pointwise mode which is typically used with this scheduler
220+
# But if it does, we process each result
221+
self.results.extend(eval_res)
222+
for r in eval_res:
223+
if self.mini_batch_data_buffer:
224+
await self.mini_batch_data_buffer.add_result(r)
225+
226+
last_msg = r.last_assistant_message()
227+
if last_msg and last_msg.content:
228+
content = last_msg.content
229+
if isinstance(content, list):
230+
text_parts = [p["text"] for p in content if p["type"] == "text"]
231+
current_batch_history_updates.append("".join(text_parts))
232+
else:
233+
current_batch_history_updates.append(str(content))
234+
else:
235+
current_batch_history_updates.append("")
236+
else:
237+
self.results.append(eval_res)
172238
if self.mini_batch_data_buffer:
173-
await self.mini_batch_data_buffer.add_result(r)
174-
175-
last_msg = r.last_assistant_message()
239+
await self.mini_batch_data_buffer.add_result(eval_res)
240+
241+
# Extract prediction for history
242+
last_msg = eval_res.last_assistant_message()
176243
if last_msg and last_msg.content:
177244
content = last_msg.content
178245
if isinstance(content, list):
@@ -181,28 +248,13 @@ async def _run_eval():
181248
else:
182249
current_batch_history_updates.append(str(content))
183250
else:
184-
current_batch_history_updates.append("")
185-
else:
186-
if self.mini_batch_data_buffer:
187-
await self.mini_batch_data_buffer.add_result(eval_res)
251+
current_batch_history_updates.append("") # Empty string for failed turns
188252

189-
# Extract prediction for history
190-
last_msg = eval_res.last_assistant_message()
191-
if last_msg and last_msg.content:
192-
content = last_msg.content
193-
if isinstance(content, list):
194-
text_parts = [p["text"] for p in content if p["type"] == "text"]
195-
current_batch_history_updates.append("".join(text_parts))
196-
else:
197-
current_batch_history_updates.append(str(content))
198-
else:
199-
current_batch_history_updates.append("") # Empty string for failed turns
200-
201-
if self.eval_sem:
202-
async with self.eval_sem:
253+
if self.eval_sem:
254+
async with self.eval_sem:
255+
await _run_eval()
256+
else:
203257
await _run_eval()
204-
else:
205-
await _run_eval()
206258

207259
# 4. Schedule Next Micro-batch (High Priority)
208260
last_run_idx = task.run_indices[-1]
@@ -248,12 +300,11 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_siz
248300
for w in workers:
249301
w.cancel()
250302

251-
# Ensure cancellation is complete
252303
if workers:
253304
await asyncio.gather(*workers, return_exceptions=True)
254305

255-
# Return empty dict as we rely on side effects (streaming buffer)
256-
return {}
306+
# Return collected results
307+
return self.results
257308

258309
async def execute_priority_rollouts(
259310
dataset: List[EvaluationRow],

tests/test_priority_scheduler.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,54 @@ async def schedule_dataset(self, *args):
279279

280280
assert worker_start_count == expected_workers
281281

282+
@pytest.mark.asyncio
283+
async def test_groupwise_mode(
284+
mock_logger, mock_eval_executor, base_config
285+
):
286+
"""
287+
Test that groupwise mode collects all runs before evaluating.
288+
"""
289+
dataset = [create_mock_row("row-0")]
290+
num_runs = 4
291+
micro_batch_size = 2
292+
293+
# We expect 2 batches of 2 runs each.
294+
# Batch 1 (Runs 0,1): Should buffer and update history, NOT call eval.
295+
# Batch 2 (Runs 2,3): Should buffer, update history, AND call eval with all 4 runs.
296+
297+
eval_calls = []
298+
299+
async def mock_eval(rows):
300+
eval_calls.append(rows)
301+
return rows # Pass through
302+
303+
async def mock_rollout_gen(processor, rows, config, run_idx):
304+
for row in rows:
305+
yield row
306+
307+
mock_eval_executor.side_effect = mock_eval
308+
309+
with patch('eval_protocol.pytest.priority_scheduler.rollout_processor_with_retry', side_effect=mock_rollout_gen):
310+
processor_instance = MagicMock()
311+
312+
scheduler = PriorityRolloutScheduler(
313+
rollout_processor=processor_instance,
314+
max_concurrent_rollouts=1,
315+
active_logger=mock_logger,
316+
eval_executor=mock_eval_executor,
317+
mode="groupwise"
318+
)
319+
320+
results = await scheduler.run(dataset, num_runs, micro_batch_size, base_config)
321+
322+
# Verify evaluation was called EXACTLY ONCE
323+
assert len(eval_calls) == 1, f"Expected 1 eval call, got {len(eval_calls)}"
324+
325+
# Verify it was called with ALL 4 rows
326+
evaluated_rows = eval_calls[0]
327+
assert len(evaluated_rows) == 4, f"Expected 4 rows in group eval, got {len(evaluated_rows)}"
328+
329+
# Verify results contains all 4 rows
330+
assert len(results) == 4
331+
282332

0 commit comments

Comments
 (0)