From 278b663594eff6817cfbc9b00e8a3b8f95d64169 Mon Sep 17 00:00:00 2001 From: "Yufei (Benny) Chen" <1585539+benjibc@users.noreply.github.com> Date: Thu, 18 Sep 2025 00:38:08 -0700 Subject: [PATCH] Add data loader abstraction to evaluation_test --- eval_protocol/adapters/langfuse.py | 2 +- eval_protocol/adapters/langsmith.py | 4 +- eval_protocol/pytest/__init__.py | 10 + eval_protocol/pytest/data_loaders.py | 173 +++++++++++++++++ eval_protocol/pytest/evaluation_test.py | 179 +++++++++++++++--- .../pytest/generate_parameter_combinations.py | 4 +- tests/pytest/test_data_loader.py | 81 ++++++++ 7 files changed, 421 insertions(+), 32 deletions(-) create mode 100644 eval_protocol/pytest/data_loaders.py create mode 100644 tests/pytest/test_data_loader.py diff --git a/eval_protocol/adapters/langfuse.py b/eval_protocol/adapters/langfuse.py index 44c43fe2..d59ef3fb 100644 --- a/eval_protocol/adapters/langfuse.py +++ b/eval_protocol/adapters/langfuse.py @@ -219,7 +219,7 @@ def __init__(self): if not LANGFUSE_AVAILABLE: raise ImportError("Langfuse not installed. Install with: pip install 'eval-protocol[langfuse]'") - self.client = get_client() + self.client = get_client() # pyright: ignore[reportCallIssue] def get_evaluation_rows( self, diff --git a/eval_protocol/adapters/langsmith.py b/eval_protocol/adapters/langsmith.py index fc1daf71..7b83ccab 100644 --- a/eval_protocol/adapters/langsmith.py +++ b/eval_protocol/adapters/langsmith.py @@ -35,10 +35,10 @@ class LangSmithAdapter(BaseAdapter): - outputs: { messages: [...] } | { content } | { result } | { answer } | { output } | str | list[dict] """ - def __init__(self, client: Optional[Client] = None) -> None: + def __init__(self, client: Optional[Any] = None) -> None: if not LANGSMITH_AVAILABLE: raise ImportError("LangSmith not installed. Install with: pip install 'eval-protocol[langsmith]'") - self.client = client or Client() + self.client = client or Client() # pyright: ignore[reportCallIssue] def get_evaluation_rows( self, diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index b6d02ae2..fa9096de 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -7,6 +7,12 @@ from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig +from .data_loaders import ( + EvaluationDataLoader, + InlineDataLoader, + LangfuseAdapterLoader, + LangfuseLoaderConfig, +) # Conditional import for optional dependencies try: @@ -38,6 +44,10 @@ "ExceptionHandlerConfig", "BackoffConfig", "get_default_exception_handler_config", + "EvaluationDataLoader", + "InlineDataLoader", + "LangfuseAdapterLoader", + "LangfuseLoaderConfig", ] # Only add to __all__ if available diff --git a/eval_protocol/pytest/data_loaders.py b/eval_protocol/pytest/data_loaders.py new file mode 100644 index 00000000..e7f4f629 --- /dev/null +++ b/eval_protocol/pytest/data_loaders.py @@ -0,0 +1,173 @@ +"""Data loader abstractions for evaluation tests.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Protocol, Sequence + +from eval_protocol.adapters.base import BaseAdapter +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.types import EvaluationTestMode, InputMessagesParam +from eval_protocol.dataset_logger.dataset_logger import DatasetLogger + + +@dataclass(slots=True) +class DataLoaderContext: + """Context provided to loader variants when materializing data.""" + + max_rows: int | None + preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None + logger: DatasetLogger + invocation_id: str + experiment_id: str + mode: EvaluationTestMode + + +@dataclass(slots=True) +class DataLoaderResult: + """Rows and metadata returned by a loader variant.""" + + rows: list[EvaluationRow] + source_id: str + source_metadata: dict[str, Any] = field(default_factory=dict) + raw_payload: Any | None = None + preprocessed: bool = False + + +@dataclass(slots=True) +class DataLoaderVariant: + """Single parameterizable variant from a data loader.""" + + id: str + description: str + loader: Callable[[DataLoaderContext], DataLoaderResult] + metadata: dict[str, Any] = field(default_factory=dict) + + def load(self, ctx: DataLoaderContext) -> DataLoaderResult: + """Load a dataset for this variant using the provided context.""" + + return self.loader(ctx) + + +class EvaluationDataLoader(Protocol): + """Protocol for data loaders that can be consumed by ``evaluation_test``.""" + + def variants(self) -> Sequence[DataLoaderVariant]: + """Return parameterizable variants emitted by this loader.""" + + ... + + +@dataclass(slots=True) +class InlineDataLoader(EvaluationDataLoader): + """Data loader for inline ``EvaluationRow`` or message payloads.""" + + rows: Sequence[EvaluationRow] | None = None + messages: Sequence[InputMessagesParam] | None = None + variant_id: str = "inline" + description: str | None = None + + def __post_init__(self) -> None: + if self.rows is None and self.messages is None: + raise ValueError("InlineDataLoader requires rows or messages to be provided") + + def variants(self) -> Sequence[DataLoaderVariant]: + def _load(ctx: DataLoaderContext) -> DataLoaderResult: + resolved_rows: list[EvaluationRow] = [] + if self.rows is not None: + resolved_rows.extend(row.model_copy(deep=True) for row in self.rows) + if self.messages is not None: + for dataset_messages in self.messages: + row_messages: list[Message] = [] + for msg in dataset_messages: + if isinstance(msg, Message): + row_messages.append(msg.model_copy(deep=True)) + else: + row_messages.append(Message.model_validate(msg)) + resolved_rows.append(EvaluationRow(messages=row_messages)) + + if ctx.max_rows is not None: + resolved_rows = resolved_rows[: ctx.max_rows] + + metadata = { + "data_loader_variant_id": self.variant_id, + "data_loader_type": "inline", + "row_count": len(resolved_rows), + } + + return DataLoaderResult( + rows=resolved_rows, + source_id=self.variant_id, + source_metadata=metadata, + ) + + description = self.description or self.variant_id + return [ + DataLoaderVariant( + id=self.variant_id, + description=description, + loader=_load, + metadata={"type": "inline"}, + ) + ] + + +@dataclass(slots=True) +class LangfuseLoaderConfig: + """Configuration for a single Langfuse adapter variant.""" + + id: str + kwargs: dict[str, Any] = field(default_factory=dict) + description: str | None = None + + +@dataclass(slots=True) +class LangfuseAdapterLoader(EvaluationDataLoader): + """Wrap a ``LangfuseAdapter`` (or compatible adapter) as a data loader.""" + + adapter: BaseAdapter + variants_config: Sequence[LangfuseLoaderConfig] + + def variants(self) -> Sequence[DataLoaderVariant]: + loader_variants: list[DataLoaderVariant] = [] + + for config in self.variants_config: + + def _load(ctx: DataLoaderContext, *, _config: LangfuseLoaderConfig = config) -> DataLoaderResult: + rows = self.adapter.get_evaluation_rows(**_config.kwargs) + if ctx.max_rows is not None: + rows = rows[: ctx.max_rows] + + metadata = { + "data_loader_variant_id": _config.id, + "data_loader_type": "langfuse", + "adapter_kwargs": _config.kwargs, + } + + return DataLoaderResult( + rows=[row.model_copy(deep=True) for row in rows], + source_id=_config.id, + source_metadata=metadata, + ) + + loader_variants.append( + DataLoaderVariant( + id=config.id, + description=config.description or config.id, + loader=_load, + metadata={"type": "langfuse", "adapter_kwargs": config.kwargs}, + ) + ) + + return loader_variants + + +__all__ = [ + "DataLoaderContext", + "DataLoaderResult", + "DataLoaderVariant", + "EvaluationDataLoader", + "InlineDataLoader", + "LangfuseAdapterLoader", + "LangfuseLoaderConfig", +] diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index a7ec65f3..020f6505 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -26,10 +26,22 @@ from eval_protocol.pytest.evaluation_test_postprocess import postprocess from eval_protocol.pytest.execution import execute_pytest from eval_protocol.pytest.generate_parameter_combinations import ( + CombinationTuple, ParameterizedTestKwargs, generate_parameter_combinations, ) -from eval_protocol.pytest.parameterize import pytest_parametrize, create_dynamically_parameterized_wrapper +from eval_protocol.pytest.data_loaders import ( + DataLoaderContext, + DataLoaderResult, + DataLoaderVariant, + EvaluationDataLoader, +) +from eval_protocol.pytest.parameterize import ( + DefaultParameterIdGenerator, + PytestParametrizeArgs, + create_dynamically_parameterized_wrapper, + pytest_parametrize, +) from eval_protocol.pytest.validate_signature import validate_signature from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor @@ -69,6 +81,7 @@ def evaluation_test( input_messages: Sequence[list[InputMessagesParam] | None] | None = None, input_dataset: Sequence[DatasetPathParam] | None = None, input_rows: Sequence[list[EvaluationRow]] | None = None, + data_loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None = None, dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter, # pyright: ignore[reportExplicitAny] rollout_processor: RolloutProcessor | None = None, evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None = None, @@ -131,6 +144,9 @@ def evaluation_test( input_rows: Pre-constructed EvaluationRow objects to use directly. This is useful when you want to provide EvaluationRow objects with custom metadata, input_messages, or other fields already populated. Will be passed as "input_dataset" to the test function. + data_loaders: Data loader(s) that produce datasets or message bundles. When provided, + ``input_dataset``, ``input_messages``, and ``input_rows`` must be omitted. Each loader + can expose one or more parameterizable variants, similar to ``torch.utils.data.DataLoader``. dataset_adapter: Function to convert the input dataset to a list of EvaluationRows. This is useful if you have a custom dataset format. completion_params: Generation parameters for the rollout. @@ -165,6 +181,11 @@ def evaluation_test( active_logger: DatasetLogger = logger if logger else default_logger + if data_loaders is not None and ( + input_dataset is not None or input_messages is not None or input_rows is not None + ): + raise ValueError("data_loaders cannot be combined with input_dataset, input_messages, or input_rows.") + # Optional global overrides via environment for ad-hoc experimentation # EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged # into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}'). @@ -175,36 +196,112 @@ def evaluation_test( original_completion_params = completion_params passed_threshold = parse_ep_passed_threshold(passed_threshold) + def _normalize_data_loaders( + loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None, + ) -> list[EvaluationDataLoader]: + if loaders is None: + return [] + if isinstance(loaders, Sequence): + return list(loaders) + return [loaders] + + def _build_data_loader_parametrize_args( + loader_variants: Sequence[DataLoaderVariant], + completion_params_seq: Sequence[CompletionParams | None] | None, + evaluation_test_kwargs_seq: Sequence[EvaluationInputParam | None] | None, + ) -> PytestParametrizeArgs: + if not loader_variants: + raise ValueError("No data loader variants were produced by the provided data_loaders.") + + argnames: list[str] = ["data_loader_variant"] + if completion_params_seq is not None: + argnames.append("completion_params") + if evaluation_test_kwargs_seq is not None: + argnames.append("evaluation_test_kwargs") + + completion_values: Sequence[CompletionParams | None] + if completion_params_seq is None: + completion_values = [None] + else: + completion_values = completion_params_seq + + etk_values: Sequence[EvaluationInputParam | None] + if evaluation_test_kwargs_seq is None: + etk_values = [None] + else: + etk_values = evaluation_test_kwargs_seq + + argvalues: list[tuple[object, ...]] = [] + ids: list[str] = [] + id_generator = DefaultParameterIdGenerator() + + for variant in loader_variants: + for cp in completion_values: + for etk in etk_values: + param_tuple: list[object] = [variant] + if completion_params_seq is not None: + param_tuple.append(cp) + if evaluation_test_kwargs_seq is not None: + param_tuple.append(etk) + argvalues.append(tuple(param_tuple)) + + cp_id = id_generator.generate_id((None, cp, None, None, etk)) + if cp_id: + ids.append(f"{variant.id}-{cp_id}") + else: + ids.append(variant.id) + + return PytestParametrizeArgs( + argnames=argnames, + argvalues=argvalues, + ids=ids if any(ids) else None, + ) + def decorator( test_func: TestFunction, ) -> TestFunction: sig = inspect.signature(test_func) validate_signature(sig, mode, completion_params) - # Calculate all possible combinations of parameters - combinations = generate_parameter_combinations( - input_dataset, - completion_params, - input_messages, - input_rows, - evaluation_test_kwargs, - max_dataset_rows, - combine_datasets, - ) - if len(combinations) == 0: - raise ValueError( - "No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows." - ) + normalized_loaders = _normalize_data_loaders(data_loaders) + loader_variants: list[DataLoaderVariant] = [] + for loader in normalized_loaders: + loader_variants.extend(loader.variants()) - # Create parameter tuples for pytest.mark.parametrize - pytest_parametrize_args = pytest_parametrize( - combinations, - input_dataset, - completion_params, - input_messages, - input_rows, - evaluation_test_kwargs, - ) + has_data_loader_variants = len(loader_variants) > 0 + + combinations: list[CombinationTuple] = [] + if has_data_loader_variants: + pytest_parametrize_args = _build_data_loader_parametrize_args( + loader_variants, + completion_params, + evaluation_test_kwargs, + ) + else: + # Calculate all possible combinations of parameters + combinations = generate_parameter_combinations( + input_dataset, + completion_params, + input_messages, + input_rows, + evaluation_test_kwargs, + max_dataset_rows, + combine_datasets, + ) + if len(combinations) == 0: + raise ValueError( + "No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, input_rows, or data_loaders." + ) + + # Create parameter tuples for pytest.mark.parametrize + pytest_parametrize_args = pytest_parametrize( + combinations, + input_dataset, + completion_params, + input_messages, + input_rows, + evaluation_test_kwargs, + ) # Create wrapper function with exact signature that pytest expects def create_wrapper_with_signature() -> Callable[[], None]: @@ -225,9 +322,24 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo try: # Handle dataset loading data: list[EvaluationRow] = [] + batch: DataLoaderResult | None = None # Track all rows processed in the current run for error logging processed_rows_in_run: list[EvaluationRow] = [] - if "dataset_path" in kwargs and kwargs["dataset_path"] is not None: + data_loader_variant = kwargs.get("data_loader_variant") + if data_loader_variant is not None: + loader_context = DataLoaderContext( + max_rows=max_dataset_rows, + preprocess_fn=preprocess_fn, + logger=active_logger, + invocation_id=invocation_id, + experiment_id=experiment_id, + mode=mode, + ) + batch = data_loader_variant.load(loader_context) + data = batch.rows + if max_dataset_rows is not None and len(data) > max_dataset_rows: + data = data[:max_dataset_rows] + elif "dataset_path" in kwargs and kwargs["dataset_path"] is not None: ds_arg: list[str] = kwargs["dataset_path"] # Support either a single path or a list of paths; if a list is provided, # concatenate the rows from each file in order. @@ -246,11 +358,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo # Deep copy pre-constructed EvaluationRow objects data = [row.model_copy(deep=True) for row in kwargs["input_rows"]] else: - raise ValueError("No input dataset, input messages, or input rows provided") + raise ValueError("No input dataset, input messages, input rows, or data loader provided") - if preprocess_fn: + if preprocess_fn and not (batch is not None and batch.preprocessed): data = preprocess_fn(data) + if data_loader_variant is not None and batch is not None: + for row in data: + dataset_info = dict(row.input_metadata.dataset_info or {}) + dataset_info.setdefault("data_loader_variant_id", data_loader_variant.id) + dataset_info.setdefault("data_loader_variant_description", data_loader_variant.description) + dataset_info.setdefault("data_loader_source_id", batch.source_id) + if batch.source_metadata: + for key, value in batch.source_metadata.items(): + dataset_info.setdefault(key, value) + row.input_metadata.dataset_info = dataset_info + for row in data: # generate a stable row_id for each row if row.input_metadata.row_id is None: @@ -261,7 +384,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo index = abs(index) % (max_index + 1) row.input_metadata.row_id = generate_id(seed=0, index=index) - completion_params = kwargs["completion_params"] + completion_params = kwargs.get("completion_params") # Create eval metadata with test function info and current commit hash eval_metadata = EvalMetadata( name=test_func.__name__, diff --git a/eval_protocol/pytest/generate_parameter_combinations.py b/eval_protocol/pytest/generate_parameter_combinations.py index 6a1dcf2f..eeef044f 100644 --- a/eval_protocol/pytest/generate_parameter_combinations.py +++ b/eval_protocol/pytest/generate_parameter_combinations.py @@ -1,5 +1,6 @@ from typing import TypedDict from eval_protocol.models import CompletionParams, EvaluationRow +from eval_protocol.pytest.data_loaders import DataLoaderVariant from eval_protocol.pytest.types import Dataset, DatasetPathParam, EvaluationInputParam, InputMessagesParam from eval_protocol.pytest.utils import parse_ep_max_rows from collections.abc import Sequence @@ -31,7 +32,7 @@ ] -class ParameterizedTestKwargs(TypedDict): +class ParameterizedTestKwargs(TypedDict, total=False): """ These are the type of parameters that can be passed to the generated pytest function. Every experiment is a unique combination of these parameters. @@ -42,6 +43,7 @@ class ParameterizedTestKwargs(TypedDict): input_messages: InputMessagesKwarg input_rows: InputRowsKwarg evaluation_test_kwargs: EvaluationTestKwargs + data_loader_variant: DataLoaderVariant def generate_parameter_combinations( diff --git a/tests/pytest/test_data_loader.py b/tests/pytest/test_data_loader.py new file mode 100644 index 00000000..87e250b7 --- /dev/null +++ b/tests/pytest/test_data_loader.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Any + + +import pytest + +pytest.importorskip("openai") +pytest.importorskip("loguru") +pytest.importorskip("toml") +pytest.importorskip("addict") +pytest.importorskip("deepdiff") +pytest.importorskip("dotenv") + +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import ( + InlineDataLoader, + LangfuseAdapterLoader, + LangfuseLoaderConfig, + NoOpRolloutProcessor, + evaluation_test, +) + + +@evaluation_test( + data_loaders=InlineDataLoader( + messages=[[Message(role="user", content="What is 2 + 2?")]], + ), + completion_params=[{"model": "no-op"}], + rollout_processor=NoOpRolloutProcessor(), +) +def test_inline_data_loader(row: EvaluationRow) -> EvaluationRow: + """Inline data loader should feed pre-constructed message bundles.""" + + assert row.messages[0].content == "What is 2 + 2?" + assert row.input_metadata.dataset_info is not None + assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "inline" + return row + + +class _FakeLangfuseAdapter: + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + + def get_evaluation_rows(self, **kwargs: Any) -> list[EvaluationRow]: + self.calls.append(kwargs) + return [ + EvaluationRow(messages=[Message(role="user", content="trace-0")]), + EvaluationRow(messages=[Message(role="user", content="trace-1")]), + ] + + +_fake_adapter = _FakeLangfuseAdapter() + + +def _preprocess_rows(rows: list[EvaluationRow]) -> list[EvaluationRow]: + for row in rows: + row.messages[0].content = f"processed-{row.messages[0].content}" + return rows + + +@evaluation_test( + data_loaders=LangfuseAdapterLoader( + adapter=_fake_adapter, + variants_config=[LangfuseLoaderConfig(id="recent", kwargs={"limit": 5})], + ), + completion_params=[{"model": "no-op"}], + rollout_processor=NoOpRolloutProcessor(), + max_dataset_rows=1, + preprocess_fn=_preprocess_rows, +) +def test_langfuse_data_loader(row: EvaluationRow) -> EvaluationRow: + """Langfuse data loader should pull traces and respect preprocess/max_rows.""" + + assert _fake_adapter.calls == [{"limit": 5}] + assert row.messages[0].content == "processed-trace-0" + assert row.input_metadata.dataset_info is not None + dataset_info = row.input_metadata.dataset_info + assert dataset_info.get("data_loader_variant_id") == "recent" + assert dataset_info.get("adapter_kwargs") == {"limit": 5} + return row