Skip to content

Commit 249358d

Browse files
committed
lazy loading
1 parent ce6277f commit 249358d

File tree

5 files changed

+241
-101
lines changed

5 files changed

+241
-101
lines changed

eval_protocol/pytest/__init__.py

Lines changed: 137 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,157 @@
1-
from .default_agent_rollout_processor import AgentRolloutProcessor
2-
from .default_dataset_adapter import default_dataset_adapter
3-
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
4-
from .default_no_op_rollout_processor import NoOpRolloutProcessor
5-
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
6-
from .remote_rollout_processor import RemoteRolloutProcessor
7-
from .github_action_rollout_processor import GithubActionRolloutProcessor
8-
from .evaluation_test import evaluation_test
9-
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
10-
from .rollout_processor import RolloutProcessor
11-
from .rollout_result_post_processor import RolloutResultPostProcessor, NoOpRolloutResultPostProcessor
12-
from .types import RolloutProcessorConfig
13-
14-
# Conditional import for optional Klavis dependency
15-
try:
16-
from .default_klavis_sandbox_rollout_processor import KlavisSandboxRolloutProcessor
17-
18-
KLAVIS_AVAILABLE = True
19-
except ImportError:
20-
KLAVIS_AVAILABLE = False
21-
KlavisSandboxRolloutProcessor = None
22-
23-
# Conditional import for optional dependencies
24-
try:
25-
from .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
26-
27-
PYDANTIC_AI_AVAILABLE = True
28-
except ImportError:
29-
PYDANTIC_AI_AVAILABLE = False
30-
PydanticAgentRolloutProcessor = None
31-
32-
# Conditional import for optional LangChain dependency
33-
try:
34-
from .default_langchain_rollout_processor import LangGraphRolloutProcessor
35-
36-
LANGCHAIN_AVAILABLE = True
37-
except ImportError:
38-
LANGCHAIN_AVAILABLE = False
39-
LangGraphRolloutProcessor = None
1+
"""
2+
eval_protocol.pytest - Pytest integration for evaluation testing.
3+
4+
This module uses lazy loading to minimize import time.
5+
Heavy dependencies (litellm, torch, etc.) are only loaded when needed.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import importlib
11+
from typing import TYPE_CHECKING
12+
13+
# Lazy imports mapping: name -> (module_path, attr_name)
14+
# These are loaded on-demand when accessed
15+
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
16+
# Rollout processors
17+
"AgentRolloutProcessor": (".default_agent_rollout_processor", "AgentRolloutProcessor"),
18+
"MCPGymRolloutProcessor": (".default_mcp_gym_rollout_processor", "MCPGymRolloutProcessor"),
19+
"NoOpRolloutProcessor": (".default_no_op_rollout_processor", "NoOpRolloutProcessor"),
20+
"SingleTurnRolloutProcessor": (".default_single_turn_rollout_process", "SingleTurnRolloutProcessor"),
21+
"RemoteRolloutProcessor": (".remote_rollout_processor", "RemoteRolloutProcessor"),
22+
"GithubActionRolloutProcessor": (".github_action_rollout_processor", "GithubActionRolloutProcessor"),
23+
"RolloutProcessor": (".rollout_processor", "RolloutProcessor"),
24+
# Dataset adapter
25+
"default_dataset_adapter": (".default_dataset_adapter", "default_dataset_adapter"),
26+
# Core decorator
27+
"evaluation_test": (".evaluation_test", "evaluation_test"),
28+
# Exception handling
29+
"ExceptionHandlerConfig": (".exception_config", "ExceptionHandlerConfig"),
30+
"BackoffConfig": (".exception_config", "BackoffConfig"),
31+
"get_default_exception_handler_config": (".exception_config", "get_default_exception_handler_config"),
32+
# Post processors
33+
"RolloutResultPostProcessor": (".rollout_result_post_processor", "RolloutResultPostProcessor"),
34+
"NoOpRolloutResultPostProcessor": (".rollout_result_post_processor", "NoOpRolloutResultPostProcessor"),
35+
# Types
36+
"RolloutProcessorConfig": (".types", "RolloutProcessorConfig"),
37+
}
38+
39+
# Optional imports that may not be available
40+
_OPTIONAL_IMPORTS: dict[str, tuple[str, str]] = {
41+
"KlavisSandboxRolloutProcessor": (".default_klavis_sandbox_rollout_processor", "KlavisSandboxRolloutProcessor"),
42+
"PydanticAgentRolloutProcessor": (".default_pydantic_ai_rollout_processor", "PydanticAgentRolloutProcessor"),
43+
"LangGraphRolloutProcessor": (".default_langchain_rollout_processor", "LangGraphRolloutProcessor"),
44+
}
45+
46+
# Track which optional imports are available (set on first access)
47+
_optional_availability: dict[str, bool] = {}
48+
49+
50+
def __getattr__(name: str):
51+
"""Lazy load attributes on first access."""
52+
# Handle lazy imports
53+
if name in _LAZY_IMPORTS:
54+
module_path, attr_name = _LAZY_IMPORTS[name]
55+
module = importlib.import_module(module_path, package="eval_protocol.pytest")
56+
value = getattr(module, attr_name)
57+
# Cache in module namespace for future access
58+
globals()[name] = value
59+
return value
60+
61+
# Handle optional imports
62+
if name in _OPTIONAL_IMPORTS:
63+
module_path, attr_name = _OPTIONAL_IMPORTS[name]
64+
try:
65+
module = importlib.import_module(module_path, package="eval_protocol.pytest")
66+
value = getattr(module, attr_name)
67+
globals()[name] = value
68+
_optional_availability[name] = True
69+
return value
70+
except ImportError:
71+
_optional_availability[name] = False
72+
return None
73+
74+
# Handle availability flags
75+
if name == "KLAVIS_AVAILABLE":
76+
if "KlavisSandboxRolloutProcessor" not in _optional_availability:
77+
# Trigger the import to check availability
78+
__getattr__("KlavisSandboxRolloutProcessor")
79+
return _optional_availability.get("KlavisSandboxRolloutProcessor", False)
80+
81+
if name == "PYDANTIC_AI_AVAILABLE":
82+
if "PydanticAgentRolloutProcessor" not in _optional_availability:
83+
__getattr__("PydanticAgentRolloutProcessor")
84+
return _optional_availability.get("PydanticAgentRolloutProcessor", False)
85+
86+
if name == "LANGCHAIN_AVAILABLE":
87+
if "LangGraphRolloutProcessor" not in _optional_availability:
88+
__getattr__("LangGraphRolloutProcessor")
89+
return _optional_availability.get("LangGraphRolloutProcessor", False)
90+
91+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
92+
93+
94+
def __dir__():
95+
"""List available attributes for tab completion."""
96+
return list(__all__) + ["KLAVIS_AVAILABLE", "PYDANTIC_AI_AVAILABLE", "LANGCHAIN_AVAILABLE"]
97+
4098

4199
__all__ = [
100+
# Rollout processors
42101
"AgentRolloutProcessor",
43102
"MCPGymRolloutProcessor",
44103
"RolloutProcessor",
45104
"SingleTurnRolloutProcessor",
46105
"RemoteRolloutProcessor",
47106
"GithubActionRolloutProcessor",
48107
"NoOpRolloutProcessor",
108+
# Dataset
49109
"default_dataset_adapter",
110+
# Types
50111
"RolloutProcessorConfig",
112+
# Core
51113
"evaluation_test",
114+
# Exception handling
52115
"ExceptionHandlerConfig",
53116
"BackoffConfig",
54117
"get_default_exception_handler_config",
118+
# Post processors
55119
"RolloutResultPostProcessor",
56120
"NoOpRolloutResultPostProcessor",
121+
# Optional (may be None if dependencies not installed)
122+
"KlavisSandboxRolloutProcessor",
123+
"PydanticAgentRolloutProcessor",
124+
"LangGraphRolloutProcessor",
57125
]
58126

59-
# Only add to __all__ if available
60-
if KLAVIS_AVAILABLE:
61-
__all__.append("KlavisSandboxRolloutProcessor")
62-
63-
# Only add to __all__ if available
64-
if PYDANTIC_AI_AVAILABLE:
65-
__all__.append("PydanticAgentRolloutProcessor")
66127

67-
if LANGCHAIN_AVAILABLE:
68-
__all__.append("LangGraphRolloutProcessor")
128+
# Type hints for IDE support (not executed at runtime)
129+
if TYPE_CHECKING:
130+
from .default_agent_rollout_processor import AgentRolloutProcessor as AgentRolloutProcessor
131+
from .default_dataset_adapter import default_dataset_adapter as default_dataset_adapter
132+
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor as MCPGymRolloutProcessor
133+
from .default_no_op_rollout_processor import NoOpRolloutProcessor as NoOpRolloutProcessor
134+
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor as SingleTurnRolloutProcessor
135+
from .remote_rollout_processor import RemoteRolloutProcessor as RemoteRolloutProcessor
136+
from .github_action_rollout_processor import GithubActionRolloutProcessor as GithubActionRolloutProcessor
137+
from .evaluation_test import evaluation_test as evaluation_test
138+
from .exception_config import (
139+
ExceptionHandlerConfig as ExceptionHandlerConfig,
140+
BackoffConfig as BackoffConfig,
141+
get_default_exception_handler_config as get_default_exception_handler_config,
142+
)
143+
from .rollout_processor import RolloutProcessor as RolloutProcessor
144+
from .rollout_result_post_processor import (
145+
RolloutResultPostProcessor as RolloutResultPostProcessor,
146+
NoOpRolloutResultPostProcessor as NoOpRolloutResultPostProcessor,
147+
)
148+
from .types import RolloutProcessorConfig as RolloutProcessorConfig
149+
from .default_klavis_sandbox_rollout_processor import (
150+
KlavisSandboxRolloutProcessor as KlavisSandboxRolloutProcessor,
151+
)
152+
from .default_pydantic_ai_rollout_processor import (
153+
PydanticAgentRolloutProcessor as PydanticAgentRolloutProcessor,
154+
)
155+
from .default_langchain_rollout_processor import (
156+
LangGraphRolloutProcessor as LangGraphRolloutProcessor,
157+
)

eval_protocol/pytest/evaluation_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
from eval_protocol.pytest.parameterize import pytest_parametrize, create_dynamically_parameterized_wrapper
3535
from eval_protocol.pytest.validate_signature import validate_signature
3636
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
37-
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
37+
38+
# Note: MCPGymRolloutProcessor and SingleTurnRolloutProcessor are imported lazily to avoid loading litellm (~1300ms)
3839
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
39-
from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor
4040
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig
4141
from eval_protocol.pytest.rollout_processor import RolloutProcessor
4242
from eval_protocol.pytest.types import (
@@ -188,6 +188,9 @@ def evaluation_test(
188188
if os.environ.get("EP_USE_NO_OP_ROLLOUT_PROCESSOR") == "1":
189189
rollout_processor = NoOpRolloutProcessor()
190190
elif rollout_processor is None:
191+
# Lazy import to avoid loading litellm at decorator definition time
192+
from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor
193+
191194
rollout_processor = SingleTurnRolloutProcessor()
192195

193196
active_logger: DatasetLogger = logger if logger else default_logger
@@ -411,6 +414,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
411414

412415
rollout_processor.setup()
413416

417+
# Lazy import to avoid loading litellm at module load time
418+
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
419+
414420
use_priority_scheduler = os.environ.get(
415421
"EP_USE_PRIORITY_SCHEDULER", "0"
416422
) == "1" and not isinstance(rollout_processor, MCPGymRolloutProcessor)
@@ -689,6 +695,9 @@ async def _collect_result(config, lst):
689695

690696
# if rollout_processor is McpGymRolloutProcessor, we execute runs sequentially since McpGym does not support concurrent runs
691697
# else, we execute runs in parallel
698+
# Lazy import (cached after first import above)
699+
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
700+
692701
if isinstance(rollout_processor, MCPGymRolloutProcessor):
693702
# For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
694703
for run_idx in range(num_runs):

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from dataclasses import replace
77
from typing import Any, Literal, Callable, AsyncGenerator, Optional
88

9-
from litellm.cost_calculator import cost_per_token
109
from tqdm import tqdm
1110

11+
# Note: litellm.cost_calculator.cost_per_token is imported lazily in add_cost_metrics()
12+
# to avoid ~1300ms import time at module load
13+
1214
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1315
from eval_protocol.models import (
1416
CostMetrics,
@@ -22,7 +24,8 @@
2224
from eval_protocol.data_loader import DynamicDataLoader
2325
from eval_protocol.data_loader.models import EvaluationDataLoader
2426
from eval_protocol.pytest.rollout_processor import RolloutProcessor
25-
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
27+
28+
# Note: MCPGymRolloutProcessor is imported lazily in validate_config_for_processor() to avoid loading litellm
2629
from eval_protocol.pytest.types import (
2730
RolloutProcessorConfig,
2831
ServerMode,
@@ -542,6 +545,9 @@ def add_cost_metrics(row: EvaluationRow) -> None:
542545

543546
# Try to calculate costs, but gracefully handle unknown models
544547
try:
548+
# Lazy import to avoid ~1300ms import time at module load
549+
from litellm.cost_calculator import cost_per_token
550+
545551
input_cost, output_cost = cost_per_token(
546552
model=model_id, prompt_tokens=input_tokens, completion_tokens=output_tokens
547553
)
@@ -596,6 +602,9 @@ def build_rollout_processor_config(
596602

597603
completion_params = {"model": model, "temperature": temperature, "max_tokens": max_tokens}
598604

605+
# Lazy import to avoid loading litellm at module load time
606+
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
607+
599608
if isinstance(rollout_processor, MCPGymRolloutProcessor):
600609
base_kwargs = {**(rollout_processor_kwargs or {}), "start_server": start_server}
601610
if server_mode is not None and "server_mode" not in base_kwargs:

0 commit comments

Comments
 (0)