Skip to content

Commit a866dde

Browse files
authored
Adding unified semaphore (#153)
1 parent 1d551c0 commit a866dde

File tree

8 files changed

+23
-18
lines changed

8 files changed

+23
-18
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def execute_rollouts(
3939
self,
4040
envs: "GeneralMCPVectorEnv",
4141
policy: Union["LLMBasePolicy", Callable],
42+
semaphore: asyncio.Semaphore,
4243
steps: int = 512,
4344
openai_format_log_file: Optional[str] = None,
44-
max_concurrent_rollouts: int = 8,
4545
evaluation_rows: Optional[List[EvaluationRow]] = None,
4646
) -> List[asyncio.Task[EvaluationRow]]:
4747
"""
@@ -57,7 +57,7 @@ def execute_rollouts(
5757
policy: Policy that takes tool schemas, observations, prompts and returns tool calls
5858
steps: Maximum steps per rollout
5959
openai_format_log_file: Optional file to log clean OpenAI format for terminated trajectories only
60-
max_concurrent_rollouts: Maximum number of concurrent threads to run
60+
semaphore: Semaphore to control concurrent rollout execution
6161
6262
Environment Variable Control:
6363
EP_PLAYBACK_FILE: Controls record/playback mode
@@ -90,15 +90,13 @@ def execute_rollouts(
9090
pass
9191
openai_logger = lambda data: self._log_openai_entry(openai_format_log_file, data)
9292

93-
logger.info(f"🧵 Starting {envs.n} rollouts with max {max_concurrent_rollouts} concurrent threads...")
93+
logger.info(f"🧵 Starting {envs.n} rollouts with max {semaphore._value} concurrent threads...")
9494

9595
if evaluation_rows is None:
9696
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in range(envs.n)]
9797

9898
shared_tool_schema = envs.tool_schemas
9999

100-
semaphore = asyncio.Semaphore(max_concurrent_rollouts)
101-
102100
async def _execute_with_semaphore(idx):
103101
async with semaphore:
104102
evaluation_row: EvaluationRow = evaluation_rows[idx]

eval_protocol/mcp_env.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,15 @@ async def rollout(
310310
# Use the new ExecutionManager for execution
311311
execution_manager = ExecutionManager()
312312

313+
rollout_semaphore = asyncio.Semaphore(max_concurrent_rollouts)
314+
313315
tasks = execution_manager.execute_rollouts(
314-
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
316+
envs,
317+
policy,
318+
semaphore=rollout_semaphore,
319+
steps=steps,
320+
openai_format_log_file=openai_format_log_file,
321+
evaluation_rows=evaluation_rows,
315322
)
316323

317324
# Await all tasks and return concrete EvaluationRows

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,7 @@ class AgentRolloutProcessor(RolloutProcessor):
225225
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
226226
"""Create agent rollout tasks and return them for external handling."""
227227

228-
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
229-
semaphore = asyncio.Semaphore(max_concurrent)
228+
semaphore = config.semaphore
230229

231230
async def process_row(row: EvaluationRow) -> EvaluationRow:
232231
"""Process a single row with agent rollout."""

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,10 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
256256
tasks = execution_manager.execute_rollouts(
257257
envs,
258258
policy=self.policy,
259+
semaphore=config.semaphore,
259260
steps=config.steps,
260-
max_concurrent_rollouts=config.max_concurrent_rollouts,
261261
evaluation_rows=rows,
262262
)
263-
264263
return tasks
265264

266265
def cleanup(self) -> None:

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def __init__(self):
4040
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
4141
"""Create agent rollout tasks and return them for external handling."""
4242

43-
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
44-
semaphore = asyncio.Semaphore(max_concurrent)
43+
semaphore = config.semaphore
4544

4645
# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
4746
if "agent" not in config.kwargs:

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
112112
default_logger.log(row)
113113
return row
114114

115-
# Process rows with bounded concurrency
116-
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
117-
semaphore = asyncio.Semaphore(max_concurrent)
115+
semaphore = config.semaphore
118116

119117
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
120118
async with semaphore:

eval_protocol/pytest/evaluation_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,14 +282,17 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
282282
# used to determine whether this eval has stopped
283283
row.pid = os.getpid()
284284

285+
# Create shared semaphore for unified concurrency control across all runs and rollouts
286+
shared_semaphore = asyncio.Semaphore(max_concurrent_rollouts)
287+
285288
# Prepare rollout processor config once; we will generate fresh outputs per run
286289
config = RolloutProcessorConfig(
287290
completion_params=completion_params if completion_params is not None else {},
288291
mcp_config_path=mcp_config_path or "",
289-
max_concurrent_rollouts=max_concurrent_rollouts,
290292
server_script_path=server_script_path,
291293
steps=steps,
292294
logger=active_logger,
295+
semaphore=shared_semaphore,
293296
kwargs=rollout_processor_kwargs or {},
294297
exception_handler_config=exception_handler_config,
295298
)
@@ -372,10 +375,10 @@ async def _execute_groupwise_eval_with_semaphore(
372375
config = RolloutProcessorConfig(
373376
completion_params=cp if cp is not None else {},
374377
mcp_config_path=mcp_config_path or "",
375-
max_concurrent_rollouts=max_concurrent_rollouts,
376378
server_script_path=server_script_path,
377379
steps=steps,
378380
logger=active_logger,
381+
semaphore=shared_semaphore,
379382
kwargs=rollout_processor_kwargs or {},
380383
)
381384
lst = []
@@ -474,6 +477,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
474477
await task
475478
else:
476479
# For other processors, create all tasks at once and run in parallel
480+
# Concurrency is now controlled by the shared semaphore in each rollout processor
477481
tasks = []
478482
for i in range(num_runs):
479483
tasks.append(asyncio.create_task(execute_run(i, config))) # pyright: ignore[reportUnknownMemberType]

eval_protocol/pytest/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Parameter types
33
"""
44

5+
import asyncio
56
from dataclasses import dataclass, field
67
from typing import Any, Callable, Literal
78

@@ -58,10 +59,10 @@
5859
class RolloutProcessorConfig:
5960
completion_params: CompletionParams # input parameters for inference
6061
mcp_config_path: str
62+
semaphore: asyncio.Semaphore # shared semaphore for unified concurrency control
6163
server_script_path: str | None = (
6264
None # TODO: change from server_script_path to mcp_config_path for agent rollout processor
6365
)
64-
max_concurrent_rollouts: int = 8 # maximum number of concurrent rollouts
6566
steps: int = 30 # max number of rollout steps
6667
logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs
6768
kwargs: dict[str, Any] = field( # pyright: ignore[reportExplicitAny]

0 commit comments

Comments
 (0)