Skip to content

Commit c7e3abb

Browse files
committed
fix unittest
1 parent 70b2e17 commit c7e3abb

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/test_priority_scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ async def test_scheduler_basic_execution(
5757
micro_batch_size = 1
5858

5959
# Mock rollout processor with delay
60-
async def delayed_rollout(processor, rows, config, run_idx):
60+
async def delayed_rollout(processor, rows, config, run_idx, **kwargs):
6161
await asyncio.sleep(0.01)
6262
for row in rows:
6363
yield row
@@ -110,7 +110,7 @@ async def test_concurrency_control(
110110
rollout_lock = asyncio.Lock()
111111
eval_lock = asyncio.Lock()
112112

113-
async def mock_rollout_gen(processor, rows, config, run_idx):
113+
async def mock_rollout_gen(processor, rows, config, run_idx, **kwargs):
114114
nonlocal active_rollouts, max_active_rollouts_seen
115115
async with rollout_lock:
116116
active_rollouts += 1
@@ -177,7 +177,7 @@ async def test_priority_scheduling(
177177

178178
execution_order = []
179179

180-
async def mock_rollout_gen(processor, rows, config, run_idx):
180+
async def mock_rollout_gen(processor, rows, config, run_idx, **kwargs):
181181
row_id = rows[0].input_metadata.row_id
182182
execution_order.append(f"{row_id}_run_{run_idx}")
183183
for row in rows:
@@ -290,7 +290,7 @@ async def mock_eval(rows):
290290
eval_calls.append(rows)
291291
return rows # Pass through
292292

293-
async def mock_rollout_gen(processor, rows, config, run_idx):
293+
async def mock_rollout_gen(processor, rows, config, run_idx, **kwargs):
294294
for row in rows:
295295
yield row
296296

0 commit comments

Comments
 (0)