Skip to content

Commit 4a8fd6c

Browse files
committed
Cache airline database between resets
1 parent 64abf2d commit 4a8fd6c

File tree

4 files changed

+109
-10
lines changed

4 files changed

+109
-10
lines changed

eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import logging
1111
import os
1212
import time
13-
from copy import deepcopy
1413
from dataclasses import dataclass, field
1514
from enum import Enum
1615
from pathlib import Path
@@ -24,6 +23,9 @@
2423
from vendor.tau2.domains.airline.utils import AIRLINE_DB_PATH
2524

2625

26+
_CACHED_FLIGHT_DB: Optional[FlightDB] = None
27+
28+
2729
class AirlineEnvironment:
2830
"""
2931
Airline environment that integrates τ²-Bench simulation pattern
@@ -37,13 +39,18 @@ def __init__(self, config: Optional[Dict[str, Any]] = None):
3739

3840
def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
3941
"""Reset the environment to initial state"""
40-
logger.info("🔄 Resetting airline environment - reloading database from disk")
41-
# FlightDB.load expects a str path
42-
# Ensure type matches expected FlightDB
43-
# FlightDB.load returns vendor.tau2.domains.airline.data_model.FlightDB which is compatible
44-
db_loaded = FlightDB.load(str(AIRLINE_DB_PATH))
45-
assert isinstance(db_loaded, FlightDB)
46-
self.db = db_loaded
42+
global _CACHED_FLIGHT_DB
43+
44+
if _CACHED_FLIGHT_DB is None:
45+
logger.info("🔄 Resetting airline environment - loading flight database from disk")
46+
db_loaded = FlightDB.load(str(AIRLINE_DB_PATH))
47+
assert isinstance(db_loaded, FlightDB)
48+
_CACHED_FLIGHT_DB = db_loaded
49+
else:
50+
logger.info("🔄 Resetting airline environment - using cached flight database")
51+
52+
assert isinstance(_CACHED_FLIGHT_DB, FlightDB)
53+
self.db = _CACHED_FLIGHT_DB.model_copy(deep=True)
4754
self.airline_tools = AirlineTools(self.db)
4855

4956
return {}, {}

tests/pytest/test_mcp_session_autocreate.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
"""
2-
Regression test: ensure MCP-Gym auto-creates a session on first tool call
3-
without requiring a prior initial state fetch, and returns JSON.
2+
Regression tests for the airline MCP server.
3+
4+
The tests in this module ensure we can exercise key behaviours from the
5+
AirlineDomainMcp server. They also act as regression coverage for performance
6+
fixes that impact readiness probes.
47
"""
58

9+
import importlib.util
10+
from pathlib import Path
611
import time
712
from multiprocessing import Process
813

@@ -13,6 +18,55 @@
1318
from eval_protocol.types import MCPSession
1419

1520

21+
def _load_airline_environment_module():
22+
module_name = "airline_environment_for_test"
23+
module_path = (
24+
Path(__file__).resolve().parents[2]
25+
/ "eval_protocol"
26+
/ "mcp_servers"
27+
/ "tau2"
28+
/ "airplane_environment"
29+
/ "airline_environment.py"
30+
)
31+
spec = importlib.util.spec_from_file_location(module_name, module_path)
32+
assert spec and spec.loader is not None
33+
module = importlib.util.module_from_spec(spec)
34+
spec.loader.exec_module(module)
35+
return module
36+
37+
38+
def test_airline_environment_reset_uses_cached_db(monkeypatch):
39+
"""AirlineEnvironment should only hit disk the first time it's reset."""
40+
41+
airline_module = _load_airline_environment_module()
42+
43+
from vendor.tau2.environment import db as db_module
44+
45+
load_file_calls = 0
46+
original_load_file = db_module.load_file
47+
48+
def counting_load_file(path: str, *args, **kwargs):
49+
nonlocal load_file_calls
50+
load_file_calls += 1
51+
return original_load_file(path, *args, **kwargs)
52+
53+
monkeypatch.setattr(db_module, "load_file", counting_load_file)
54+
55+
env = airline_module.AirlineEnvironment()
56+
57+
env.reset()
58+
assert load_file_calls == 1
59+
assert env.db.users, "Expected seeded users in the airline database"
60+
61+
user_id, user = next(iter(env.db.users.items()))
62+
original_first_name = user.name.first_name
63+
env.db.users[user_id].name.first_name = "Changed"
64+
65+
env.reset()
66+
assert load_file_calls == 1
67+
assert env.db.users[user_id].name.first_name == original_first_name
68+
69+
1670
def _run_airline_server():
1771
import os
1872

typings/langfuse/__init__.pyi

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Any
2+
3+
class LangfuseTraceAPI:
4+
def list(self, *args: Any, **kwargs: Any) -> Any: ...
5+
def get(self, *args: Any, **kwargs: Any) -> Any: ...
6+
7+
class LangfuseApi:
8+
trace: LangfuseTraceAPI
9+
10+
class LangfuseClient:
11+
api: LangfuseApi
12+
13+
def create_score(self, *args: Any, **kwargs: Any) -> Any: ...
14+
15+
class ObservationsView:
16+
...
17+
18+
class Trace:
19+
...
20+
21+
class TraceWithFullDetails:
22+
...
23+
24+
class Traces:
25+
...
26+
27+
class Langfuse:
28+
...
29+
30+
def get_client(*args: Any, **kwargs: Any) -> LangfuseClient: ...

typings/langsmith/__init__.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from typing import Any, Iterable
2+
3+
class Run:
4+
...
5+
6+
class Client:
7+
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
8+
def list_runs(self, *args: Any, **kwargs: Any) -> Iterable[Run]: ...

0 commit comments

Comments
 (0)