diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 753d454e..0b8f0668 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -39,9 +39,9 @@ def execute_rollouts( self, envs: "GeneralMCPVectorEnv", policy: Union["LLMBasePolicy", Callable], + semaphore: asyncio.Semaphore, steps: int = 512, openai_format_log_file: Optional[str] = None, - max_concurrent_rollouts: int = 8, evaluation_rows: Optional[List[EvaluationRow]] = None, ) -> List[asyncio.Task[EvaluationRow]]: """ @@ -57,7 +57,7 @@ def execute_rollouts( policy: Policy that takes tool schemas, observations, prompts and returns tool calls steps: Maximum steps per rollout openai_format_log_file: Optional file to log clean OpenAI format for terminated trajectories only - max_concurrent_rollouts: Maximum number of concurrent threads to run + semaphore: Semaphore to control concurrent rollout execution Environment Variable Control: EP_PLAYBACK_FILE: Controls record/playback mode @@ -90,15 +90,13 @@ def execute_rollouts( pass openai_logger = lambda data: self._log_openai_entry(openai_format_log_file, data) - logger.info(f"🧵 Starting {envs.n} rollouts with max {max_concurrent_rollouts} concurrent threads...") + logger.info(f"🧵 Starting {envs.n} rollouts with max {semaphore._value} concurrent threads...") if evaluation_rows is None: evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in range(envs.n)] shared_tool_schema = envs.tool_schemas - semaphore = asyncio.Semaphore(max_concurrent_rollouts) - async def _execute_with_semaphore(idx): async with semaphore: evaluation_row: EvaluationRow = evaluation_rows[idx] diff --git a/eval_protocol/mcp_env.py b/eval_protocol/mcp_env.py index 35ad517b..aac1a127 100644 --- a/eval_protocol/mcp_env.py +++ b/eval_protocol/mcp_env.py @@ -310,8 +310,15 @@ async def rollout( # Use the new ExecutionManager for execution execution_manager = ExecutionManager() + rollout_semaphore = asyncio.Semaphore(max_concurrent_rollouts) + tasks = execution_manager.execute_rollouts( - envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows + envs, + policy, + semaphore=rollout_semaphore, + steps=steps, + openai_format_log_file=openai_format_log_file, + evaluation_rows=evaluation_rows, ) # Await all tasks and return concrete EvaluationRows diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index 33650185..fac02dd9 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -225,8 +225,7 @@ class AgentRolloutProcessor(RolloutProcessor): def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: """Create agent rollout tasks and return them for external handling.""" - max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 - semaphore = asyncio.Semaphore(max_concurrent) + semaphore = config.semaphore async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row with agent rollout.""" diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index b377d71c..42428b5c 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -256,11 +256,10 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> tasks = execution_manager.execute_rollouts( envs, policy=self.policy, + semaphore=config.semaphore, steps=config.steps, - max_concurrent_rollouts=config.max_concurrent_rollouts, evaluation_rows=rows, ) - return tasks def cleanup(self) -> None: diff --git a/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py b/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py index 1fc85fa5..26cf3915 100644 --- a/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py +++ b/eval_protocol/pytest/default_pydantic_ai_rollout_processor.py @@ -40,8 +40,7 @@ def __init__(self): def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]: """Create agent rollout tasks and return them for external handling.""" - max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 - semaphore = asyncio.Semaphore(max_concurrent) + semaphore = config.semaphore # validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict if "agent" not in config.kwargs: diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 161d8db9..48a12fa3 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -112,9 +112,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: default_logger.log(row) return row - # Process rows with bounded concurrency - max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 - semaphore = asyncio.Semaphore(max_concurrent) + semaphore = config.semaphore async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: async with semaphore: diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 583fb29f..14a23dce 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -282,14 +282,17 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo # used to determine whether this eval has stopped row.pid = os.getpid() + # Create shared semaphore for unified concurrency control across all runs and rollouts + shared_semaphore = asyncio.Semaphore(max_concurrent_rollouts) + # Prepare rollout processor config once; we will generate fresh outputs per run config = RolloutProcessorConfig( completion_params=completion_params if completion_params is not None else {}, mcp_config_path=mcp_config_path or "", - max_concurrent_rollouts=max_concurrent_rollouts, server_script_path=server_script_path, steps=steps, logger=active_logger, + semaphore=shared_semaphore, kwargs=rollout_processor_kwargs or {}, exception_handler_config=exception_handler_config, ) @@ -372,10 +375,10 @@ async def _execute_groupwise_eval_with_semaphore( config = RolloutProcessorConfig( completion_params=cp if cp is not None else {}, mcp_config_path=mcp_config_path or "", - max_concurrent_rollouts=max_concurrent_rollouts, server_script_path=server_script_path, steps=steps, logger=active_logger, + semaphore=shared_semaphore, kwargs=rollout_processor_kwargs or {}, ) lst = [] @@ -474,6 +477,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete await task else: # For other processors, create all tasks at once and run in parallel + # Concurrency is now controlled by the shared semaphore in each rollout processor tasks = [] for i in range(num_runs): tasks.append(asyncio.create_task(execute_run(i, config))) # pyright: ignore[reportUnknownMemberType] diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index 6b52ae33..46719c1c 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -2,6 +2,7 @@ Parameter types """ +import asyncio from dataclasses import dataclass, field from typing import Any, Callable, Literal @@ -58,10 +59,10 @@ class RolloutProcessorConfig: completion_params: CompletionParams # input parameters for inference mcp_config_path: str + semaphore: asyncio.Semaphore # shared semaphore for unified concurrency control server_script_path: str | None = ( None # TODO: change from server_script_path to mcp_config_path for agent rollout processor ) - max_concurrent_rollouts: int = 8 # maximum number of concurrent rollouts steps: int = 30 # max number of rollout steps logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs kwargs: dict[str, Any] = field( # pyright: ignore[reportExplicitAny]