Skip to content

Commit 003a140

Browse files
authored
Reduce Pip Install and Import Time (#244)
* pinning litellm * lower and upper bound * lower and upper bound * just upper bound * remove pandas in code * fixed imports and pip speeds
1 parent 33185d7 commit 003a140

File tree

9 files changed

+79
-97
lines changed

9 files changed

+79
-97
lines changed

eval_protocol/__init__.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,6 @@
2323
test_mcp,
2424
)
2525
from .data_loader import DynamicDataLoader, InlineDataLoader
26-
27-
# Try to import FireworksPolicy if available
28-
try:
29-
from .mcp_env import FireworksPolicy
30-
31-
_FIREWORKS_AVAILABLE = True
32-
except (ImportError, AttributeError):
33-
_FIREWORKS_AVAILABLE = False
34-
# Import submodules to make them available via eval_protocol.rewards, etc.
3526
from . import mcp, rewards
3627
from .models import EvaluateResult, Message, MetricResult, EvaluationRow, InputMetadata
3728
from .playback_policy import PlaybackPolicyBase
@@ -42,6 +33,13 @@
4233
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor
4334
from .pytest.parameterize import DefaultParameterIdGenerator
4435

36+
from .types.remote_rollout_processor import (
37+
InitRequest,
38+
RolloutMetadata,
39+
StatusResponse,
40+
create_langfuse_config_tags,
41+
)
42+
4543
try:
4644
from .adapters import OpenAIResponsesAdapter
4745
except ImportError:
@@ -62,14 +60,6 @@
6260
except ImportError:
6361
LangSmithAdapter = None
6462

65-
# Remote server types
66-
from .types.remote_rollout_processor import (
67-
InitRequest,
68-
RolloutMetadata,
69-
StatusResponse,
70-
create_langfuse_config_tags,
71-
)
72-
7363
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
7464

