Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def execute_rollouts(
self,
envs: "GeneralMCPVectorEnv",
policy: Union["LLMBasePolicy", Callable],
semaphore: asyncio.Semaphore,
steps: int = 512,
openai_format_log_file: Optional[str] = None,
max_concurrent_rollouts: int = 8,
evaluation_rows: Optional[List[EvaluationRow]] = None,
) -> List[asyncio.Task[EvaluationRow]]:
"""
Expand All @@ -57,7 +57,7 @@ def execute_rollouts(
policy: Policy that takes tool schemas, observations, prompts and returns tool calls
steps: Maximum steps per rollout
openai_format_log_file: Optional file to log clean OpenAI format for terminated trajectories only
max_concurrent_rollouts: Maximum number of concurrent threads to run
semaphore: Semaphore to control concurrent rollout execution

Environment Variable Control:
EP_PLAYBACK_FILE: Controls record/playback mode
Expand Down Expand Up @@ -90,15 +90,13 @@ def execute_rollouts(
pass
openai_logger = lambda data: self._log_openai_entry(openai_format_log_file, data)

logger.info(f"🧵 Starting {envs.n} rollouts with max {max_concurrent_rollouts} concurrent threads...")
logger.info(f"🧵 Starting {envs.n} rollouts with max {semaphore._value} concurrent threads...")

if evaluation_rows is None:
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in range(envs.n)]

shared_tool_schema = envs.tool_schemas

semaphore = asyncio.Semaphore(max_concurrent_rollouts)

async def _execute_with_semaphore(idx):
async with semaphore:
evaluation_row: EvaluationRow = evaluation_rows[idx]
Expand Down
9 changes: 8 additions & 1 deletion eval_protocol/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,15 @@ async def rollout(
# Use the new ExecutionManager for execution
execution_manager = ExecutionManager()

rollout_semaphore = asyncio.Semaphore(max_concurrent_rollouts)

tasks = execution_manager.execute_rollouts(
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
envs,
policy,
semaphore=rollout_semaphore,
steps=steps,
openai_format_log_file=openai_format_log_file,
evaluation_rows=evaluation_rows,
)

# Await all tasks and return concrete EvaluationRows
Expand Down
3 changes: 1 addition & 2 deletions eval_protocol/pytest/default_agent_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ class AgentRolloutProcessor(RolloutProcessor):
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Create agent rollout tasks and return them for external handling."""

max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
semaphore = asyncio.Semaphore(max_concurrent)
semaphore = config.semaphore

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
Expand Down
3 changes: 1 addition & 2 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,10 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
tasks = execution_manager.execute_rollouts(
envs,
policy=self.policy,
semaphore=config.semaphore,
steps=config.steps,
max_concurrent_rollouts=config.max_concurrent_rollouts,
evaluation_rows=rows,
)

return tasks

def cleanup(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def __init__(self):
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
"""Create agent rollout tasks and return them for external handling."""

max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
semaphore = asyncio.Semaphore(max_concurrent)
semaphore = config.semaphore

# validate that the "agent" field is present with a valid Pydantic AI Agent instance in the completion_params dict
if "agent" not in config.kwargs:
Expand Down
4 changes: 1 addition & 3 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
default_logger.log(row)
return row

# Process rows with bounded concurrency
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
semaphore = asyncio.Semaphore(max_concurrent)
semaphore = config.semaphore

async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
Expand Down
8 changes: 6 additions & 2 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,17 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
# used to determine whether this eval has stopped
row.pid = os.getpid()

# Create shared semaphore for unified concurrency control across all runs and rollouts
shared_semaphore = asyncio.Semaphore(max_concurrent_rollouts)

# Prepare rollout processor config once; we will generate fresh outputs per run
config = RolloutProcessorConfig(
completion_params=completion_params if completion_params is not None else {},
mcp_config_path=mcp_config_path or "",
max_concurrent_rollouts=max_concurrent_rollouts,
server_script_path=server_script_path,
steps=steps,
logger=active_logger,
semaphore=shared_semaphore,
kwargs=rollout_processor_kwargs or {},
exception_handler_config=exception_handler_config,
)
Expand Down Expand Up @@ -372,10 +375,10 @@ async def _execute_groupwise_eval_with_semaphore(
config = RolloutProcessorConfig(
completion_params=cp if cp is not None else {},
mcp_config_path=mcp_config_path or "",
max_concurrent_rollouts=max_concurrent_rollouts,
server_script_path=server_script_path,
steps=steps,
logger=active_logger,
semaphore=shared_semaphore,
kwargs=rollout_processor_kwargs or {},
)
lst = []
Expand Down Expand Up @@ -474,6 +477,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
await task
else:
# For other processors, create all tasks at once and run in parallel
# Concurrency is now controlled by the shared semaphore in each rollout processor
tasks = []
for i in range(num_runs):
tasks.append(asyncio.create_task(execute_run(i, config))) # pyright: ignore[reportUnknownMemberType]
Expand Down
3 changes: 2 additions & 1 deletion eval_protocol/pytest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Parameter types
"""

import asyncio
from dataclasses import dataclass, field
from typing import Any, Callable, Literal

Expand Down Expand Up @@ -58,10 +59,10 @@
class RolloutProcessorConfig:
completion_params: CompletionParams # input parameters for inference
mcp_config_path: str
semaphore: asyncio.Semaphore # shared semaphore for unified concurrency control
server_script_path: str | None = (
None # TODO: change from server_script_path to mcp_config_path for agent rollout processor
)
max_concurrent_rollouts: int = 8 # maximum number of concurrent rollouts
steps: int = 30 # max number of rollout steps
logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs
kwargs: dict[str, Any] = field( # pyright: ignore[reportExplicitAny]
Expand Down
Loading