From d91a56fd6f8753aab9db500ec424421a779ca73b Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 13 Nov 2025 13:46:07 -0800 Subject: [PATCH] move shared mcp server logic --- .../default_mcp_gym_rollout_processor.py | 90 ++++++++++++++----- eval_protocol/pytest/evaluation_test.py | 10 +++ eval_protocol/pytest/evaluation_test_utils.py | 64 ++++++++++++- eval_protocol/pytest/types.py | 6 ++ 4 files changed, 148 insertions(+), 22 deletions(-) diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index cd869bd7..4587be0a 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -13,7 +13,7 @@ from eval_protocol.mcp.execution.manager import ExecutionManager from eval_protocol.models import EvaluationRow from eval_protocol.pytest.rollout_processor import RolloutProcessor -from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.types import RolloutProcessorConfig, ServerMode class MCPServerManager: @@ -207,37 +207,78 @@ class MCPGymRolloutProcessor(RolloutProcessor): using the eval_protocol framework with proper cleanup handling. """ + # Shared server state for "shared" mode + _shared_server_lock = threading.Lock() + _shared_server: Optional[MCPServerManager] = None + _shared_server_started: bool = False + def __init__(self): - self.server = None + # Instance-level server handle (used in "per_run" mode) + self.server: Optional[MCPServerManager] = None self.policy = None + # Track which mode this instance last used ("per_run" or "shared") + self.server_mode: ServerMode = "per_run" def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: """Process evaluation rows with MCP gym environments.""" - start_server = config.kwargs.get("start_server", True) if config.kwargs else True + server_kwargs = dict(config.kwargs or {}) + start_server = bool(server_kwargs.pop("start_server", True)) + server_mode: ServerMode = server_kwargs.pop("server_mode", "per_run") + port = int(server_kwargs.pop("port", 9700)) - if start_server: - # Create fresh MCP server and environments for this run - if config.server_script_path is None: - raise ValueError("server_script_path is required for MCPGymRolloutProcessor") + self.server_mode = server_mode - self.server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {})) + if server_mode == "shared": + # Shared, class-level server used across calls + if start_server and not MCPGymRolloutProcessor._shared_server_started: + with MCPGymRolloutProcessor._shared_server_lock: + if not MCPGymRolloutProcessor._shared_server_started: + if config.server_script_path is None: + raise ValueError("server_script_path is required for MCPGymRolloutProcessor") - try: - self.server.start() + shared_server = MCPServerManager(config.server_script_path, port=port, **server_kwargs) - except Exception as e: - if self.server: - self.server.stop() - self.server = None - self.policy = None - raise e + try: + shared_server.start() + except Exception as e: + shared_server.stop() + raise e - else: - # Reuse existing MCP environments for retry - if not self.server: + MCPGymRolloutProcessor._shared_server = shared_server + MCPGymRolloutProcessor._shared_server_started = True + + if MCPGymRolloutProcessor._shared_server is None: raise RuntimeError( - "Cannot retry without existing server/environments. Call with start_server=True first." + "Shared MCP server not started. Call with server_mode='shared' and start_server=True first." ) + # Bind this instance to the shared server for this call + self.server = MCPGymRolloutProcessor._shared_server + + else: + # Default "per_run" behavior: fresh server per call, reused only for retries + if start_server: + # Create fresh MCP server and environments for this run + if config.server_script_path is None: + raise ValueError("server_script_path is required for MCPGymRolloutProcessor") + + self.server = MCPServerManager(config.server_script_path, port=port, **server_kwargs) + + try: + self.server.start() + + except Exception as e: + if self.server: + self.server.stop() + self.server = None + self.policy = None + raise e + + else: + # Reuse existing MCP environments for retry (per_run mode) + if not self.server: + raise RuntimeError( + "Cannot retry without existing server/environments. Call with start_server=True first." + ) model_id = str((config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini") temperature = config.completion_params.get("temperature", 0.0) @@ -260,7 +301,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> ) # Create MCP environments directly from evaluation_rows envs = ep.make( - "http://localhost:9700/mcp/", + f"http://localhost:{port}/mcp/", evaluation_rows=rows, model_id=self.policy.model_id, ) @@ -278,6 +319,13 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> def cleanup(self) -> None: """Cleanup MCP server and environments.""" + # For shared mode, don't stop the shared server here; rely on global cleanup + # (atexit or an explicit class-level shutdown) so multiple users can share it. + if self.server_mode == "shared": + self.policy = None + return + + # Per-run mode: stop this instance's server if self.server: self.server.stop() self.server = None diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 763e3081..8f96021b 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -683,11 +683,21 @@ async def _collect_result(config, lst): ) pytest_wrapper = pytest.mark.asyncio(pytest_wrapper) + ep_params: dict[str, Any] = { + "rollout_processor": rollout_processor, + "server_script_path": server_script_path, + "mcp_config_path": mcp_config_path, + "rollout_processor_kwargs": rollout_processor_kwargs, + "mode": mode, + } + # Create the dual mode wrapper dual_mode_wrapper = create_dual_mode_wrapper( test_func, mode, max_concurrent_rollouts, max_concurrent_evaluations, pytest_wrapper ) + setattr(dual_mode_wrapper, "__ep_params__", ep_params) + # Make this pytest discoverable regardless of pytest configuration. So # you can name your eval whatever you want, as long as it's decorated # with @evaluation_test. diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index 26b0d799..b4d1c218 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -4,7 +4,7 @@ import re import sys from dataclasses import replace -from typing import Any, Literal, Callable, AsyncGenerator +from typing import Any, Literal, Callable, AsyncGenerator, Optional from litellm.cost_calculator import cost_per_token from tqdm import tqdm @@ -22,8 +22,10 @@ from eval_protocol.data_loader import DynamicDataLoader from eval_protocol.data_loader.models import EvaluationDataLoader from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor from eval_protocol.pytest.types import ( RolloutProcessorConfig, + ServerMode, ) from eval_protocol.pytest.exception_config import get_default_exception_handler_config @@ -530,3 +532,63 @@ def add_cost_metrics(row: EvaluationRow) -> None: output_cost=output_cost, total_cost_dollar=total_cost, ) + + +def build_rollout_processor_config( + rollout_processor: RolloutProcessor, + model: str, + semaphore: asyncio.Semaphore, + temperature: float = 0.0, + max_tokens: int = 4096, + steps: int = 30, + mcp_config_path: str = "", + server_script_path: Optional[str] = None, + rollout_processor_kwargs: Optional[dict[str, Any]] = None, + start_server: bool = True, + server_mode: Optional[ServerMode] = None, +) -> RolloutProcessorConfig: + """Build rollout processor config with appropriate parameters for different processor types. + + Args: + rollout_processor: The rollout processor instance + model: Model name/path for completion_params + semaphore: Semaphore for concurrency control + temperature: Temperature for completion_params + max_tokens: Max tokens for completion_params + steps: Number of rollout steps + mcp_config_path: Path to MCP config file + server_script_path: Path to server script (required for MCPGymRolloutProcessor) + rollout_processor_kwargs: Additional kwargs to pass to rollout processor + start_server: Whether to start server (for MCPGymRolloutProcessor) + server_mode: Optional server lifecycle mode ("per_run" or "shared") for MCPGymRolloutProcessor + + Returns: + RolloutProcessorConfig: Configured rollout processor config + """ + rollout_processor_kwargs = rollout_processor_kwargs or {} + + completion_params = {"model": model, "temperature": temperature, "max_tokens": max_tokens} + + if isinstance(rollout_processor, MCPGymRolloutProcessor): + base_kwargs = {**(rollout_processor_kwargs or {}), "start_server": start_server} + if server_mode is not None and "server_mode" not in base_kwargs: + base_kwargs["server_mode"] = server_mode + + return RolloutProcessorConfig( + completion_params=completion_params, + mcp_config_path=mcp_config_path, + steps=steps, + semaphore=semaphore, + server_script_path=server_script_path, + kwargs=base_kwargs, + ) + + # RemoteRolloutProcessor, SingleTurnRolloutProcessor, AgentRolloutProcessor, etc. + return RolloutProcessorConfig( + completion_params=completion_params, + mcp_config_path=mcp_config_path, + steps=steps, + semaphore=semaphore, + server_script_path=None, + kwargs=rollout_processor_kwargs, + ) diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index 46719c1c..9603c7b9 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -27,6 +27,12 @@ "all": applies test function to the whole dataset. """ +ServerMode = Literal["per_run", "shared"] +""" +"per_run": start a new MCP server for each eval run / training step, only reuse the same server only for retries within that run. +"shared": start a single MCP server the first time it's needed, then reuse that same server across multiple eval runs / training steps. +""" + """ Test function types """