From 4a8fd6c6faa036bb37d55e5baf21e1f8692f312a Mon Sep 17 00:00:00 2001 From: "Yufei (Benny) Chen" <1585539+benjibc@users.noreply.github.com> Date: Thu, 18 Sep 2025 17:02:50 -0700 Subject: [PATCH] Cache airline database between resets --- .../airline_environment.py | 23 +++++--- tests/pytest/test_mcp_session_autocreate.py | 58 ++++++++++++++++++- typings/langfuse/__init__.pyi | 30 ++++++++++ typings/langsmith/__init__.pyi | 8 +++ 4 files changed, 109 insertions(+), 10 deletions(-) create mode 100644 typings/langfuse/__init__.pyi create mode 100644 typings/langsmith/__init__.pyi diff --git a/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py b/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py index 1e9b00dd..c2a181ab 100644 --- a/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py +++ b/eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py @@ -10,7 +10,6 @@ import logging import os import time -from copy import deepcopy from dataclasses import dataclass, field from enum import Enum from pathlib import Path @@ -24,6 +23,9 @@ from vendor.tau2.domains.airline.utils import AIRLINE_DB_PATH +_CACHED_FLIGHT_DB: Optional[FlightDB] = None + + class AirlineEnvironment: """ Airline environment that integrates τ²-Bench simulation pattern @@ -37,13 +39,18 @@ def __init__(self, config: Optional[Dict[str, Any]] = None): def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Reset the environment to initial state""" - logger.info("🔄 Resetting airline environment - reloading database from disk") - # FlightDB.load expects a str path - # Ensure type matches expected FlightDB - # FlightDB.load returns vendor.tau2.domains.airline.data_model.FlightDB which is compatible - db_loaded = FlightDB.load(str(AIRLINE_DB_PATH)) - assert isinstance(db_loaded, FlightDB) - self.db = db_loaded + global _CACHED_FLIGHT_DB + + if _CACHED_FLIGHT_DB is None: + logger.info("🔄 Resetting airline environment - loading flight database from disk") + db_loaded = FlightDB.load(str(AIRLINE_DB_PATH)) + assert isinstance(db_loaded, FlightDB) + _CACHED_FLIGHT_DB = db_loaded + else: + logger.info("🔄 Resetting airline environment - using cached flight database") + + assert isinstance(_CACHED_FLIGHT_DB, FlightDB) + self.db = _CACHED_FLIGHT_DB.model_copy(deep=True) self.airline_tools = AirlineTools(self.db) return {}, {} diff --git a/tests/pytest/test_mcp_session_autocreate.py b/tests/pytest/test_mcp_session_autocreate.py index df816f55..2cba6c40 100644 --- a/tests/pytest/test_mcp_session_autocreate.py +++ b/tests/pytest/test_mcp_session_autocreate.py @@ -1,8 +1,13 @@ """ -Regression test: ensure MCP-Gym auto-creates a session on first tool call -without requiring a prior initial state fetch, and returns JSON. +Regression tests for the airline MCP server. + +The tests in this module ensure we can exercise key behaviours from the +AirlineDomainMcp server. They also act as regression coverage for performance +fixes that impact readiness probes. """ +import importlib.util +from pathlib import Path import time from multiprocessing import Process @@ -13,6 +18,55 @@ from eval_protocol.types import MCPSession +def _load_airline_environment_module(): + module_name = "airline_environment_for_test" + module_path = ( + Path(__file__).resolve().parents[2] + / "eval_protocol" + / "mcp_servers" + / "tau2" + / "airplane_environment" + / "airline_environment.py" + ) + spec = importlib.util.spec_from_file_location(module_name, module_path) + assert spec and spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_airline_environment_reset_uses_cached_db(monkeypatch): + """AirlineEnvironment should only hit disk the first time it's reset.""" + + airline_module = _load_airline_environment_module() + + from vendor.tau2.environment import db as db_module + + load_file_calls = 0 + original_load_file = db_module.load_file + + def counting_load_file(path: str, *args, **kwargs): + nonlocal load_file_calls + load_file_calls += 1 + return original_load_file(path, *args, **kwargs) + + monkeypatch.setattr(db_module, "load_file", counting_load_file) + + env = airline_module.AirlineEnvironment() + + env.reset() + assert load_file_calls == 1 + assert env.db.users, "Expected seeded users in the airline database" + + user_id, user = next(iter(env.db.users.items())) + original_first_name = user.name.first_name + env.db.users[user_id].name.first_name = "Changed" + + env.reset() + assert load_file_calls == 1 + assert env.db.users[user_id].name.first_name == original_first_name + + def _run_airline_server(): import os diff --git a/typings/langfuse/__init__.pyi b/typings/langfuse/__init__.pyi new file mode 100644 index 00000000..70233d50 --- /dev/null +++ b/typings/langfuse/__init__.pyi @@ -0,0 +1,30 @@ +from typing import Any + +class LangfuseTraceAPI: + def list(self, *args: Any, **kwargs: Any) -> Any: ... + def get(self, *args: Any, **kwargs: Any) -> Any: ... + +class LangfuseApi: + trace: LangfuseTraceAPI + +class LangfuseClient: + api: LangfuseApi + + def create_score(self, *args: Any, **kwargs: Any) -> Any: ... + +class ObservationsView: + ... + +class Trace: + ... + +class TraceWithFullDetails: + ... + +class Traces: + ... + +class Langfuse: + ... + +def get_client(*args: Any, **kwargs: Any) -> LangfuseClient: ... diff --git a/typings/langsmith/__init__.pyi b/typings/langsmith/__init__.pyi new file mode 100644 index 00000000..35eb979a --- /dev/null +++ b/typings/langsmith/__init__.pyi @@ -0,0 +1,8 @@ +from typing import Any, Iterable + +class Run: + ... + +class Client: + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def list_runs(self, *args: Any, **kwargs: Any) -> Iterable[Run]: ...