Skip to content

Commit 68eea88

Browse files
author
Dylan Huang
committed
save
1 parent 71f588f commit 68eea88

File tree

4 files changed

+184
-43
lines changed

4 files changed

+184
-43
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from collections.abc import Callable, Sequence
2+
3+
from eval_protocol.data_loader.models import (
4+
DataLoaderContext,
5+
DataLoaderResult,
6+
DataLoaderVariant,
7+
EvaluationDataLoader,
8+
)
9+
from eval_protocol.models import EvaluationRow
10+
11+
12+
class FactoryDataLoader(EvaluationDataLoader):
13+
"""Data loader for factory of list[EvaluationRow]"""
14+
15+
description: str | None = None
16+
"""Optional human-readable description of this data loader. Provides additional
17+
context about the data source, purpose, or any special characteristics. Used for
18+
documentation and debugging purposes. If not provided, the variant_id will be used instead."""
19+
20+
factory: Sequence[Callable[[], list[EvaluationRow]]]
21+
"""Factory function that generates evaluation rows dynamically. This callable
22+
is invoked each time data needs to be loaded, allowing for dynamic data generation,
23+
lazy loading, or data that changes between evaluation runs. The factory should return
24+
a list of EvaluationRow objects. This is useful for scenarios like generating test
25+
data on-the-fly, loading data from external sources, or creating data with randomized
26+
elements for robust testing."""
27+
28+
def variants(self) -> Sequence[DataLoaderVariant]:
29+
variants: Sequence[DataLoaderVariant] = []
30+
for factory in self.factory:
31+
32+
def _load(ctx: DataLoaderContext) -> DataLoaderResult:
33+
resolved_rows = factory()
34+
return DataLoaderResult(
35+
rows=resolved_rows,
36+
num_rows=len(resolved_rows),
37+
type="factory",
38+
variant_id=ctx.variant_id,
39+
variant_description=ctx.variant_description,
40+
)
41+
42+
variants.append(
43+
DataLoaderVariant(
44+
id=factory.__name__,
45+
description=factory.__doc__,
46+
loader=_load,
47+
)
48+
)
49+
50+
return variants

eval_protocol/data_loader/inline_data_loader.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,33 @@
1010
from eval_protocol.pytest.types import InputMessagesParam
1111

1212

13+
DEFAULT_VARIANT_ID: str = "inline"
14+
15+
1316
class InlineDataLoader(EvaluationDataLoader):
1417
"""Data loader for inline ``EvaluationRow`` or message payloads."""
1518

16-
rows: Sequence[EvaluationRow] | None = None
17-
messages: Sequence[InputMessagesParam] | None = None
18-
variant_id: str = "inline"
19+
rows: list[EvaluationRow] | None = None
20+
"""Pre-defined evaluation rows with tools and metadata. Use this when you have complete
21+
EvaluationRow objects that include tools, input_metadata, and other structured data.
22+
This is the preferred option when working with tool-calling scenarios or when you need
23+
to provide additional metadata like row_id, dataset information, or custom fields."""
24+
25+
messages: InputMessagesParam | None = None
26+
"""Raw chat completion message history. Use this when you only have simple
27+
conversation history without tools or additional metadata. The messages will be
28+
automatically converted to EvaluationRow objects. InputMessagesParam is a list of
29+
Message objects representing the conversation flow (user, assistant, system messages)."""
30+
31+
variant_id: str = DEFAULT_VARIANT_ID
32+
"""Unique identifier for this data loader variant. Used to label and distinguish
33+
different input data sources, versions, or configurations. This helps with tracking
34+
and organizing evaluation results from different data sources."""
35+
1936
description: str | None = None
37+
"""Optional human-readable description of this data loader. Provides additional
38+
context about the data source, purpose, or any special characteristics. Used for
39+
documentation and debugging purposes. If not provided, the variant_id will be used instead."""
2040

2141
def __post_init__(self) -> None:
2242
if self.rows is None and self.messages is None:
@@ -26,7 +46,7 @@ def variants(self) -> Sequence[DataLoaderVariant]:
2646
def _load(ctx: DataLoaderContext) -> DataLoaderResult:
2747
resolved_rows: list[EvaluationRow] = []
2848
if self.rows is not None:
29-
resolved_rows.extend(row.model_copy(deep=True) for row in self.rows)
49+
resolved_rows = [row.model_copy(deep=True) for row in self.rows]
3050
if self.messages is not None:
3151
for dataset_messages in self.messages:
3252
row_messages: list[Message] = []
@@ -37,19 +57,11 @@ def _load(ctx: DataLoaderContext) -> DataLoaderResult:
3757
row_messages.append(Message.model_validate(msg))
3858
resolved_rows.append(EvaluationRow(messages=row_messages))
3959

