Skip to content

Commit 2865b79

Browse files
committed
add priority rolluot scheduler
1 parent d9ab3d4 commit 2865b79

File tree

2 files changed

+340
-37
lines changed

2 files changed

+340
-37
lines changed

eval_protocol/pytest/priority_scheduler.py

Lines changed: 58 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,22 @@ def __init__(
4545
active_logger: DatasetLogger,
4646
eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], # Callback to run evaluation
4747
mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None,
48+
max_concurrent_evaluations: Optional[int] = None,
4849
):
4950
self.rollout_processor = rollout_processor
5051
self.max_concurrent_rollouts = max_concurrent_rollouts
52+
self.max_concurrent_evaluations = max_concurrent_evaluations
5153
self.active_logger = active_logger
5254
self.eval_executor = eval_executor
5355
self.mini_batch_data_buffer = mini_batch_data_buffer
5456

5557
# Priority Queue: Stores RolloutTask
5658
self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue()
5759

60+
# Concurrency Control
61+
self.rollout_sem = asyncio.Semaphore(max_concurrent_rollouts)
62+
self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) if max_concurrent_evaluations else None
63+
5864
self.num_runs = 0
5965
self.micro_batch_size = 0
6066

@@ -140,31 +146,48 @@ async def _process_task(self, task: RolloutTask):
140146
if task.run_indices:
141147
representative_run_idx = task.run_indices[0]
142148

143-
async for result_row in rollout_processor_with_retry(
144-
self.rollout_processor, current_batch_rows, task.config, representative_run_idx
145-
):
146-
batch_results.append(result_row)
149+
async with self.rollout_sem:
150+
async for result_row in rollout_processor_with_retry(
151+
self.rollout_processor, current_batch_rows, task.config, representative_run_idx
152+
):
153+
batch_results.append(result_row)
147154

148155
# 3. Evaluate and Collect History
149156
current_batch_history_updates = []
150157

151-
for res in batch_results:
152-
# Run Evaluation
153-
eval_res = await self.eval_executor(res)
154-
155-
# Depending on the execution mode, eval_executor might return a single row or a list
156-
# For pointwise, it's a single row. For groupwise, it's a list.
157-
# Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back
158-
# But to be safe and type-correct, we handle both.
159-
160-
if isinstance(eval_res, list):
161-
# Should not happen in pointwise mode which is typically used with this scheduler
162-
# But if it does, we process each result
163-
for r in eval_res:
158+
async def _run_eval():
159+
for res in batch_results:
160+
# Run Evaluation
161+
eval_res = await self.eval_executor(res)
162+
163+
# Depending on the execution mode, eval_executor might return a single row or a list
164+
# For pointwise, it's a single row. For groupwise, it's a list.
165+
# Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back
166+
# But to be safe and type-correct, we handle both.
167+
168+
if isinstance(eval_res, list):
169+
# Should not happen in pointwise mode which is typically used with this scheduler
170+
# But if it does, we process each result
171+
for r in eval_res:
172+
if self.mini_batch_data_buffer:
173+
await self.mini_batch_data_buffer.add_result(r)
174+
175+
last_msg = r.last_assistant_message()
176+
if last_msg and last_msg.content:
177+
content = last_msg.content
178+
if isinstance(content, list):
179+
text_parts = [p["text"] for p in content if p["type"] == "text"]
180+
current_batch_history_updates.append("".join(text_parts))
181+
else:
182+
current_batch_history_updates.append(str(content))
183+
else:
184+
current_batch_history_updates.append("")
185+
else:
164186
if self.mini_batch_data_buffer:
165-
await self.mini_batch_data_buffer.add_result(r)
166-
167-
last_msg = r.last_assistant_message()
187+
await self.mini_batch_data_buffer.add_result(eval_res)
188+
189+
# Extract prediction for history
190+
last_msg = eval_res.last_assistant_message()
168191
if last_msg and last_msg.content:
169192
content = last_msg.content
170193
if isinstance(content, list):
@@ -173,22 +196,13 @@ async def _process_task(self, task: RolloutTask):
173196
else:
174197
current_batch_history_updates.append(str(content))
175198
else:
176-
current_batch_history_updates.append("")
177-
else:
178-
if self.mini_batch_data_buffer:
179-
await self.mini_batch_data_buffer.add_result(eval_res)
199+
current_batch_history_updates.append("") # Empty string for failed turns
180200

181-
# Extract prediction for history
182-
last_msg = eval_res.last_assistant_message()
183-
if last_msg and last_msg.content:
184-
content = last_msg.content
185-
if isinstance(content, list):
186-
text_parts = [p["text"] for p in content if p["type"] == "text"]
187-
current_batch_history_updates.append("".join(text_parts))
188-
else:
189-
current_batch_history_updates.append(str(content))
190-
else:
191-
current_batch_history_updates.append("") # Empty string for failed turns
201+
if self.eval_sem:
202+
async with self.eval_sem:
203+
await _run_eval()
204+
else:
205+
await _run_eval()
192206

193207
# 4. Schedule Next Micro-batch (High Priority)
194208
last_run_idx = task.run_indices[-1]
@@ -220,7 +234,12 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_siz
220234
await self.schedule_dataset(dataset, base_config)
221235

222236
# 2. Start Workers
223-
workers = [asyncio.create_task(self.worker()) for _ in range(self.max_concurrent_rollouts)]
237+
# If we have separate limits, we need enough workers to saturate both stages
238+
num_workers = self.max_concurrent_rollouts
239+
if self.max_concurrent_evaluations:
240+
num_workers += self.max_concurrent_evaluations
241+
242+
workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)]
224243

225244
# 3. Wait for completion
226245
await self.queue.join()
@@ -246,12 +265,14 @@ async def execute_priority_rollouts(
246265
active_logger: DatasetLogger,
247266
eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]],
248267
mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None,
268+
max_concurrent_evaluations: Optional[int] = None,
249269
):
250270
scheduler = PriorityRolloutScheduler(
251271
rollout_processor=rollout_processor,
252272
max_concurrent_rollouts=max_concurrent_rollouts,
253273
active_logger=active_logger,
254274
eval_executor=eval_executor,
255-
mini_batch_data_buffer=mini_batch_data_buffer
275+
mini_batch_data_buffer=mini_batch_data_buffer,
276+
max_concurrent_evaluations=max_concurrent_evaluations
256277
)
257278
return await scheduler.run(dataset, num_runs, micro_batch_size, config)

0 commit comments

Comments
 (0)