Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 69 additions & 21 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from eval_protocol.mcp.execution.manager import ExecutionManager
from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from eval_protocol.pytest.types import RolloutProcessorConfig, ServerMode


class MCPServerManager:
Expand Down Expand Up @@ -207,37 +207,78 @@ class MCPGymRolloutProcessor(RolloutProcessor):
using the eval_protocol framework with proper cleanup handling.
"""

# Shared server state for "shared" mode
_shared_server_lock = threading.Lock()
_shared_server: Optional[MCPServerManager] = None
_shared_server_started: bool = False

def __init__(self):
self.server = None
# Instance-level server handle (used in "per_run" mode)
self.server: Optional[MCPServerManager] = None
self.policy = None
# Track which mode this instance last used ("per_run" or "shared")
self.server_mode: ServerMode = "per_run"

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Process evaluation rows with MCP gym environments."""
start_server = config.kwargs.get("start_server", True) if config.kwargs else True
server_kwargs = dict(config.kwargs or {})
start_server = bool(server_kwargs.pop("start_server", True))
server_mode: ServerMode = server_kwargs.pop("server_mode", "per_run")
port = int(server_kwargs.pop("port", 9700))

if start_server:
# Create fresh MCP server and environments for this run
if config.server_script_path is None:
raise ValueError("server_script_path is required for MCPGymRolloutProcessor")
self.server_mode = server_mode

self.server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {}))
if server_mode == "shared":
# Shared, class-level server used across calls
if start_server and not MCPGymRolloutProcessor._shared_server_started:
with MCPGymRolloutProcessor._shared_server_lock:
if not MCPGymRolloutProcessor._shared_server_started:
if config.server_script_path is None:
raise ValueError("server_script_path is required for MCPGymRolloutProcessor")

try:
self.server.start()
shared_server = MCPServerManager(config.server_script_path, port=port, **server_kwargs)

except Exception as e:
if self.server:
self.server.stop()
self.server = None
self.policy = None
raise e
try:
shared_server.start()
except Exception as e:
shared_server.stop()
raise e

else:
# Reuse existing MCP environments for retry
if not self.server:
MCPGymRolloutProcessor._shared_server = shared_server
MCPGymRolloutProcessor._shared_server_started = True

if MCPGymRolloutProcessor._shared_server is None:
raise RuntimeError(
"Cannot retry without existing server/environments. Call with start_server=True first."
"Shared MCP server not started. Call with server_mode='shared' and start_server=True first."
)
# Bind this instance to the shared server for this call
self.server = MCPGymRolloutProcessor._shared_server
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Shared Server Port Mismatch Breaks Connections

In shared server mode, the port used to create the shared server is not validated against the port requested in subsequent calls. If the first call creates a shared server on port 9700, but a second call requests port 8000, the code reuses the server on port 9700 but attempts to connect to port 8000 at line 304, causing connection failures.

Fix in Cursor Fix in Web


else:
# Default "per_run" behavior: fresh server per call, reused only for retries
if start_server:
# Create fresh MCP server and environments for this run
if config.server_script_path is None:
raise ValueError("server_script_path is required for MCPGymRolloutProcessor")

self.server = MCPServerManager(config.server_script_path, port=port, **server_kwargs)

try:
self.server.start()

except Exception as e:
if self.server:
self.server.stop()
self.server = None
self.policy = None
raise e

else:
# Reuse existing MCP environments for retry (per_run mode)
if not self.server:
raise RuntimeError(
"Cannot retry without existing server/environments. Call with start_server=True first."
)

model_id = str((config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini")
temperature = config.completion_params.get("temperature", 0.0)
Expand All @@ -260,7 +301,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
)
# Create MCP environments directly from evaluation_rows
envs = ep.make(
"http://localhost:9700/mcp/",
f"http://localhost:{port}/mcp/",
evaluation_rows=rows,
model_id=self.policy.model_id,
)
Expand All @@ -278,6 +319,13 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->

def cleanup(self) -> None:
"""Cleanup MCP server and environments."""
# For shared mode, don't stop the shared server here; rely on global cleanup
# (atexit or an explicit class-level shutdown) so multiple users can share it.
if self.server_mode == "shared":
self.policy = None
return

# Per-run mode: stop this instance's server
if self.server:
self.server.stop()
self.server = None
Expand Down
10 changes: 10 additions & 0 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,11 +683,21 @@ async def _collect_result(config, lst):
)
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)

