Skip to content

Commit f333d81

Browse files
author
Dylan Huang
authored
eval_protocol.data_loader (#210)
* savev * save * DynamicDataLoader and InlineDataLoader * use dynamic data loader wherever tests are being collected * fix test_parametrized_ids * fix pytest collection printing out a bunch of local UI urls * only print local ui URL once * remove unused imports
1 parent 6b31354 commit f333d81

24 files changed

+588
-165
lines changed

eval_protocol/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
rollout,
2323
test_mcp,
2424
)
25+
from .data_loader import DynamicDataLoader, InlineDataLoader
2526

2627
# Try to import FireworksPolicy if available
2728
try:
@@ -66,6 +67,8 @@
6667

6768
__all__ = [
6869
"DefaultParameterIdGenerator",
70+
"DynamicDataLoader",
71+
"InlineDataLoader",
6972
"aha_judge",
7073
"multi_turn_assistant_to_ground_truth",
7174
"assistant_to_ground_truth",
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .dynamic_data_loader import DynamicDataLoader
2+
from .inline_data_loader import InlineDataLoader
3+
4+
__all__ = ["DynamicDataLoader", "InlineDataLoader"]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from collections.abc import Callable, Sequence
2+
from dataclasses import dataclass
3+
4+
from eval_protocol.data_loader.models import (
5+
DataLoaderResult,
6+
DataLoaderVariant,
7+
EvaluationDataLoader,
8+
)
9+
from eval_protocol.models import EvaluationRow
10+
11+
12+
@dataclass(kw_only=True)
13+
class DynamicDataLoader(EvaluationDataLoader):
14+
"""Data loader for dynamic data generation."""
15+
16+
generators: Sequence[Callable[[], list[EvaluationRow]]]
17+
"""Dynamic data generation functions. These callables are invoked each time data
18+
needs to be loaded, allowing for dynamic data generation, lazy loading, or data that
19+
changes between evaluation runs. Each function should return a list of EvaluationRow
20+
objects. This is useful for scenarios like generating test data on-the-fly, loading
21+
data from external sources, or creating data with randomized elements for robust testing."""
22+
23+
def variants(self) -> Sequence[DataLoaderVariant]:
24+
variants: Sequence[DataLoaderVariant] = []
25+
for generator in self.generators:
26+
27+
def _load() -> DataLoaderResult:
28+
resolved_rows = generator()
29+
return DataLoaderResult(
30+
rows=resolved_rows,
31+
type=self.__class__.__name__,
32+
variant_id=generator.__name__,
33+
variant_description=generator.__doc__,
34+
)
35+
36+
variants.append(_load)
37+
38+
return variants
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from collections.abc import Callable, Sequence
2+
from dataclasses import dataclass
3+
4+
from eval_protocol.data_loader.models import (
5+
DataLoaderResult,
6+
DataLoaderVariant,
7+
EvaluationDataLoader,
8+
)
9+
from eval_protocol.models import EvaluationRow
10+
11+
12+
@dataclass(kw_only=True)
13+
class DynamicDataLoader(EvaluationDataLoader):
14+
"""Data loader for dynamic data generation."""
15+
16+
factory: Sequence[Callable[[], list[EvaluationRow]]]
17+
"""Dynamic data generation functions. These callables are invoked each time data
18+
needs to be loaded, allowing for dynamic data generation, lazy loading, or data that
19+
changes between evaluation runs. Each function should return a list of EvaluationRow
20+
objects. This is useful for scenarios like generating test data on-the-fly, loading
21+
data from external sources, or creating data with randomized elements for robust testing."""
22+
23+
def variants(self) -> Sequence[DataLoaderVariant]:
24+
variants: Sequence[DataLoaderVariant] = []
25+
for factory in self.factory:
26+
27+
def _load() -> DataLoaderResult:
28+
resolved_rows = factory()
29+
return DataLoaderResult(
30+
rows=resolved_rows,
31+
type=self.__class__.__name__,
32+
variant_id=factory.__name__,
33+
variant_description=factory.__doc__,
34+
)
35+
36+
variants.append(_load)
37+
38+
return variants
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from collections.abc import Sequence
2+
from dataclasses import dataclass
3+
4+
from eval_protocol.data_loader.models import (
5+
DataLoaderResult,
6+
DataLoaderVariant,
7+
EvaluationDataLoader,
8+
)
9+
from eval_protocol.models import EvaluationRow, Message
10+
from eval_protocol.pytest.types import InputMessagesParam
11+
12+
13+
DEFAULT_VARIANT_ID: str = "inline"
14+
15+
16+
@dataclass(kw_only=True)
17+
class InlineDataLoader(EvaluationDataLoader):
18+
"""Data loader for inline ``EvaluationRow`` or message payloads."""
19+
20+
rows: list[EvaluationRow] | None = None
21+
"""Pre-defined evaluation rows with tools and metadata. Use this when you have complete
22+
EvaluationRow objects that include tools, input_metadata, and other structured data.
23+
This is the preferred option when working with tool-calling scenarios or when you need
24+
to provide additional metadata like row_id, dataset information, or custom fields."""
25+
26+
messages: Sequence[InputMessagesParam] | None = None
27+
"""Raw chat completion message history. Use this when you only have simple
28+
conversation history without tools or additional metadata. The messages will be
29+
automatically converted to EvaluationRow objects. InputMessagesParam is a list of
30+
Message objects representing the conversation flow (user, assistant, system messages)."""
31+
32+
id: str = DEFAULT_VARIANT_ID
33+
"""Unique identifier for this data loader variant. Used to label and distinguish
34+
different input data sources, versions, or configurations. This helps with tracking
35+
and organizing evaluation results from different data sources."""
36+
37+
description: str | None = None
38+
"""Optional human-readable description of this data loader. Provides additional
39+
context about the data source, purpose, or any special characteristics. Used for
40+
documentation and debugging purposes. If not provided, the variant_id will be used instead."""
41+
42+
def __post_init__(self) -> None:
43+
if self.rows is None and self.messages is None:
44+
raise ValueError("InlineDataLoader requires rows or messages to be provided")
45+
46+
def variants(self) -> Sequence[DataLoaderVariant]:
47+
def _load() -> DataLoaderResult:
48+
resolved_rows: list[EvaluationRow] = []
49+
if self.rows is not None:
50+
resolved_rows = [row.model_copy(deep=True) for row in self.rows]
51+
if self.messages is not None:
52+
for dataset_messages in self.messages:
53+
row_messages: list[Message] = []
54+
for msg in dataset_messages:
55+
if isinstance(msg, Message):
56+
row_messages.append(msg.model_copy(deep=True))
57+
else:
58+
row_messages.append(Message.model_validate(msg))
59+
resolved_rows.append(EvaluationRow(messages=row_messages))
60+
61+
return DataLoaderResult(
62+
rows=resolved_rows,
63+
variant_id=self.id,
64+
variant_description=self.description,
65+
type=self.__class__.__name__,
66+
)
67+
68+
return [_load]
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Data loader abstractions"""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Sequence
6+
from dataclasses import dataclass
7+
from typing import Callable
8+
from typing_extensions import Protocol
9+
from abc import ABC, abstractmethod
10+
11+
from pydantic import BaseModel, Field, field_validator
12+
13+
from eval_protocol.models import EvaluationRow
14+
15+
16+
class DataLoaderResult(BaseModel):
17+
"""Rows and metadata returned by a loader variant."""
18+
19+
rows: list[EvaluationRow] = Field(
20+
description="List of evaluation rows loaded from the data source. These are the "
21+
"processed and ready-to-use evaluation data that will be fed into the evaluation pipeline."
22+
)
23+
24+
type: str = Field(
25+
...,
26+
description="Type of the data loader that produced this result. Used for identification "
27+
"and debugging purposes (e.g., 'InlineDataLoader', 'DynamicDataLoader').",
28+
)
29+
30+
variant_id: str = Field(
31+
...,
32+
description="Unique identifier for the data loader variant that produced this result. "
33+
"Used for tracking and organizing evaluation results from different data sources.",
34+
)
35+
36+
variant_description: str | None = Field(
37+
default=None,
38+
description="Human-readable description of the data loader variant that produced this result. "
39+
"Provides context about what this variant represents, its purpose, or any special characteristics that distinguish "
40+
"it from other variants.",
41+
)
42+
43+
preprocessed: bool = Field(
44+
default=False,
45+
description="Whether the data has been preprocessed. This flag indicates if any "
46+
"preprocessing functions have been applied to the data, helping to avoid duplicate "
47+
"processing and track data transformation state.",
48+
)
49+
50+
@field_validator("type")
51+
@classmethod
52+
def validate_type(cls, v: str) -> str:
53+
if not v or not v.strip():
54+
raise ValueError("type must be non-empty")
55+
return v
56+
57+
@field_validator("variant_id")
58+
@classmethod
59+
def validate_variant_id(cls, v: str) -> str:
60+
if not v or not v.strip():
61+
raise ValueError("variant_id must be non-empty")
62+
return v
63+
64+
65+
class DataLoaderVariant(Protocol):
66+
"""Single parameterizable variant from a data loader."""
67+
68+
def __call__(self) -> DataLoaderResult:
69+
"""Load a dataset for this variant using the provided context."""
70+
...
71+
72+
73+
@dataclass(kw_only=True)
74+
class EvaluationDataLoader(ABC):
75+
"""Abstract base class for data loaders that can be consumed by ``evaluation_test``."""
76+
77+
preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None = None
78+
"""Optional preprocessing function for evaluation rows. This function is applied
79+
to the loaded data before it's returned, allowing for data cleaning, transformation,
80+
filtering, or other modifications. The function receives a list of EvaluationRow objects
81+
and should return a modified list of EvaluationRow objects."""
82+
83+
@abstractmethod
84+
def variants(self) -> Sequence[DataLoaderVariant]:
85+
"""Return parameterizable variants emitted by this loader."""
86+
...
87+
88+
def load(self) -> list[DataLoaderResult]:
89+
"""Loads all variants of this data loader and return a list of DataLoaderResult."""
90+
results = []
91+
for variant in self.variants():
92+
result = variant()
93+
result = self._process_variant(result)
94+
results.append(result)
95+
return results
96+
97+
def _process_variant(self, result: DataLoaderResult) -> DataLoaderResult:
98+
"""Process a single variant: preprocess data and apply metadata."""
99+
# Preprocess data
100+
original_count = len(result.rows)
101+
if self.preprocess_fn:
102+
result.rows = self.preprocess_fn(result.rows)
103+
result.preprocessed = True
104+
processed_count = len(result.rows)
105+
else:
106+
processed_count = original_count
107+
108+
# Apply metadata to rows
109+
self._apply_metadata(result, original_count, processed_count)
110+
return result
111+
112+
def _apply_metadata(self, result: DataLoaderResult, original_count: int, processed_count: int) -> None:
113+
"""Apply metadata to all rows in the result."""
114+
for row in result.rows:
115+
if row.input_metadata.dataset_info is None:
116+
row.input_metadata.dataset_info = {}
117+
118+
# Apply result attributes as metadata
119+
for attr_name, attr_value in vars(result).items():
120+
"""
121+
Exclude rows and private attributes from metadata.
122+
"""
123+
if attr_name != "rows" and not attr_name.startswith("_"):
124+
row.input_metadata.dataset_info[f"data_loader_{attr_name}"] = attr_value
125+
126+
# Apply row counts
127+
row.input_metadata.dataset_info["data_loader_num_rows"] = original_count
128+
row.input_metadata.dataset_info["data_loader_num_rows_after_preprocessing"] = processed_count

0 commit comments

Comments
 (0)