Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import json
import logging
import os
import time
from copy import deepcopy
from functools import lru_cache
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
Expand All @@ -24,6 +23,16 @@
from vendor.tau2.domains.airline.utils import AIRLINE_DB_PATH


@lru_cache(maxsize=1)
def _load_flight_db(path: str) -> FlightDB:
"""Load and cache the flight database for reuse across resets."""

logger.info("🗂️ Loading airline database from disk (cached)")
db_loaded = FlightDB.load(path)
assert isinstance(db_loaded, FlightDB)
return db_loaded


class AirlineEnvironment:
"""
Airline environment that integrates τ²-Bench simulation pattern
Expand All @@ -37,13 +46,10 @@ def __init__(self, config: Optional[Dict[str, Any]] = None):

def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Reset the environment to initial state"""
logger.info("🔄 Resetting airline environment - reloading database from disk")
# FlightDB.load expects a str path
# Ensure type matches expected FlightDB
# FlightDB.load returns vendor.tau2.domains.airline.data_model.FlightDB which is compatible
db_loaded = FlightDB.load(str(AIRLINE_DB_PATH))
assert isinstance(db_loaded, FlightDB)
self.db = db_loaded
logger.info("🔄 Resetting airline environment - using cached airline database")
cached_db = _load_flight_db(str(AIRLINE_DB_PATH))
# Provide a fresh copy for each environment reset without re-reading from disk.
self.db = cached_db.model_copy(deep=True)
self.airline_tools = AirlineTools(self.db)

return {}, {}
Expand Down
19 changes: 16 additions & 3 deletions tests/pytest/test_mcp_session_autocreate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
def _run_airline_server():
import os

os.environ["PORT"] = "9780"
python_version = os.environ.get("PYTHON_VERSION", "3.10").replace(".", "")
port = str(9780 + int(python_version[-1:]))
os.environ["PORT"] = port
from eval_protocol.mcp_servers.tau2.tau2_mcp import AirlineDomainMcp

server = AirlineDomainMcp(seed=None)
Expand All @@ -25,24 +27,35 @@ def _run_airline_server():

@pytest.mark.asyncio
async def test_tool_call_returns_json_without_prior_initial_state():
import os

proc = Process(target=_run_airline_server, daemon=True)
proc.start()

try:
base_url = "http://127.0.0.1:9780/mcp"
python_version = os.environ.get("PYTHON_VERSION", "3.10").replace(".", "")
port = str(9780 + int(python_version[-1:]))

base_url = f"http://127.0.0.1:{port}/mcp"
client = httpx.Client(timeout=1.0)
deadline = time.time() + 20
start_time = time.time()
deadline = start_time + 20
ready_time = None
while time.time() < deadline:
try:
r = client.get(base_url)
if r.status_code in (200, 307, 406):
ready_time = time.time()
break
except Exception:
pass
time.sleep(0.2)
else:
pytest.fail("Server did not start on port 9780 in time")

assert ready_time is not None, "Server did not return a successful status before exiting loop"
assert ready_time - start_time < 20, f"Server took too long to respond: {ready_time - start_time:.2f}s"

session = MCPSession(base_url=base_url, session_id="test-autocreate", seed=None, model_id="test-model")

mgr = MCPConnectionManager()
Expand Down
Loading