Skip to content

Commit db701ab

Browse files
committed
Cache airline flight DB and harden adapters
1 parent 64abf2d commit db701ab

File tree

4 files changed

+52
-19
lines changed

4 files changed

+52
-19
lines changed

eval_protocol/adapters/langfuse.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import random
1010
import time
1111
from datetime import datetime, timedelta
12-
from typing import Any, Dict, List, Optional, Protocol
12+
from typing import Any, Callable, Dict, List, Optional, Protocol, cast
1313

1414
from eval_protocol.models import EvaluationRow, InputMetadata, Message
1515
from .base import BaseAdapter
@@ -44,14 +44,19 @@ def __call__(
4444
...
4545

4646

47+
LangfuseClient = Any
48+
49+
_get_langfuse_client: Callable[[], Any] | None
50+
4751
try:
48-
from langfuse import get_client # pyright: ignore[reportPrivateImportUsage]
52+
from langfuse import get_client as _get_langfuse_client # type: ignore[attr-defined, reportPrivateImportUsage]
4953
from langfuse.api.resources.trace.types.traces import Traces
5054
from langfuse.api.resources.commons.types.trace import Trace
5155
from langfuse.api.resources.commons.types.trace_with_full_details import TraceWithFullDetails
5256

5357
LANGFUSE_AVAILABLE = True
54-
except ImportError:
58+
except ImportError: # pragma: no cover - optional dependency
59+
_get_langfuse_client = None
5560
LANGFUSE_AVAILABLE = False
5661

5762

@@ -219,7 +224,11 @@ def __init__(self):
219224
if not LANGFUSE_AVAILABLE:
220225
raise ImportError("Langfuse not installed. Install with: pip install 'eval-protocol[langfuse]'")
221226

222-
self.client = get_client()
227+
if _get_langfuse_client is None:
228+
raise ImportError("Langfuse not installed. Install with: pip install 'eval-protocol[langfuse]'")
229+
230+
client_factory = cast(Callable[[], LangfuseClient], _get_langfuse_client)
231+
self.client: LangfuseClient = client_factory()
223232

224233
def get_evaluation_rows(
225234
self,

eval_protocol/adapters/langsmith.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,23 @@
1010
from __future__ import annotations
1111

1212
import logging
13-
from typing import Any, Dict, List, Optional, Iterable
13+
from typing import Any, Callable, Dict, Iterable, List, Optional, cast
1414

1515
from eval_protocol.models import EvaluationRow, InputMetadata, Message
1616
from .base import BaseAdapter
1717

1818
logger = logging.getLogger(__name__)
1919

20+
LangSmithClient = Any
21+
22+
_LANGSMITH_CLIENT_CTOR: Callable[..., LangSmithClient] | None
23+
2024
try:
21-
from langsmith import Client # type: ignore
25+
from langsmith import Client as _LANGSMITH_CLIENT_CTOR # type: ignore[attr-defined]
2226

2327
LANGSMITH_AVAILABLE = True
24-
except ImportError:
28+
except ImportError: # pragma: no cover - optional dependency
29+
_LANGSMITH_CLIENT_CTOR = None
2530
LANGSMITH_AVAILABLE = False
2631

2732

@@ -35,10 +40,17 @@ class LangSmithAdapter(BaseAdapter):
3540
- outputs: { messages: [...] } | { content } | { result } | { answer } | { output } | str | list[dict]
3641
"""
3742

38-
def __init__(self, client: Optional[Client] = None) -> None:
43+
def __init__(self, client: Optional[LangSmithClient] = None) -> None:
3944
if not LANGSMITH_AVAILABLE:
4045
raise ImportError("LangSmith not installed. Install with: pip install 'eval-protocol[langsmith]'")
41-
self.client = client or Client()
46+
if client is not None:
47+
self.client = client
48+
return
49+
50+
if _LANGSMITH_CLIENT_CTOR is None:
51+
raise ImportError("LangSmith client constructor unavailable despite successful import check")
52+
53+
self.client: LangSmithClient = cast(LangSmithClient, _LANGSMITH_CLIENT_CTOR())
4254

4355
def get_evaluation_rows(
4456
self,

eval_protocol/mcp_servers/tau2/airplane_environment/airline_environment.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import json
1010
import logging
1111
import os
12-
import time
13-
from copy import deepcopy
12+
from functools import lru_cache
1413
from dataclasses import dataclass, field
1514
from enum import Enum
1615
from pathlib import Path
@@ -24,6 +23,16 @@
2423
from vendor.tau2.domains.airline.utils import AIRLINE_DB_PATH
2524

2625

26+
@lru_cache(maxsize=1)
27+
def _load_flight_db(path: str) -> FlightDB:
28+
"""Load and cache the flight database for reuse across resets."""
29+
30+
logger.info("🗂️ Loading airline database from disk (cached)")
31+
db_loaded = FlightDB.load(path)
32+
assert isinstance(db_loaded, FlightDB)
33+
return db_loaded
34+
35+
2736
class AirlineEnvironment:
2837
"""
2938
Airline environment that integrates τ²-Bench simulation pattern
@@ -37,13 +46,10 @@ def __init__(self, config: Optional[Dict[str, Any]] = None):
3746

3847
def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
3948
"""Reset the environment to initial state"""
40-
logger.info("🔄 Resetting airline environment - reloading database from disk")
41-
# FlightDB.load expects a str path
42-
# Ensure type matches expected FlightDB
43-
# FlightDB.load returns vendor.tau2.domains.airline.data_model.FlightDB which is compatible
44-
db_loaded = FlightDB.load(str(AIRLINE_DB_PATH))
45-
assert isinstance(db_loaded, FlightDB)
46-
self.db = db_loaded
49+
logger.info("🔄 Resetting airline environment - using cached airline database")
50+
cached_db = _load_flight_db(str(AIRLINE_DB_PATH))
51+
# Provide a fresh copy for each environment reset without re-reading from disk.
52+
self.db = cached_db.model_copy(deep=True)
4753
self.airline_tools = AirlineTools(self.db)
4854

4955
return {}, {}

tests/pytest/test_mcp_session_autocreate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,24 @@ async def test_tool_call_returns_json_without_prior_initial_state():
3131
try:
3232
base_url = "http://127.0.0.1:9780/mcp"
3333
client = httpx.Client(timeout=1.0)
34-
deadline = time.time() + 20
34+
start_time = time.time()
35+
deadline = start_time + 20
36+
ready_time = None
3537
while time.time() < deadline:
3638
try:
3739
r = client.get(base_url)
3840
if r.status_code in (200, 307, 406):
41+
ready_time = time.time()
3942
break
4043
except Exception:
4144
pass
4245
time.sleep(0.2)
4346
else:
4447
pytest.fail("Server did not start on port 9780 in time")
4548

49+
assert ready_time is not None, "Server did not return a successful status before exiting loop"
50+
assert ready_time - start_time < 20, f"Server took too long to respond: {ready_time - start_time:.2f}s"
51+
4652
session = MCPSession(base_url=base_url, session_id="test-autocreate", seed=None, model_id="test-model")
4753

4854
mgr = MCPConnectionManager()

0 commit comments

Comments
 (0)