Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
rollout,
test_mcp,
)
from .data_loader import DynamicDataLoader, InlineDataLoader

# Try to import FireworksPolicy if available
try:
Expand Down Expand Up @@ -63,6 +64,8 @@

__all__ = [
"DefaultParameterIdGenerator",
"DynamicDataLoader",
"InlineDataLoader",
"aha_judge",
"multi_turn_assistant_to_ground_truth",
"assistant_to_ground_truth",
Expand Down
4 changes: 4 additions & 0 deletions eval_protocol/data_loader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .dynamic_data_loader import DynamicDataLoader
from .inline_data_loader import InlineDataLoader

__all__ = ["DynamicDataLoader", "InlineDataLoader"]
38 changes: 38 additions & 0 deletions eval_protocol/data_loader/dynamic_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections.abc import Callable, Sequence
from dataclasses import dataclass

from eval_protocol.data_loader.models import (
DataLoaderResult,
DataLoaderVariant,
EvaluationDataLoader,
)
from eval_protocol.models import EvaluationRow


@dataclass(kw_only=True)
class DynamicDataLoader(EvaluationDataLoader):
"""Data loader for dynamic data generation."""

generators: Sequence[Callable[[], list[EvaluationRow]]]
"""Dynamic data generation functions. These callables are invoked each time data
needs to be loaded, allowing for dynamic data generation, lazy loading, or data that
changes between evaluation runs. Each function should return a list of EvaluationRow
objects. This is useful for scenarios like generating test data on-the-fly, loading
data from external sources, or creating data with randomized elements for robust testing."""

def variants(self) -> Sequence[DataLoaderVariant]:
variants: Sequence[DataLoaderVariant] = []
for generator in self.generators:

def _load() -> DataLoaderResult:
resolved_rows = generator()
return DataLoaderResult(
rows=resolved_rows,
type=self.__class__.__name__,
variant_id=generator.__name__,
variant_description=generator.__doc__,
)

variants.append(_load)

return variants
38 changes: 38 additions & 0 deletions eval_protocol/data_loader/factory_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections.abc import Callable, Sequence
from dataclasses import dataclass

from eval_protocol.data_loader.models import (
DataLoaderResult,
DataLoaderVariant,
EvaluationDataLoader,
)
from eval_protocol.models import EvaluationRow


@dataclass(kw_only=True)
class DynamicDataLoader(EvaluationDataLoader):
"""Data loader for dynamic data generation."""

factory: Sequence[Callable[[], list[EvaluationRow]]]
"""Dynamic data generation functions. These callables are invoked each time data
needs to be loaded, allowing for dynamic data generation, lazy loading, or data that
changes between evaluation runs. Each function should return a list of EvaluationRow
objects. This is useful for scenarios like generating test data on-the-fly, loading
data from external sources, or creating data with randomized elements for robust testing."""

def variants(self) -> Sequence[DataLoaderVariant]:
variants: Sequence[DataLoaderVariant] = []
for factory in self.factory:

def _load() -> DataLoaderResult:
resolved_rows = factory()
return DataLoaderResult(
rows=resolved_rows,
type=self.__class__.__name__,
variant_id=factory.__name__,
variant_description=factory.__doc__,
)

variants.append(_load)

return variants
68 changes: 68 additions & 0 deletions eval_protocol/data_loader/inline_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from collections.abc import Sequence
from dataclasses import dataclass

from eval_protocol.data_loader.models import (
DataLoaderResult,
DataLoaderVariant,
EvaluationDataLoader,
)
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.types import InputMessagesParam


DEFAULT_VARIANT_ID: str = "inline"


@dataclass(kw_only=True)
class InlineDataLoader(EvaluationDataLoader):
"""Data loader for inline ``EvaluationRow`` or message payloads."""

rows: list[EvaluationRow] | None = None
"""Pre-defined evaluation rows with tools and metadata. Use this when you have complete
EvaluationRow objects that include tools, input_metadata, and other structured data.
This is the preferred option when working with tool-calling scenarios or when you need
to provide additional metadata like row_id, dataset information, or custom fields."""

messages: Sequence[InputMessagesParam] | None = None
"""Raw chat completion message history. Use this when you only have simple
conversation history without tools or additional metadata. The messages will be
automatically converted to EvaluationRow objects. InputMessagesParam is a list of
Message objects representing the conversation flow (user, assistant, system messages)."""

id: str = DEFAULT_VARIANT_ID
"""Unique identifier for this data loader variant. Used to label and distinguish
different input data sources, versions, or configurations. This helps with tracking
and organizing evaluation results from different data sources."""

description: str | None = None
"""Optional human-readable description of this data loader. Provides additional
context about the data source, purpose, or any special characteristics. Used for
documentation and debugging purposes. If not provided, the variant_id will be used instead."""

def __post_init__(self) -> None:
if self.rows is None and self.messages is None:
raise ValueError("InlineDataLoader requires rows or messages to be provided")

def variants(self) -> Sequence[DataLoaderVariant]:
def _load() -> DataLoaderResult:
resolved_rows: list[EvaluationRow] = []
if self.rows is not None:
resolved_rows = [row.model_copy(deep=True) for row in self.rows]
if self.messages is not None:
for dataset_messages in self.messages:
row_messages: list[Message] = []
for msg in dataset_messages:
if isinstance(msg, Message):
row_messages.append(msg.model_copy(deep=True))
else:
row_messages.append(Message.model_validate(msg))
resolved_rows.append(EvaluationRow(messages=row_messages))

return DataLoaderResult(
rows=resolved_rows,
variant_id=self.id,
variant_description=self.description,
type=self.__class__.__name__,
)

return [_load]
128 changes: 128 additions & 0 deletions eval_protocol/data_loader/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Data loader abstractions"""

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Callable
from typing_extensions import Protocol
from abc import ABC, abstractmethod

from pydantic import BaseModel, Field, field_validator

from eval_protocol.models import EvaluationRow


class DataLoaderResult(BaseModel):
"""Rows and metadata returned by a loader variant."""

rows: list[EvaluationRow] = Field(
description="List of evaluation rows loaded from the data source. These are the "
"processed and ready-to-use evaluation data that will be fed into the evaluation pipeline."
)

type: str = Field(
...,
description="Type of the data loader that produced this result. Used for identification "
"and debugging purposes (e.g., 'InlineDataLoader', 'DynamicDataLoader').",
)

variant_id: str = Field(
...,
description="Unique identifier for the data loader variant that produced this result. "
"Used for tracking and organizing evaluation results from different data sources.",
)

variant_description: str | None = Field(
default=None,
description="Human-readable description of the data loader variant that produced this result. "
"Provides context about what this variant represents, its purpose, or any special characteristics that distinguish "
"it from other variants.",
)

preprocessed: bool = Field(
default=False,
description="Whether the data has been preprocessed. This flag indicates if any "
"preprocessing functions have been applied to the data, helping to avoid duplicate "
"processing and track data transformation state.",
)

@field_validator("type")
@classmethod
def validate_type(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("type must be non-empty")
return v

@field_validator("variant_id")
@classmethod
def validate_variant_id(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("variant_id must be non-empty")
return v


class DataLoaderVariant(Protocol):
"""Single parameterizable variant from a data loader."""

def __call__(self) -> DataLoaderResult:
"""Load a dataset for this variant using the provided context."""
...


@dataclass(kw_only=True)
class EvaluationDataLoader(ABC):
"""Abstract base class for data loaders that can be consumed by ``evaluation_test``."""

preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None = None
"""Optional preprocessing function for evaluation rows. This function is applied
to the loaded data before it's returned, allowing for data cleaning, transformation,
filtering, or other modifications. The function receives a list of EvaluationRow objects
and should return a modified list of EvaluationRow objects."""

@abstractmethod
def variants(self) -> Sequence[DataLoaderVariant]:
"""Return parameterizable variants emitted by this loader."""
...

def load(self) -> list[DataLoaderResult]:
"""Loads all variants of this data loader and return a list of DataLoaderResult."""
results = []
for variant in self.variants():
result = variant()
result = self._process_variant(result)
results.append(result)
return results

def _process_variant(self, result: DataLoaderResult) -> DataLoaderResult:
"""Process a single variant: preprocess data and apply metadata."""
# Preprocess data
original_count = len(result.rows)
if self.preprocess_fn:
result.rows = self.preprocess_fn(result.rows)
result.preprocessed = True
processed_count = len(result.rows)
else:
processed_count = original_count

# Apply metadata to rows
self._apply_metadata(result, original_count, processed_count)
return result

def _apply_metadata(self, result: DataLoaderResult, original_count: int, processed_count: int) -> None:
"""Apply metadata to all rows in the result."""
for row in result.rows:
if row.input_metadata.dataset_info is None:
row.input_metadata.dataset_info = {}

# Apply result attributes as metadata
for attr_name, attr_value in vars(result).items():
"""
Exclude rows and private attributes from metadata.
"""
if attr_name != "rows" and not attr_name.startswith("_"):
row.input_metadata.dataset_info[f"data_loader_{attr_name}"] = attr_value

# Apply row counts
row.input_metadata.dataset_info["data_loader_num_rows"] = original_count
row.input_metadata.dataset_info["data_loader_num_rows_after_preprocessing"] = processed_count
Loading
Loading