-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdefault_mcp_gym_rollout_processor.py
More file actions
270 lines (226 loc) · 9.66 KB
/
default_mcp_gym_rollout_processor.py
File metadata and controls
270 lines (226 loc) · 9.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import asyncio
import atexit
import os
import signal
import socket
import subprocess
import time
from pathlib import Path
from typing import List, Optional
import eval_protocol as ep
from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from eval_protocol.mcp.execution.manager import ExecutionManager
class MCPServerManager:
"""Manages MCP server lifecycle for testing."""
# Class-level tracking of all server instances
_active_servers = []
_cleanup_registered = False
def __init__(self, server_script: str, port: int = 8000, **kwargs):
self.server_script = server_script
self.port = port
self.domain = str(kwargs.get("domain", "airline"))
self.process: Optional[subprocess.Popen] = None
self.base_dir = Path(".").resolve()
self._log_file = None
self._log_file_path = None
# Register this server for cleanup
MCPServerManager._active_servers.append(self)
# Register cleanup handlers only once
if not MCPServerManager._cleanup_registered:
MCPServerManager._register_cleanup_handlers()
MCPServerManager._cleanup_registered = True
def start(self) -> None:
"""Start the MCP server."""
if self.process:
return
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
result = s.connect_ex(("localhost", self.port))
if result == 0:
raise RuntimeError(
f"Port {self.port} is already in use! Please use a different port or kill the process using it."
)
except socket.error:
pass
# Set environment for server
env = os.environ.copy()
env["PORT"] = str(self.port)
# Start server process (no domain argument needed for tau2_mcp server)
cmd = ["python", self.server_script, "--port", str(self.port), "--domain", self.domain]
# Setup log file with cleanup
log_file_path = os.path.join(self.base_dir, f"server_output_{self.domain}_{self.port}.log")
if os.path.exists(log_file_path):
os.remove(log_file_path)
log_file = open(log_file_path, "w")
self.process = subprocess.Popen(
cmd,
cwd=self.base_dir,
env=env,
stdout=log_file,
stderr=log_file,
text=True,
)
# Store log file reference for cleanup
self._log_file = log_file
self._log_file_path = log_file_path
# Wait for server to be ready with proper health check
if not self._wait_for_server_ready(timeout=15):
try:
with open(self._log_file_path, "r") as f:
log_content = f.read()
print("❌ Server failed to start!")
print(f"📋 Server log ({self._log_file_path}):")
print("=" * 50)
print(log_content)
print("=" * 50)
raise RuntimeError("Server failed to start or become ready. Check log above for details.")
except Exception as e:
stdout, stderr = self.process.communicate()
raise RuntimeError(f"Server failed to start or become ready. stderr: {stderr}, log error: {e}")
print(f"✅ Server started successfully on port {self.port}")
def _wait_for_server_ready(self, timeout: int = 15) -> bool:
"""
Wait for server to be ready by polling socket connection.
"""
start_time = time.time()
health_check_failures = 0
while time.time() - start_time < timeout:
# Check if process is still running
if self.process and self.process.poll() is not None:
print("Server process exited early")
return False
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
result = s.connect_ex(("localhost", self.port))
if result == 0:
time.sleep(0.5)
return True
except Exception as e:
health_check_failures += 1
# Print first few failures for debugging
if health_check_failures <= 3:
print(f"Health check failed: {e}")
# Wait before next check
time.sleep(0.1)
print(f"Server failed to become ready within {timeout} seconds")
return False
def stop(self) -> None:
"""Stop the MCP server."""
if self.process:
print(f"🛑 Stopping server on port {self.port}...")
self.process.terminate()
try:
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(f"⚡ Force killing server on port {self.port}...")
self.process.kill()
self.process.wait()
self.process = None
# Clean up log file
if self._log_file:
try:
self._log_file.close()
except Exception:
pass
self._log_file = None
# Remove from active servers list
if self in MCPServerManager._active_servers:
MCPServerManager._active_servers.remove(self)
@classmethod
def _cleanup_all_servers(cls):
"""Clean up all active servers on exit"""
print(f"\n🧹 Cleaning up {len(cls._active_servers)} active servers...")
for server in cls._active_servers.copy():
try:
server.stop()
except Exception as e:
print(f"⚠️ Error stopping server: {e}")
cls._active_servers.clear()
@classmethod
def _signal_handler(cls, signum, frame):
"""Handle interrupt signals"""
print(f"\n🛑 Received signal {signum}, cleaning up...")
cls._cleanup_all_servers()
exit(1)
@classmethod
def _register_cleanup_handlers(cls):
"""Register cleanup handlers - called only once"""
atexit.register(cls._cleanup_all_servers)
signal.signal(signal.SIGINT, cls._signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, cls._signal_handler) # Termination signal
def __enter__(self):
"""Context manager entry"""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - ensures cleanup even on exceptions"""
self.stop()
if exc_type:
print(f"⚠️ Server cleanup after exception: {exc_type.__name__}")
return False # Don't suppress exceptions
class MCPGymRolloutProcessor(RolloutProcessor):
"""
Rollout processor for tau bench environments.
This processor starts an MCP server, creates tau bench environments, and returns rollout tasks
using the eval_protocol framework with proper cleanup handling.
"""
def __init__(self):
self.server = None
self.policy = None
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Process evaluation rows with MCP gym environments."""
start_server = config.kwargs.get("start_server", True) if config.kwargs else True
if start_server:
# Create fresh MCP server and environments for this run
if config.server_script_path is None:
raise ValueError("server_script_path is required for MCPGymRolloutProcessor")
self.server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {}))
try:
self.server.start()
self.policy = ep.LiteLLMPolicy(
model_id=str(
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
),
temperature=config.completion_params.get("temperature", 0.0),
max_tokens=config.completion_params.get("max_tokens", 4096),
**(config.completion_params.get("extra_body", {}) or {}),
)
except Exception as e:
if self.server:
self.server.stop()
self.server = None
self.policy = None
raise e
else:
# Reuse existing MCP environments for retry
if not self.server or not self.policy:
raise RuntimeError(
"Cannot retry without existing server/environments. Call with start_server=True first."
)
# Create MCP environments directly from evaluation_rows
assert self.policy is not None, "Policy must be initialized before rollout"
envs = ep.make(
"http://localhost:9700/mcp/",
evaluation_rows=rows,
model_id=self.policy.model_id,
)
# TODO: chat with benny/dylan about when they're back. can we just bypass ep.rollout()? i don't really see the point of it anymore. or turn it into a return list of tasks.
execution_manager = ExecutionManager()
tasks = execution_manager.execute_rollouts(
envs,
policy=self.policy,
semaphore=config.semaphore,
steps=config.steps,
evaluation_rows=rows,
)
return tasks
def cleanup(self) -> None:
"""Cleanup MCP server and environments."""
if self.server:
self.server.stop()
self.server = None
self.policy = None