40-
if ctx.max_rows is not None:
41-
resolved_rows = resolved_rows[: ctx.max_rows]
42-
43-
metadata = {
44-
"data_loader_variant_id": self.variant_id,
45-
"data_loader_type": "inline",
46-
"row_count": len(resolved_rows),
47-
}
48-
4960
return DataLoaderResult(
5061
rows=resolved_rows,
51-
source_id=self.variant_id,
52-
source_metadata=metadata,
62+
num_rows=len(resolved_rows),
63+
variant_id=ctx.variant_id,
64+
type=self.__class__.__name__,
5365
)
5466

5567
description = self.description or self.variant_id
@@ -58,6 +70,5 @@ def _load(ctx: DataLoaderContext) -> DataLoaderResult:
5870
id=self.variant_id,
5971
description=description,
6072
loader=_load,
61-
metadata={"type": "inline"},
6273
)
6374
]

eval_protocol/data_loader/models.py

Lines changed: 105 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,56 +3,129 @@
33
from __future__ import annotations
44

55
from collections.abc import Sequence
6-
from typing import Any, Callable
6+
from typing import Callable
77
from typing_extensions import Protocol
88

9-
from pydantic import BaseModel, Field
9+
from pydantic import BaseModel, Field, field_validator
1010

1111
from eval_protocol.models import EvaluationRow
12-
from eval_protocol.pytest.types import EvaluationTestMode
13-
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1412

1513

1614
class DataLoaderContext(BaseModel):
17-
"""Context provided to loader variants when materializing data."""
15+
"""Context provided to loader variants when materializing data. This is mainly used internally by eval-protocol."""
1816

19-
max_rows: int | None = Field(default=None, ge=1, description="Maximum number of rows to load")
2017
preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None = Field(
21-
default=None, description="Optional preprocessing function for evaluation rows"
18+
default=None,
19+
description="Optional preprocessing function for evaluation rows. This function is applied "
20+
"to the loaded data before it's returned, allowing for data cleaning, transformation, "
21+
"filtering, or other modifications. The function receives a list of EvaluationRow objects "
22+
"and should return a modified list of EvaluationRow objects.",
23+
)
24+
variant_id: str = Field(
25+
...,
26+
description="Unique identifier for the data loader variant. Used to distinguish between "
27+
"different variants of the same data loader and for tracking purposes in evaluation results.",
28+
)
29+
variant_description: str | None = Field(
30+
default=None,
31+
description="Human-readable description of the data loader variant. Provides context about what "
32+
"this variant represents, its purpose, or any special characteristics that distinguish "
33+
"it from other variants.",
2234
)
23-
logger: DatasetLogger = Field(description="Dataset logger for tracking operations")
24-
invocation_id: str = Field(description="Unique identifier for this invocation")
25-
experiment_id: str = Field(description="Unique identifier for this experiment")
26-
mode: EvaluationTestMode = Field(description="The evaluation test mode")
2735

28-
class Config:
29-
arbitrary_types_allowed = True # For Callable and DatasetLogger types
36+
@field_validator("variant_id")
37+
@classmethod
38+
def validate_variant_id(cls, v: str) -> str:
39+
if not v or not v.strip():
40+
raise ValueError("variant_id must be non-empty")
41+
return v
3042

3143

3244
class DataLoaderResult(BaseModel):
3345
"""Rows and metadata returned by a loader variant."""
3446

