Skip to content

Commit 5fc935c

Browse files
committed
fix
1 parent b556f4e commit 5fc935c

File tree

3 files changed

+12
-34
lines changed

3 files changed

+12
-34
lines changed

eval_protocol/pytest/priority_scheduler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,8 @@ def _post_process_result(self, res: EvaluationRow):
294294
pass
295295
self.active_logger.log(res)
296296

297-
async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_size: int, base_config: RolloutProcessorConfig):
297+
async def run(self, dataset: List[EvaluationRow], num_runs: int, base_config: RolloutProcessorConfig):
298298
self.num_runs = num_runs
299-
self.micro_batch_size = micro_batch_size
300299

301300
# 1. Schedule initial tasks
302301
await self.schedule_dataset(dataset, base_config)
@@ -349,4 +348,4 @@ async def execute_priority_rollouts(
349348
in_group_minibatch_size=(num_runs // 2),
350349
evaluation_test_kwargs=evaluation_test_kwargs,
351350
)
352-
return await scheduler.run(dataset, num_runs, micro_batch_size, config)
351+
return await scheduler.run(dataset, num_runs, config)

pytest.ini

Lines changed: 0 additions & 21 deletions
This file was deleted.

tests/test_priority_scheduler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ async def mock_eval(row):
7676
eval_executor=mock_eval,
7777
max_concurrent_evaluations=2,
7878
rollout_n=num_runs,
79-
in_group_microbatch_size=micro_batch_size
79+
in_group_minibatch_size=micro_batch_size
8080
)
8181

82-
results = await scheduler.run(dataset, num_runs, micro_batch_size, base_config)
82+
results = await scheduler.run(dataset, num_runs, base_config)
8383

8484
assert len(results) == 5 * num_runs
8585
for res in results:
@@ -151,10 +151,10 @@ async def mock_eval(row):
151151
eval_executor=mock_eval,
152152
max_concurrent_evaluations=max_evals,
153153
rollout_n=num_runs,
154-
in_group_microbatch_size=micro_batch_size
154+
in_group_minibatch_size=micro_batch_size
155155
)
156156

157-
await scheduler.run(dataset, num_runs, micro_batch_size, base_config)
157+
await scheduler.run(dataset, num_runs, base_config)
158158

159159
# Verify limits were respected
160160
assert max_active_rollouts_seen <= max_rollouts, f"Rollout concurrency exceeded: {max_active_rollouts_seen} > {max_rollouts}"
@@ -196,10 +196,10 @@ async def mock_eval(row):
196196
eval_executor=mock_eval,
197197
max_concurrent_evaluations=1,
198198
rollout_n=num_runs,
199-
in_group_microbatch_size=micro_batch_size
199+
in_group_minibatch_size=micro_batch_size
200200
)
201201

202-
await scheduler.run(dataset, num_runs, micro_batch_size, base_config)
202+
await scheduler.run(dataset, num_runs, base_config)
203203

204204
# Expected order: row-0_run_0, row-0_run_1, row-1_run_0, row-1_run_1
205205
# Note: Since row-0_run_0 finishes, it schedules row-0_run_1 with HIGH priority (0).
@@ -262,10 +262,10 @@ async def schedule_dataset(self, *args):
262262
eval_executor=mock_eval_executor,
263263
max_concurrent_evaluations=max_evals,
264264
rollout_n=1,
265-
in_group_microbatch_size=1
265+
in_group_minibatch_size=1
266266
)
267267

268-
await scheduler.run(dataset, 1, 1, base_config)
268+
await scheduler.run(dataset, 1, base_config)
269269

270270
assert worker_start_count == expected_workers
271271

@@ -305,10 +305,10 @@ async def mock_rollout_gen(processor, rows, config, run_idx):
305305
max_concurrent_evaluations=1,
306306
mode="groupwise",
307307
rollout_n=num_runs,
308-
in_group_microbatch_size=micro_batch_size
308+
in_group_minibatch_size=micro_batch_size
309309
)
310310

311-
results = await scheduler.run(dataset, num_runs, micro_batch_size, base_config)
311+
results = await scheduler.run(dataset, num_runs, base_config)
312312

313313
# Verify evaluation was called EXACTLY ONCE
314314
assert len(eval_calls) == 1, f"Expected 1 eval call, got {len(eval_calls)}"

0 commit comments

Comments
 (0)