diff --git a/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py b/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py index 1e9b00dd..ce9d905a 100644 --- a/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py +++ b/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py @@ -9,8 +9,7 @@ import json import logging import os -import time -from copy import deepcopy +from functools import lru_cache from dataclasses import dataclass, field from enum import Enum from pathlib import Path @@ -24,6 +23,16 @@ from vendor.tau2.domains.airline.utils import AIRLINE_DB_PATH +@lru_cache(maxsize=1) +def _load_flight_db(path: str) -> FlightDB: + """Load and cache the flight database for reuse across resets.""" + + logger.info("🗂️ Loading airline database from disk (cached)") + db_loaded = FlightDB.load(path) + assert isinstance(db_loaded, FlightDB) + return db_loaded + + class AirlineEnvironment: """ Airline environment that integrates τ²-Bench simulation pattern @@ -37,13 +46,10 @@ def __init__(self, config: Optional[Dict[str, Any]] = None): def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Reset the environment to initial state""" - logger.info("🔄 Resetting airline environment - reloading database from disk") - # FlightDB.load expects a str path - # Ensure type matches expected FlightDB - # FlightDB.load returns vendor.tau2.domains.airline.data_model.FlightDB which is compatible - db_loaded = FlightDB.load(str(AIRLINE_DB_PATH)) - assert isinstance(db_loaded, FlightDB) - self.db = db_loaded + logger.info("🔄 Resetting airline environment - using cached airline database") + cached_db = _load_flight_db(str(AIRLINE_DB_PATH)) + # Provide a fresh copy for each environment reset without re-reading from disk. + self.db = cached_db.model_copy(deep=True) self.airline_tools = AirlineTools(self.db) return {}, {} diff --git a/tests/pytest/test_mcp_session_autocreate.py b/tests/pytest/test_mcp_session_autocreate.py index df816f55..6cc76f50 100644 --- a/tests/pytest/test_mcp_session_autocreate.py +++ b/tests/pytest/test_mcp_session_autocreate.py @@ -16,7 +16,9 @@ def _run_airline_server(): import os - os.environ["PORT"] = "9780" + python_version = os.environ.get("PYTHON_VERSION", "3.10").replace(".", "") + port = str(9780 + int(python_version[-1:])) + os.environ["PORT"] = port from eval_protocol.mcp_servers.tau2.tau2_mcp import AirlineDomainMcp server = AirlineDomainMcp(seed=None) @@ -25,17 +27,25 @@ def _run_airline_server(): @pytest.mark.asyncio async def test_tool_call_returns_json_without_prior_initial_state(): + import os + proc = Process(target=_run_airline_server, daemon=True) proc.start() try: - base_url = "http://127.0.0.1:9780/mcp" + python_version = os.environ.get("PYTHON_VERSION", "3.10").replace(".", "") + port = str(9780 + int(python_version[-1:])) + + base_url = f"http://127.0.0.1:{port}/mcp" client = httpx.Client(timeout=1.0) - deadline = time.time() + 20 + start_time = time.time() + deadline = start_time + 20 + ready_time = None while time.time() < deadline: try: r = client.get(base_url) if r.status_code in (200, 307, 406): + ready_time = time.time() break except Exception: pass @@ -43,6 +53,9 @@ async def test_tool_call_returns_json_without_prior_initial_state(): else: pytest.fail("Server did not start on port 9780 in time") + assert ready_time is not None, "Server did not return a successful status before exiting loop" + assert ready_time - start_time < 20, f"Server took too long to respond: {ready_time - start_time:.2f}s" + session = MCPSession(base_url=base_url, session_id="test-autocreate", seed=None, model_id="test-model") mgr = MCPConnectionManager()