35-
rows: list[EvaluationRow] = Field(description="List of evaluation rows loaded")
36-
source_id: str = Field(description="Unique identifier for the data source")
37-
source_metadata: dict[str, Any] = Field(
38-
default_factory=dict, description="Additional metadata about the data source"
47+
rows: list[EvaluationRow] = Field(
48+
description="List of evaluation rows loaded from the data source. These are the "
49+
"processed and ready-to-use evaluation data that will be fed into the evaluation pipeline."
50+
)
51+
num_rows: int = Field(
52+
...,
53+
description="Number of rows loaded. This should match the length of the rows list "
54+
"and is used for validation and reporting purposes.",
55+
)
56+
type: str = Field(
57+
...,
58+
description="Type of the data loader that produced this result. Used for identification "
59+
"and debugging purposes (e.g., 'InlineDataLoader', 'FactoryDataLoader').",
60+
)
61+
variant_id: str = Field(
62+
...,
63+
description="Unique identifier for the data loader variant that produced this result. "
64+
"Used for tracking and organizing evaluation results from different data sources.",
3965
)
40-
raw_payload: Any | None = Field(default=None, description="Raw payload data if available")
41-
preprocessed: bool = Field(default=False, description="Whether the data has been preprocessed")
4266

43-
class Config:
44-
arbitrary_types_allowed = True # For Any type in raw_payload
67+
variant_description: str | None = Field(
68+
default=None,
69+
description="Human-readable description of the data loader variant that produced this result. "
70+
"Provides context about what this variant represents, its purpose, or any special characteristics that distinguish "
71+
"it from other variants.",
72+
)
73+
74+
preprocessed: bool = Field(
75+
default=False,
76+
description="Whether the data has been preprocessed. This flag indicates if any "
77+
"preprocessing functions have been applied to the data, helping to avoid duplicate "
78+
"processing and track data transformation state.",
79+
)
80+
81+
@field_validator("type")
82+
@classmethod
83+
def validate_type(cls, v: str) -> str:
84+
if not v or not v.strip():
85+
raise ValueError("type must be non-empty")
86+
return v
87+
88+
@field_validator("num_rows")
89+
@classmethod
90+
def validate_num_rows(cls, v: int) -> int:
91+
if v <= 0:
92+
raise ValueError("num_rows must be greater than 0")
93+
return v
94+
95+
@field_validator("variant_id")
96+
@classmethod
97+
def validate_variant_id(cls, v: str) -> str:
98+
if not v or not v.strip():
99+
raise ValueError("variant_id must be non-empty")
100+
return v
45101

46102

47103
class DataLoaderVariant(BaseModel):
48104
"""Single parameterizable variant from a data loader."""
49105

50-
id: str = Field(description="Unique identifier for this variant")
51-
description: str = Field(description="Human-readable description of this variant")
106+
id: str = Field(
107+
description="Unique identifier for this variant. Used to distinguish between different "
108+
"variants of the same data loader and for tracking purposes in evaluation results."
109+
)
110+
description: str | None = Field(
111+
default=None,
112+
description="Human-readable description of this variant. Provides context about what "
113+
"this variant represents, its purpose, or any special characteristics that distinguish "
114+
"it from other variants.",
115+
)
52116
loader: Callable[[DataLoaderContext], DataLoaderResult] = Field(
53-
description="Function that loads data for this variant"
117+
description="Function that loads data for this variant. This callable is invoked with "
118+
"a DataLoaderContext and should return a DataLoaderResult containing the loaded "
119+
"evaluation rows and associated metadata. The loader function is responsible for "
120+
"the actual data retrieval and any necessary processing."
54121
)
55-
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this variant")
122+
123+
@field_validator("id")
124+
@classmethod
125+
def validate_id(cls, v: str) -> str:
126+
if not v or not v.strip():
127+
raise ValueError("DataLoaderVariant.id must be non-empty")
128+
return v
56129

57130
class Config:
58131
arbitrary_types_allowed = True # For Callable type
@@ -69,3 +142,10 @@ class EvaluationDataLoader(Protocol):
69142
def variants(self) -> Sequence[DataLoaderVariant]:
70143
"""Return parameterizable variants emitted by this loader."""
71144
...
145+
146+
def load(self, ctx: DataLoaderContext) -> list[DataLoaderResult]:
147+
"""
148+
Loads all variants of this data loader and return a list of DataLoaderResult.
149+
"""
150+
variants = self.variants()
151+
return [variant.load(ctx) for variant in variants]

eval_protocol/pytest/evaluation_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def evaluation_test(
7272
input_messages: Sequence[list[InputMessagesParam] | None] | None = None,
7373
input_dataset: Sequence[DatasetPathParam] | None = None,
7474
input_rows: Sequence[list[EvaluationRow]] | None = None,
75-
input_data_loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None = None,
75+
data_loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None = None,
7676
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter, # pyright: ignore[reportExplicitAny]
7777
rollout_processor: RolloutProcessor | None = None,
7878
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None = None,
@@ -176,7 +176,7 @@ def evaluation_test(
176176

177177
active_logger: DatasetLogger = logger if logger else default_logger
178178

179-
if input_data_loaders is not None and (
179+
if data_loaders is not None and (
180180
input_dataset is not None or input_messages is not None or input_rows is not None
181181
):
182182
raise ValueError("data_loaders cannot be combined with input_dataset, input_messages, or input_rows.")

0 commit comments

Comments
 (0)