Skip to content

Commit 68282d9

Browse files
committed
fixing mcp gym rollout processor
1 parent d88e155 commit 68282d9

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

eval_protocol/mcp_servers/frozen_lake/server.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414
import sys
1515
from pathlib import Path
1616

17-
# Add root directory to path so we can import eval_protocol
18-
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
17+
# Add current directory first for local imports (frozen_lake_mcp)
18+
sys.path.insert(0, str(Path(__file__).parent))
19+
20+
# Add eval_protocol parent to path, but use append to avoid priority conflicts
21+
parent_dir = str(Path(__file__).parent.parent.parent)
22+
if parent_dir not in sys.path:
23+
sys.path.append(parent_dir)
1924

2025
from frozen_lake_mcp import FrozenLakeMcp
2126

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class MCPServerManager:
2525
def __init__(self, server_script: str, port: int = 8000, **kwargs):
2626
self.server_script = server_script
2727
self.port = port
28-
self.domain = str(kwargs.get("domain", "airline"))
28+
self.domain = kwargs.get("domain", None)
2929
self.process: Optional[subprocess.Popen] = None
3030
self.base_dir = Path(".").resolve()
3131
self._log_file = None
@@ -59,11 +59,14 @@ def start(self) -> None:
5959
env = os.environ.copy()
6060
env["PORT"] = str(self.port)
6161

62-
# Start server process (no domain argument needed for tau2_mcp server)
63-
cmd = ["python", self.server_script, "--port", str(self.port), "--domain", self.domain]
62+
# Build command, add --domain only if provided (e.g. tau2 needs it, frozen_lake doesn't)
63+
cmd = ["python", self.server_script, "--port", str(self.port)]
64+
if self.domain:
65+
cmd.extend(["--domain", self.domain])
6466

6567
# Setup log file with cleanup
66-
log_file_path = os.path.join(self.base_dir, f"server_output_{self.domain}_{self.port}.log")
68+
domain_part = self.domain if self.domain else "server"
69+
log_file_path = os.path.join(self.base_dir, f"server_output_{domain_part}_{self.port}.log")
6770
if os.path.exists(log_file_path):
6871
os.remove(log_file_path)
6972

tests/pytest/test_tau_bench_airline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
7373
}
7474
],
7575
rollout_processor=MCPGymRolloutProcessor(),
76+
rollout_processor_kwargs={"domain": "airline"},
7677
passed_threshold={"success": 0.4, "standard_error": 0.02},
7778
num_runs=8,
7879
mode="pointwise",

0 commit comments

Comments
 (0)