Skip to content

Commit d91a56f

Browse files
committed
move shared mcp server logic
1 parent bd27ffd commit d91a56f

File tree

4 files changed

+148
-22
lines changed

4 files changed

+148
-22
lines changed

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from eval_protocol.mcp.execution.manager import ExecutionManager
1414
from eval_protocol.models import EvaluationRow
1515
from eval_protocol.pytest.rollout_processor import RolloutProcessor
16-
from eval_protocol.pytest.types import RolloutProcessorConfig
16+
from eval_protocol.pytest.types import RolloutProcessorConfig, ServerMode
1717

1818

1919
class MCPServerManager:
@@ -207,37 +207,78 @@ class MCPGymRolloutProcessor(RolloutProcessor):
207207
using the eval_protocol framework with proper cleanup handling.
208208
"""
209209

210+
# Shared server state for "shared" mode
211+
_shared_server_lock = threading.Lock()
212+
_shared_server: Optional[MCPServerManager] = None
213+
_shared_server_started: bool = False
214+
210215
def __init__(self):
211-
self.server = None
216+
# Instance-level server handle (used in "per_run" mode)
217+
self.server: Optional[MCPServerManager] = None
212218
self.policy = None
219+
# Track which mode this instance last used ("per_run" or "shared")
220+
self.server_mode: ServerMode = "per_run"
213221

214222
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
215223
"""Process evaluation rows with MCP gym environments."""
216-
start_server = config.kwargs.get("start_server", True) if config.kwargs else True
224+
server_kwargs = dict(config.kwargs or {})
225+
start_server = bool(server_kwargs.pop("start_server", True))
226+
server_mode: ServerMode = server_kwargs.pop("server_mode", "per_run")
227+
port = int(server_kwargs.pop("port", 9700))
217228

218-
if start_server:
219-
# Create fresh MCP server and environments for this run
220-
if config.server_script_path is None:
221-
raise ValueError("server_script_path is required for MCPGymRolloutProcessor")
229+
self.server_mode = server_mode
222230

223-
self.server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {}))
231+
if server_mode == "shared":
232+
# Shared, class-level server used across calls
233+
if start_server and not MCPGymRolloutProcessor._shared_server_started:
234+
with MCPGymRolloutProcessor._shared_server_lock:
235+
if not MCPGymRolloutProcessor._shared_server_started:
236+
if config.server_script_path is None:
237+
raise ValueError("server_script_path is required for MCPGymRolloutProcessor")
224238

225-
try:
226-
self.server.start()
239+
shared_server = MCPServerManager(config.server_script_path, port=port, **server_kwargs)
227240

228-
except Exception as e:
229-
if self.server:
230-
self.server.stop()
231-
self.server = None
232-
self.policy = None
233-
raise e
241+
try:
242+
shared_server.start()
243+
except Exception as e:
244+
shared_server.stop()
245+
raise e
234246

235-
else:
236-
# Reuse existing MCP environments for retry
237-
if not self.server:
247+
MCPGymRolloutProcessor._shared_server = shared_server
248+
MCPGymRolloutProcessor._shared_server_started = True
249+
250+
if MCPGymRolloutProcessor._shared_server is None:
238251
raise RuntimeError(
239-
"Cannot retry without existing server/environments. Call with start_server=True first."
252+
"Shared MCP server not started. Call with server_mode='shared' and start_server=True first."
240253
)
254+
# Bind this instance to the shared server for this call
255+
self.server = MCPGymRolloutProcessor._shared_server
256+
257+
else:
258+
# Default "per_run" behavior: fresh server per call, reused only for retries
259+
if start_server:
260+
# Create fresh MCP server and environments for this run
261+
if config.server_script_path is None:
262+
raise ValueError("server_script_path is required for MCPGymRolloutProcessor")
263+
264+
self.server = MCPServerManager(config.server_script_path, port=port, **server_kwargs)
265+
266+
try:
267+
self.server.start()
268+
269+
except Exception as e:
270+
if self.server:
271+
self.server.stop()
272+
self.server = None
273+
self.policy = None
274+
raise e
275+
276+
else:
277+
# Reuse existing MCP environments for retry (per_run mode)
278+
if not self.server:
279+
raise RuntimeError(
280+
"Cannot retry without existing server/environments. Call with start_server=True first."
281+
)
241282

242283
model_id = str((config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini")
243284
temperature = config.completion_params.get("temperature", 0.0)
@@ -260,7 +301,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
260301
)
261302
# Create MCP environments directly from evaluation_rows
262303
envs = ep.make(
263-
"http://localhost:9700/mcp/",
304+
f"http://localhost:{port}/mcp/",
264305
evaluation_rows=rows,
265306
model_id=self.policy.model_id,
266307
)
@@ -278,6 +319,13 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
278319

279320
def cleanup(self) -> None:
280321
"""Cleanup MCP server and environments."""
322+
# For shared mode, don't stop the shared server here; rely on global cleanup
323+
# (atexit or an explicit class-level shutdown) so multiple users can share it.
324+
if self.server_mode == "shared":
325+
self.policy = None
326+
return
327+
328+
# Per-run mode: stop this instance's server
281329
if self.server:
282330
self.server.stop()
283331
self.server = None

eval_protocol/pytest/evaluation_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,11 +683,21 @@ async def _collect_result(config, lst):
683683
)
684684
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)
685685

686+
ep_params: dict[str, Any] = {
687+
"rollout_processor": rollout_processor,
688+
"server_script_path": server_script_path,
689+
"mcp_config_path": mcp_config_path,
690+
"rollout_processor_kwargs": rollout_processor_kwargs,
691+
"mode": mode,
692+
}
693+
686694
# Create the dual mode wrapper
687695
dual_mode_wrapper = create_dual_mode_wrapper(
688696
test_func, mode, max_concurrent_rollouts, max_concurrent_evaluations, pytest_wrapper
689697
)
690698

699+
setattr(dual_mode_wrapper, "__ep_params__", ep_params)
700+
691701
# Make this pytest discoverable regardless of pytest configuration. So
692702
# you can name your eval whatever you want, as long as it's decorated
693703
# with @evaluation_test.

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
import sys
66
from dataclasses import replace
7-
from typing import Any, Literal, Callable, AsyncGenerator
7+
from typing import Any, Literal, Callable, AsyncGenerator, Optional
88

99
from litellm.cost_calculator import cost_per_token
1010
from tqdm import tqdm
@@ -22,8 +22,10 @@
2222
from eval_protocol.data_loader import DynamicDataLoader
2323
from eval_protocol.data_loader.models import EvaluationDataLoader
2424
from eval_protocol.pytest.rollout_processor import RolloutProcessor
25+
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
2526
from eval_protocol.pytest.types import (
2627
RolloutProcessorConfig,
28+
ServerMode,
2729
)
2830
from eval_protocol.pytest.exception_config import get_default_exception_handler_config
2931

@@ -530,3 +532,63 @@ def add_cost_metrics(row: EvaluationRow) -> None:
530532
output_cost=output_cost,
531533
total_cost_dollar=total_cost,
532534
)
535+
536+
537+
def build_rollout_processor_config(
538+
rollout_processor: RolloutProcessor,
539+
model: str,
540+
semaphore: asyncio.Semaphore,
541+
temperature: float = 0.0,
542+
max_tokens: int = 4096,
543+
steps: int = 30,
544+
mcp_config_path: str = "",
545+
server_script_path: Optional[str] = None,
546+
rollout_processor_kwargs: Optional[dict[str, Any]] = None,
547+
start_server: bool = True,
548+
server_mode: Optional[ServerMode] = None,
549+
) -> RolloutProcessorConfig:
550+
"""Build rollout processor config with appropriate parameters for different processor types.
551+
552+
Args:
553+
rollout_processor: The rollout processor instance
554+
model: Model name/path for completion_params
555+
semaphore: Semaphore for concurrency control
556+
temperature: Temperature for completion_params
557+
max_tokens: Max tokens for completion_params
558+
steps: Number of rollout steps
559+
mcp_config_path: Path to MCP config file
560+
server_script_path: Path to server script (required for MCPGymRolloutProcessor)
561+
rollout_processor_kwargs: Additional kwargs to pass to rollout processor
562+
start_server: Whether to start server (for MCPGymRolloutProcessor)
563+
server_mode: Optional server lifecycle mode ("per_run" or "shared") for MCPGymRolloutProcessor
564+
565+
Returns:
566+
RolloutProcessorConfig: Configured rollout processor config
567+
"""
568+
rollout_processor_kwargs = rollout_processor_kwargs or {}
569+
570+
completion_params = {"model": model, "temperature": temperature, "max_tokens": max_tokens}
571+
572+
if isinstance(rollout_processor, MCPGymRolloutProcessor):
573+
base_kwargs = {**(rollout_processor_kwargs or {}), "start_server": start_server}
574+
if server_mode is not None and "server_mode" not in base_kwargs:
575+
base_kwargs["server_mode"] = server_mode
576+
577+
return RolloutProcessorConfig(
578+
completion_params=completion_params,
579+
mcp_config_path=mcp_config_path,
580+
steps=steps,
581+
semaphore=semaphore,
582+
server_script_path=server_script_path,
583+
kwargs=base_kwargs,
584+
)
585+
586+
# RemoteRolloutProcessor, SingleTurnRolloutProcessor, AgentRolloutProcessor, etc.
587+
return RolloutProcessorConfig(
588+
completion_params=completion_params,
589+
mcp_config_path=mcp_config_path,
590+
steps=steps,
591+
semaphore=semaphore,
592+
server_script_path=None,
593+
kwargs=rollout_processor_kwargs,
594+
)

eval_protocol/pytest/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
"all": applies test function to the whole dataset.
2828
"""
2929

30+
ServerMode = Literal["per_run", "shared"]
31+
"""
32+
"per_run": start a new MCP server for each eval run / training step, only reuse the same server only for retries within that run.
33+
"shared": start a single MCP server the first time it's needed, then reuse that same server across multiple eval runs / training steps.
34+
"""
35+
3036
"""
3137
Test function types
3238
"""

0 commit comments

Comments
 (0)