7565
__all__ = [

eval_protocol/adapters/huggingface.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@
1313
logger = logging.getLogger(__name__)
1414

1515
try:
16-
from datasets import Dataset, DatasetDict, load_dataset
17-
18-
DATASETS_AVAILABLE = True
16+
from datasets import Dataset, DatasetDict, load_dataset # pyright: ignore[reportAttributeAccessIssue]
1917
except ImportError:
20-
DATASETS_AVAILABLE = False
21-
logger.warning("HuggingFace datasets not installed. Install with: pip install 'eval-protocol[huggingface]'")
18+
raise ImportError("HuggingFace datasets not installed. Install with: pip install 'eval-protocol[huggingface]'")
2219

2320
# Type alias for transformation function
2421
TransformFunction = Callable[[Dict[str, Any]], Dict[str, Any]]
@@ -80,11 +77,6 @@ def __init__(
8077
revision: Optional dataset revision/commit hash
8178
**load_dataset_kwargs: Additional arguments to pass to load_dataset
8279
"""
83-
if not DATASETS_AVAILABLE:
84-
raise ImportError(
85-
"HuggingFace datasets not installed. Install with: pip install 'eval-protocol[huggingface]'"
86-
)
87-
8880
self.dataset_id = dataset_id
8981
self.transform_fn = transform_fn
9082
self.config_name = config_name

eval_protocol/execution/pipeline.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import aiohttp
1414
import hydra
15-
from datasets import Dataset, DatasetDict
1615
from hydra.errors import InstantiationException
1716
from omegaconf import DictConfig, OmegaConf
1817

@@ -24,6 +23,14 @@
2423
from eval_protocol.utils.module_loader import load_function as load_reward_function
2524
from eval_protocol.utils.packaging_utils import install_requirements
2625

26+
try:
27+
from datasets import Dataset, DatasetDict # pyright: ignore[reportAttributeAccessIssue]
28+
except ImportError:
29+
raise ImportError(
30+
"The 'datasets' package is required to use this function. "
31+
"Please install it with 'pip install \"eval-protocol[huggingface]\"'"
32+
)
33+
2734
logger = logging.getLogger(__name__)
2835

2936

eval_protocol/mcp/execution/manager.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
import anyio
1818
from openai.types import CompletionUsage
1919

20-
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
21-
from vendor.tau2.user.user_simulator import UserSimulator
22-
2320
from ...models import EvaluationRow, InputMetadata, Message, Status
2421
from ...types import TerminationReason, Trajectory, NonSkippableException
2522

@@ -234,6 +231,10 @@ def extract_text_content(msg_dict):
234231

235232
# If user simulation is enabled, initial message is from the simulated user
236233
if dataset_row.user_simulation and dataset_row.user_simulation.get("enabled", False):
234+
# Lazy import vendor.tau2 - only load when user simulation is actually used
235+
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
236+
from vendor.tau2.user.user_simulator import UserSimulator
237+
237238
user_simulator = UserSimulator(
238239
instructions=dataset_row.user_simulation.get("system_prompt"),
239240
llm=dataset_row.user_simulation.get("llm", "gpt-4.1"),
@@ -598,6 +599,9 @@ def _get_user_simulator_messages(self, conversation_history: List[Dict[str, Any]
598599
"""
599600
Filter conversation history for user simulator and convert to tau2-bench format.
600601
"""
602+
# Lazy import vendor.tau2 types
603+
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
604+
601605
tau2_messages = []
602606

603607
for message in conversation_history:

eval_protocol/pytest/utils.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030
import logging
3131
import json
32-
import pandas as pd
32+
import random
33+
import statistics
3334

3435

3536
AggregationMethod = Literal["mean", "max", "min", "bootstrap"]
@@ -122,30 +123,25 @@ async def execute_run_with_progress(run_idx: int, config: RolloutProcessorConfig
122123
raise
123124

124125

125-
def calculate_bootstrap_scores(all_scores: list[float]) -> float:
126+
def calculate_bootstrap_scores(all_scores: list[float], n_boot: int = 100, seed: int | None = None) -> float:
126127
"""
127-
Calculate bootstrap confidence intervals for individual scores.
128+
Calculate the mean of bootstrap sample means for a list of scores.
128129
129130
Args:
130-
all_scores: List of individual scores from all rows
131+
all_scores: List of individual scores from all rows.
132+
n_boot: Number of bootstrap resamples to draw (default 100).
133+
seed: Optional RNG seed for reproducibility.
131134
132135
Returns:
133-
Mean bootstrap score
136+
Mean bootstrap score (float). Returns 0.0 if all_scores is empty.
134137
"""
135138
if not all_scores:
136139
return 0.0
137140

138-
# Create DataFrame (single column of scores)
139-
battles = pd.DataFrame({"score": all_scores})
140-
141-
# Bootstrap sampling for calculating relative performance
142-
bootstrap_means = [battles.sample(frac=1.0, replace=True)["score"].mean() for _ in range(100)]
143-
144-
# Calculate final scores
145-
bootstraps = pd.Series(bootstrap_means)
146-
mean_score = bootstraps.mean()
147-
148-
return float(mean_score)
141+
rng = random.Random(seed) if seed is not None else random
142+
k = len(all_scores)
143+
bootstrap_means = [statistics.fmean(rng.choices(all_scores, k=k)) for _ in range(n_boot)]
144+
return float(statistics.fmean(bootstrap_means))
149145

150146

151147
def aggregate(scores: list[float], method: AggregationMethod) -> float:

eval_protocol/quickstart/utils.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,10 @@
33
"""
44

55
import os
6-
from datetime import datetime
76
import re
8-
from typing import List, Dict, Any, Optional
9-
from openai import AsyncOpenAI
10-
import pandas as pd
7+
from typing import Dict, Any, Optional
118

12-
from eval_protocol.models import EvaluationRow, Message, EvaluateResult, MetricResult
13-
import asyncio
14-
from openai import OpenAI
9+
from eval_protocol.models import EvaluationRow, Message
1510

1611
OG_ARENA_HARD_PROMPT = """Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user prompt displayed below. You will be given assistant A's answer and assistant B's answer. Your job is to evaluate which assistant's answer is better.
1712

pyproject.toml

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,26 @@ dependencies = [
2727
"aiohttp",
2828
"mcp>=1.9.2",
2929
"PyYAML>=5.0",
30-
# Pin minimum datasets to avoid pyarrow API mismatch (PyExtensionType removal in pyarrow>=21)
31-
"datasets>=3.0.0",
32-
"fsspec",
3330
"hydra-core>=1.3.2",
3431
"omegaconf>=2.3.0",
35-
"gymnasium>=0.29.0",
3632
"httpx>=0.24.0",
3733
"anthropic>=0.59.0",
38-
"ipykernel>=6.30.0",
39-
"jupyter>=1.1.1",
34+
"litellm<1.75.0",
35+
"pytest>=6.0.0",
36+
"pytest-asyncio>=0.21.0",
37+
"peewee>=3.18.2",
38+
"backoff>=2.2.0",
39+
"questionary>=2.0.0",
4040
# Dependencies for vendored tau2 package
4141
"toml>=0.10.0",
4242
"loguru>=0.6.0",
4343
"docstring-parser>=0.15",
4444
"rich>=12.0.0",
4545
"psutil>=5.8.0",
46-
"litellm<1.75.0",
4746
"addict>=2.4.0",
4847
"deepdiff>=6.0.0",
49-
"pandas>=1.5.0",
5048
"websockets>=15.0.1",
5149
"fastapi>=0.116.1",
52-
"pytest>=6.0.0",
53-
"pytest-asyncio>=0.21.0",
54-
"peewee>=3.18.2",
55-
"backoff>=2.2.0",
56-
"questionary>=2.0.0",
5750
]
5851

5952
[project.urls]
@@ -67,6 +60,7 @@ dev = [
6760
"werkzeug>=2.0.0",
6861
"ruff>=0.5.0",
6962
"transformers>=4.0.0",
63+
"pandas>=1.5.0",
7064
"types-setuptools",
7165
"types-requests",
7266
"types-PyYAML",
@@ -110,12 +104,6 @@ huggingface = [
110104
"datasets>=3.0.0",
111105
"transformers>=4.0.0",
112106
]
113-
adapters = [
114-
"langfuse>=2.0.0",
115-
# Keep in sync with core dependency to ensure compatibility with latest pyarrow
116-
"datasets>=3.0.0",
117-
"transformers>=4.0.0",
118-
]
119107
langsmith = [
120108
"langsmith>=0.1.86",
121109
]

tests/test_evaluation_postprocess.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,33 @@ def test_all_invalid_scores(self):
208208
assert mock_logger.log.call_count == 2
209209

210210

211+
class TestBootstrapEquivalence:
212+
def test_bootstrap_equivalence_pandas_vs_pure_python(self):
213+
import random
214+
import pandas as pd
215+
from eval_protocol.pytest.utils import calculate_bootstrap_scores as py_bootstrap
216+
217+
# Deterministic synthetic scores
218+
rng = random.Random(123)
219+
scores = [rng.random() for _ in range(100)]
220+
221+
n_boot = 1000
222+
seed = 42
223+
224+
# Old (pandas) style bootstrap: resample full column with replacement
225+
df = pd.DataFrame({"score": scores})
226+
pandas_means = [
227+
df.sample(frac=1.0, replace=True, random_state=seed + i)["score"].mean() for i in range(n_boot)
228+
]
229+
pandas_boot_mean = sum(pandas_means) / len(pandas_means)
230+
231+
# New pure-python implementation
232+
py_boot_mean = py_bootstrap(scores, n_boot=n_boot, seed=seed)
233+
234+
# They estimate the same quantity; allow small Monte Carlo tolerance
235+
assert abs(pandas_boot_mean - py_boot_mean) < 0.02
236+
237+
211238
class TestComputeFixedSetMuCi:
212239
"""Tests for compute_fixed_set_mu_ci function."""
213240

0 commit comments

Comments
 (0)