Skip to content

Commit 153eacd

Browse files
committed
Fix tests from semaphore change
1 parent a866dde commit 153eacd

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/test_rollout_control_plane_integration.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
feature in the rollout execution pipeline.
1414
"""
1515

16+
import asyncio
1617
import json
1718
import sys
1819
import tempfile
@@ -238,7 +239,8 @@ def mock_step_side_effect(env_index, tool_call):
238239
policy = MockPolicy(["right", "down", "right"])
239240

240241
# Execute rollout
241-
tasks = self.execution_manager.execute_rollouts(mock_env, policy, steps=10)
242+
semaphore = asyncio.Semaphore(1) # Create semaphore for test
243+
tasks = self.execution_manager.execute_rollouts(mock_env, policy, semaphore, steps=10)
242244
evaluation_rows = []
243245
for task in tasks:
244246
row = await task
@@ -459,7 +461,8 @@ async def test_rollout_handles_control_plane_failure_gracefully(self):
459461

460462
# Execute rollout with control plane failure
461463
policy = MockPolicy(["right"])
462-
tasks = self.execution_manager.execute_rollouts(mock_env, policy, steps=1)
464+
semaphore = asyncio.Semaphore(1) # Create semaphore for test
465+
tasks = self.execution_manager.execute_rollouts(mock_env, policy, semaphore, steps=1)
463466
evaluation_rows = []
464467
for task in tasks:
465468
row = await task
@@ -515,7 +518,6 @@ async def mock_task():
515518

516519
def mock_execute_rollouts(*args, **kwargs):
517520
call_args.append((args, kwargs))
518-
import asyncio
519521

520522
return [asyncio.create_task(mock_task())]
521523

@@ -541,9 +543,11 @@ def mock_execute_rollouts(*args, **kwargs):
541543
# Verify execute_rollouts was called with correct arguments
542544
assert len(call_args) == 1, "execute_rollouts should be called once"
543545
args, kwargs = call_args[0]
546+
544547
assert args[0] == mock_make.return_value, "First arg should be mock env"
545548
assert args[1] == policy, "Second arg should be policy"
546-
assert args[2] == 5, "Third arg should be steps"
549+
assert isinstance(kwargs.get("semaphore"), asyncio.Semaphore), "semaphore should be in kwargs"
550+
assert kwargs.get("steps") == 5, "steps should be in kwargs"
547551

548552
assert result == ["ok"]
549553

0 commit comments

Comments
 (0)