ep_params: dict[str, Any] = {
"rollout_processor": rollout_processor,
"server_script_path": server_script_path,
"mcp_config_path": mcp_config_path,
"rollout_processor_kwargs": rollout_processor_kwargs,
"mode": mode,
}

# Create the dual mode wrapper
dual_mode_wrapper = create_dual_mode_wrapper(
test_func, mode, max_concurrent_rollouts, max_concurrent_evaluations, pytest_wrapper
)

setattr(dual_mode_wrapper, "__ep_params__", ep_params)

# Make this pytest discoverable regardless of pytest configuration. So
# you can name your eval whatever you want, as long as it's decorated
# with @evaluation_test.
Expand Down
64 changes: 63 additions & 1 deletion eval_protocol/pytest/evaluation_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import sys
from dataclasses import replace
from typing import Any, Literal, Callable, AsyncGenerator
from typing import Any, Literal, Callable, AsyncGenerator, Optional

from litellm.cost_calculator import cost_per_token
from tqdm import tqdm
Expand All @@ -22,8 +22,10 @@
from eval_protocol.data_loader import DynamicDataLoader
from eval_protocol.data_loader.models import EvaluationDataLoader
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
from eval_protocol.pytest.types import (
RolloutProcessorConfig,
ServerMode,
)
from eval_protocol.pytest.exception_config import get_default_exception_handler_config

Expand Down Expand Up @@ -530,3 +532,63 @@ def add_cost_metrics(row: EvaluationRow) -> None:
output_cost=output_cost,
total_cost_dollar=total_cost,
)


def build_rollout_processor_config(
rollout_processor: RolloutProcessor,
model: str,
semaphore: asyncio.Semaphore,
temperature: float = 0.0,
max_tokens: int = 4096,
steps: int = 30,
mcp_config_path: str = "",
server_script_path: Optional[str] = None,
rollout_processor_kwargs: Optional[dict[str, Any]] = None,
start_server: bool = True,
server_mode: Optional[ServerMode] = None,
) -> RolloutProcessorConfig:
"""Build rollout processor config with appropriate parameters for different processor types.

Args:
rollout_processor: The rollout processor instance
model: Model name/path for completion_params
semaphore: Semaphore for concurrency control
temperature: Temperature for completion_params
max_tokens: Max tokens for completion_params
steps: Number of rollout steps
mcp_config_path: Path to MCP config file
server_script_path: Path to server script (required for MCPGymRolloutProcessor)
rollout_processor_kwargs: Additional kwargs to pass to rollout processor
start_server: Whether to start server (for MCPGymRolloutProcessor)
server_mode: Optional server lifecycle mode ("per_run" or "shared") for MCPGymRolloutProcessor

Returns:
RolloutProcessorConfig: Configured rollout processor config
"""
rollout_processor_kwargs = rollout_processor_kwargs or {}

completion_params = {"model": model, "temperature": temperature, "max_tokens": max_tokens}

if isinstance(rollout_processor, MCPGymRolloutProcessor):
base_kwargs = {**(rollout_processor_kwargs or {}), "start_server": start_server}
if server_mode is not None and "server_mode" not in base_kwargs:
base_kwargs["server_mode"] = server_mode

return RolloutProcessorConfig(
completion_params=completion_params,
mcp_config_path=mcp_config_path,
steps=steps,
semaphore=semaphore,
server_script_path=server_script_path,
kwargs=base_kwargs,
)

# RemoteRolloutProcessor, SingleTurnRolloutProcessor, AgentRolloutProcessor, etc.
return RolloutProcessorConfig(
completion_params=completion_params,
mcp_config_path=mcp_config_path,
steps=steps,
semaphore=semaphore,
server_script_path=None,
kwargs=rollout_processor_kwargs,
)
6 changes: 6 additions & 0 deletions eval_protocol/pytest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
"all": applies test function to the whole dataset.
"""

ServerMode = Literal["per_run", "shared"]
"""
"per_run": start a new MCP server for each eval run / training step, only reuse the same server only for retries within that run.
"shared": start a single MCP server the first time it's needed, then reuse that same server across multiple eval runs / training steps.
"""

"""
Test function types
"""
Expand Down