diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index 7b868b2b..5bd5eec3 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -22,6 +22,7 @@ rollout, test_mcp, ) +from .data_loader import DynamicDataLoader, InlineDataLoader # Try to import FireworksPolicy if available try: @@ -63,6 +64,8 @@ __all__ = [ "DefaultParameterIdGenerator", + "DynamicDataLoader", + "InlineDataLoader", "aha_judge", "multi_turn_assistant_to_ground_truth", "assistant_to_ground_truth", diff --git a/eval_protocol/data_loader/__init__.py b/eval_protocol/data_loader/__init__.py new file mode 100644 index 00000000..4c92b023 --- /dev/null +++ b/eval_protocol/data_loader/__init__.py @@ -0,0 +1,4 @@ +from .dynamic_data_loader import DynamicDataLoader +from .inline_data_loader import InlineDataLoader + +__all__ = ["DynamicDataLoader", "InlineDataLoader"] diff --git a/eval_protocol/data_loader/dynamic_data_loader.py b/eval_protocol/data_loader/dynamic_data_loader.py new file mode 100644 index 00000000..77c2efc2 --- /dev/null +++ b/eval_protocol/data_loader/dynamic_data_loader.py @@ -0,0 +1,38 @@ +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +from eval_protocol.data_loader.models import ( + DataLoaderResult, + DataLoaderVariant, + EvaluationDataLoader, +) +from eval_protocol.models import EvaluationRow + + +@dataclass(kw_only=True) +class DynamicDataLoader(EvaluationDataLoader): + """Data loader for dynamic data generation.""" + + generators: Sequence[Callable[[], list[EvaluationRow]]] + """Dynamic data generation functions. These callables are invoked each time data + needs to be loaded, allowing for dynamic data generation, lazy loading, or data that + changes between evaluation runs. Each function should return a list of EvaluationRow + objects. This is useful for scenarios like generating test data on-the-fly, loading + data from external sources, or creating data with randomized elements for robust testing.""" + + def variants(self) -> Sequence[DataLoaderVariant]: + variants: Sequence[DataLoaderVariant] = [] + for generator in self.generators: + + def _load() -> DataLoaderResult: + resolved_rows = generator() + return DataLoaderResult( + rows=resolved_rows, + type=self.__class__.__name__, + variant_id=generator.__name__, + variant_description=generator.__doc__, + ) + + variants.append(_load) + + return variants diff --git a/eval_protocol/data_loader/factory_data_loader.py b/eval_protocol/data_loader/factory_data_loader.py new file mode 100644 index 00000000..4544c756 --- /dev/null +++ b/eval_protocol/data_loader/factory_data_loader.py @@ -0,0 +1,38 @@ +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +from eval_protocol.data_loader.models import ( + DataLoaderResult, + DataLoaderVariant, + EvaluationDataLoader, +) +from eval_protocol.models import EvaluationRow + + +@dataclass(kw_only=True) +class DynamicDataLoader(EvaluationDataLoader): + """Data loader for dynamic data generation.""" + + factory: Sequence[Callable[[], list[EvaluationRow]]] + """Dynamic data generation functions. These callables are invoked each time data + needs to be loaded, allowing for dynamic data generation, lazy loading, or data that + changes between evaluation runs. Each function should return a list of EvaluationRow + objects. This is useful for scenarios like generating test data on-the-fly, loading + data from external sources, or creating data with randomized elements for robust testing.""" + + def variants(self) -> Sequence[DataLoaderVariant]: + variants: Sequence[DataLoaderVariant] = [] + for factory in self.factory: + + def _load() -> DataLoaderResult: + resolved_rows = factory() + return DataLoaderResult( + rows=resolved_rows, + type=self.__class__.__name__, + variant_id=factory.__name__, + variant_description=factory.__doc__, + ) + + variants.append(_load) + + return variants diff --git a/eval_protocol/data_loader/inline_data_loader.py b/eval_protocol/data_loader/inline_data_loader.py new file mode 100644 index 00000000..e8226e1e --- /dev/null +++ b/eval_protocol/data_loader/inline_data_loader.py @@ -0,0 +1,68 @@ +from collections.abc import Sequence +from dataclasses import dataclass + +from eval_protocol.data_loader.models import ( + DataLoaderResult, + DataLoaderVariant, + EvaluationDataLoader, +) +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.types import InputMessagesParam + + +DEFAULT_VARIANT_ID: str = "inline" + + +@dataclass(kw_only=True) +class InlineDataLoader(EvaluationDataLoader): + """Data loader for inline ``EvaluationRow`` or message payloads.""" + + rows: list[EvaluationRow] | None = None + """Pre-defined evaluation rows with tools and metadata. Use this when you have complete + EvaluationRow objects that include tools, input_metadata, and other structured data. + This is the preferred option when working with tool-calling scenarios or when you need + to provide additional metadata like row_id, dataset information, or custom fields.""" + + messages: Sequence[InputMessagesParam] | None = None + """Raw chat completion message history. Use this when you only have simple + conversation history without tools or additional metadata. The messages will be + automatically converted to EvaluationRow objects. InputMessagesParam is a list of + Message objects representing the conversation flow (user, assistant, system messages).""" + + id: str = DEFAULT_VARIANT_ID + """Unique identifier for this data loader variant. Used to label and distinguish + different input data sources, versions, or configurations. This helps with tracking + and organizing evaluation results from different data sources.""" + + description: str | None = None + """Optional human-readable description of this data loader. Provides additional + context about the data source, purpose, or any special characteristics. Used for + documentation and debugging purposes. If not provided, the variant_id will be used instead.""" + + def __post_init__(self) -> None: + if self.rows is None and self.messages is None: + raise ValueError("InlineDataLoader requires rows or messages to be provided") + + def variants(self) -> Sequence[DataLoaderVariant]: + def _load() -> DataLoaderResult: + resolved_rows: list[EvaluationRow] = [] + if self.rows is not None: + resolved_rows = [row.model_copy(deep=True) for row in self.rows] + if self.messages is not None: + for dataset_messages in self.messages: + row_messages: list[Message] = [] + for msg in dataset_messages: + if isinstance(msg, Message): + row_messages.append(msg.model_copy(deep=True)) + else: + row_messages.append(Message.model_validate(msg)) + resolved_rows.append(EvaluationRow(messages=row_messages)) + + return DataLoaderResult( + rows=resolved_rows, + variant_id=self.id, + variant_description=self.description, + type=self.__class__.__name__, + ) + + return [_load] diff --git a/eval_protocol/data_loader/models.py b/eval_protocol/data_loader/models.py new file mode 100644 index 00000000..0179272e --- /dev/null +++ b/eval_protocol/data_loader/models.py @@ -0,0 +1,128 @@ +"""Data loader abstractions""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Callable +from typing_extensions import Protocol +from abc import ABC, abstractmethod + +from pydantic import BaseModel, Field, field_validator + +from eval_protocol.models import EvaluationRow + + +class DataLoaderResult(BaseModel): + """Rows and metadata returned by a loader variant.""" + + rows: list[EvaluationRow] = Field( + description="List of evaluation rows loaded from the data source. These are the " + "processed and ready-to-use evaluation data that will be fed into the evaluation pipeline." + ) + + type: str = Field( + ..., + description="Type of the data loader that produced this result. Used for identification " + "and debugging purposes (e.g., 'InlineDataLoader', 'DynamicDataLoader').", + ) + + variant_id: str = Field( + ..., + description="Unique identifier for the data loader variant that produced this result. " + "Used for tracking and organizing evaluation results from different data sources.", + ) + + variant_description: str | None = Field( + default=None, + description="Human-readable description of the data loader variant that produced this result. " + "Provides context about what this variant represents, its purpose, or any special characteristics that distinguish " + "it from other variants.", + ) + + preprocessed: bool = Field( + default=False, + description="Whether the data has been preprocessed. This flag indicates if any " + "preprocessing functions have been applied to the data, helping to avoid duplicate " + "processing and track data transformation state.", + ) + + @field_validator("type") + @classmethod + def validate_type(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("type must be non-empty") + return v + + @field_validator("variant_id") + @classmethod + def validate_variant_id(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("variant_id must be non-empty") + return v + + +class DataLoaderVariant(Protocol): + """Single parameterizable variant from a data loader.""" + + def __call__(self) -> DataLoaderResult: + """Load a dataset for this variant using the provided context.""" + ... + + +@dataclass(kw_only=True) +class EvaluationDataLoader(ABC): + """Abstract base class for data loaders that can be consumed by ``evaluation_test``.""" + + preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None = None + """Optional preprocessing function for evaluation rows. This function is applied + to the loaded data before it's returned, allowing for data cleaning, transformation, + filtering, or other modifications. The function receives a list of EvaluationRow objects + and should return a modified list of EvaluationRow objects.""" + + @abstractmethod + def variants(self) -> Sequence[DataLoaderVariant]: + """Return parameterizable variants emitted by this loader.""" + ... + + def load(self) -> list[DataLoaderResult]: + """Loads all variants of this data loader and return a list of DataLoaderResult.""" + results = [] + for variant in self.variants(): + result = variant() + result = self._process_variant(result) + results.append(result) + return results + + def _process_variant(self, result: DataLoaderResult) -> DataLoaderResult: + """Process a single variant: preprocess data and apply metadata.""" + # Preprocess data + original_count = len(result.rows) + if self.preprocess_fn: + result.rows = self.preprocess_fn(result.rows) + result.preprocessed = True + processed_count = len(result.rows) + else: + processed_count = original_count + + # Apply metadata to rows + self._apply_metadata(result, original_count, processed_count) + return result + + def _apply_metadata(self, result: DataLoaderResult, original_count: int, processed_count: int) -> None: + """Apply metadata to all rows in the result.""" + for row in result.rows: + if row.input_metadata.dataset_info is None: + row.input_metadata.dataset_info = {} + + # Apply result attributes as metadata + for attr_name, attr_value in vars(result).items(): + """ + Exclude rows and private attributes from metadata. + """ + if attr_name != "rows" and not attr_name.startswith("_"): + row.input_metadata.dataset_info[f"data_loader_{attr_name}"] = attr_value + + # Apply row counts + row.input_metadata.dataset_info["data_loader_num_rows"] = original_count + row.input_metadata.dataset_info["data_loader_num_rows_after_preprocessing"] = processed_count diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 47e3c9f4..a60ca51b 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -1,7 +1,6 @@ import asyncio import inspect import os -import sys import time from collections import defaultdict from typing import Any, Callable @@ -10,6 +9,7 @@ import pytest +from eval_protocol.data_loader.models import EvaluationDataLoader from eval_protocol.dataset_logger import default_logger from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.human_id import generate_id, num_combinations @@ -71,6 +71,7 @@ def evaluation_test( input_messages: Sequence[list[InputMessagesParam] | None] | None = None, input_dataset: Sequence[DatasetPathParam] | None = None, input_rows: Sequence[list[EvaluationRow]] | None = None, + data_loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None = None, dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter, # pyright: ignore[reportExplicitAny] rollout_processor: RolloutProcessor | None = None, evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None = None, @@ -133,6 +134,7 @@ def evaluation_test( input_rows: Pre-constructed EvaluationRow objects to use directly. This is useful when you want to provide EvaluationRow objects with custom metadata, input_messages, or other fields already populated. Will be passed as "input_dataset" to the test function. + input_loaders: Data loaders to use to load the input dataset. dataset_adapter: Function to convert the input dataset to a list of EvaluationRows. This is useful if you have a custom dataset format. completion_params: Generation parameters for the rollout. @@ -173,6 +175,11 @@ def evaluation_test( active_logger: DatasetLogger = logger if logger else default_logger + if data_loaders is not None and ( + input_dataset is not None or input_messages is not None or input_rows is not None + ): + raise ValueError("data_loaders cannot be combined with input_dataset, input_messages, or input_rows.") + # Optional global overrides via environment for ad-hoc experimentation # EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged # into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}'). @@ -198,6 +205,7 @@ def decorator( evaluation_test_kwargs, max_dataset_rows, combine_datasets, + data_loaders, ) if len(combinations) == 0: raise ValueError( @@ -213,6 +221,7 @@ def decorator( completion_params_provided, input_messages, input_rows, + data_loaders, evaluation_test_kwargs, ) @@ -221,10 +230,10 @@ def create_wrapper_with_signature() -> Callable[[], None]: # Create the function body that will be used invocation_id = generate_id() - # Store URL for viewing results (after all postprocessing is complete) - store_local_ui_results_url(invocation_id) - async def wrapper_body(**kwargs: Unpack[ParameterizedTestKwargs]) -> None: + # Store URL for viewing results (after all postprocessing is complete) + store_local_ui_results_url(invocation_id) + eval_metadata = None all_results: list[list[EvaluationRow]] = [[] for _ in range(num_runs)] @@ -240,7 +249,16 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo data: list[EvaluationRow] = [] # Track all rows processed in the current run for error logging processed_rows_in_run: list[EvaluationRow] = [] - if "dataset_path" in kwargs and kwargs["dataset_path"] is not None: + if "data_loaders" in kwargs and kwargs["data_loaders"] is not None: + data_loaders = kwargs["data_loaders"] + data_loaders_list = ( + [data_loaders] if isinstance(data_loaders, EvaluationDataLoader) else data_loaders + ) + for data_loader in data_loaders_list: + results = data_loader.load() + for result in results: + data.extend(result.rows) + elif "dataset_path" in kwargs and kwargs["dataset_path"] is not None: ds_arg: list[str] = kwargs["dataset_path"] # Support either a single path or a list of paths; if a list is provided, # concatenate the rows from each file in order. @@ -261,7 +279,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo else: raise ValueError("No input dataset, input messages, or input rows provided") - if preprocess_fn: + """ + data_loaders handles preprocess_fn internally so we want + to specially handle data_loaders here so we don't double + apply preprocess_fn. + """ + if preprocess_fn and not data_loaders: data = preprocess_fn(data) for row in data: diff --git a/eval_protocol/pytest/generate_parameter_combinations.py b/eval_protocol/pytest/generate_parameter_combinations.py index 99c37b74..cf2da2c6 100644 --- a/eval_protocol/pytest/generate_parameter_combinations.py +++ b/eval_protocol/pytest/generate_parameter_combinations.py @@ -1,4 +1,5 @@ from typing import TypedDict +from eval_protocol.data_loader.models import EvaluationDataLoader from eval_protocol.models import CompletionParams, EvaluationRow from eval_protocol.pytest.types import Dataset, DatasetPathParam, EvaluationInputParam, InputMessagesParam from eval_protocol.pytest.utils import parse_ep_max_rows @@ -21,6 +22,7 @@ InputMessagesKwarg = list[InputMessagesParam] | None InputRowsKwarg = Dataset | None EvaluationTestKwargs = EvaluationInputParam | None +DataLoadersKwarg = Sequence[EvaluationDataLoader] | EvaluationDataLoader | None CombinationTuple = tuple[ InputDatasetKwarg, @@ -28,6 +30,7 @@ InputMessagesKwarg, InputRowsKwarg, EvaluationTestKwargs, + DataLoadersKwarg, ] @@ -42,6 +45,7 @@ class ParameterizedTestKwargs(TypedDict, total=False): input_messages: InputMessagesKwarg input_rows: InputRowsKwarg evaluation_test_kwargs: EvaluationTestKwargs + data_loaders: DataLoadersKwarg def generate_parameter_combinations( @@ -52,6 +56,7 @@ def generate_parameter_combinations( evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None, max_dataset_rows: int | None, combine_datasets: bool, + data_loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None, ) -> list[CombinationTuple]: """ Generate all combinations of parameters for pytest parameterization. @@ -108,6 +113,12 @@ def generate_parameter_combinations( if evaluation_test_kwargs is None: evaluation_test_kwargs = [None] + data_loaders_list: Sequence[DataLoadersKwarg] = [] + if data_loaders is not None: + data_loaders_list = [data_loaders] if isinstance(data_loaders, EvaluationDataLoader) else data_loaders + else: + data_loaders_list = [None] + combinations: list[CombinationTuple] = [] # Generate all combinations @@ -116,11 +127,19 @@ def generate_parameter_combinations( for im in messages: for ir in input_rows: for etk in evaluation_test_kwargs: - # if no dataset, no messages, and no rows, raise an error - if ds is None and im is None and ir is None: - raise ValueError( - "No dataset, messages, or rows provided. Please provide at least one of input_dataset, input_messages, or input_rows." - ) - combinations.append((ds, cp, im, ir, etk)) + for dl in data_loaders_list: + # if no dataset, no messages, and no rows, raise an error + if ds is None and im is None and ir is None and dl is None: + raise ValueError( + "No dataset, messages, rows, or data loaders provided. Please provide at least one of input_dataset, input_messages, input_rows, or data_loaders." + ) + + # if more than one of dataset, messages, rows, or data loaders is provided, raise an error + non_none_count = sum(1 for x in [ds, im, ir, dl] if x is not None) + if non_none_count > 1: + raise ValueError( + "More than one of dataset, messages, rows, or data loaders provided. Please provide only one of input_dataset, input_messages, input_rows, or data_loaders." + ) + combinations.append((ds, cp, im, ir, etk, dl)) return combinations diff --git a/eval_protocol/pytest/parameterize.py b/eval_protocol/pytest/parameterize.py index a2140da5..f8c12259 100644 --- a/eval_protocol/pytest/parameterize.py +++ b/eval_protocol/pytest/parameterize.py @@ -5,6 +5,7 @@ from _pytest.mark import ParameterSet +from eval_protocol.data_loader.models import EvaluationDataLoader from eval_protocol.models import CompletionParams, EvaluationRow from eval_protocol.pytest.generate_parameter_combinations import CombinationTuple from eval_protocol.pytest.types import DatasetPathParam, EvaluationInputParam, InputMessagesParam, TestFunction @@ -165,7 +166,7 @@ def __init__(self, max_length: int = 200): def generate_id(self, combo: CombinationTuple) -> str | None: """Generate an ID for a parameter combination.""" - dataset, completion_params, messages, rows, evaluation_test_kwargs = combo + dataset, completion_params, messages, rows, evaluation_test_kwargs, data_loaders = combo if completion_params: id = self.generate_id_from_dict(completion_params, self.max_length) @@ -208,6 +209,7 @@ def pytest_parametrize( completion_params_provided: bool, input_messages: Sequence[list[InputMessagesParam] | None] | None, input_rows: Sequence[list[EvaluationRow]] | None, + data_loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None, evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None, id_generator: ParameterIdGenerator | None = None, ) -> ParametrizeArgs: @@ -231,6 +233,11 @@ def pytest_parametrize( argnames.append("dataset_path") sig_parameters.append("dataset_path") if completion_params is not None: + """ + manually adding completion_params as a pytest.mark.parametrize decorator + automatically adds it to the function signature so we only need to add + it if we provided completion_params using the evaluation_test decorator. + """ if completion_params_provided and not has_pytest_parametrize: argnames.append("completion_params") if has_pytest_parametrize or completion_params_provided: @@ -244,6 +251,9 @@ def pytest_parametrize( if evaluation_test_kwargs is not None: argnames.append("evaluation_test_kwargs") sig_parameters.append("evaluation_test_kwargs") + if data_loaders is not None: + argnames.append("data_loaders") + sig_parameters.append("data_loaders") # Use default ID generator if none provided if id_generator is None: @@ -253,7 +263,7 @@ def pytest_parametrize( ids: list[str] = [] for combo in combinations: - dataset, cp, messages, rows, etk = combo + dataset, cp, messages, rows, etk, dl = combo param_tuple: list[object] = [] # Build parameter tuple based on what's provided @@ -267,6 +277,8 @@ def pytest_parametrize( param_tuple.append(rows) if evaluation_test_kwargs is not None: param_tuple.append(etk) + if data_loaders is not None: + param_tuple.append(dl) # Validate parameter tuple length if len(argnames) != len(param_tuple): diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index ad48ad38..8d369e70 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -327,16 +327,16 @@ def _print_local_ui_results_urls(session): RESULTS_URLS_STASH_KEY = None # Get URLs from pytest stash - urls = [] + urls_dict = {} if RESULTS_URLS_STASH_KEY is not None and RESULTS_URLS_STASH_KEY in session.stash: - urls = session.stash[RESULTS_URLS_STASH_KEY] + urls_dict = session.stash[RESULTS_URLS_STASH_KEY] - if urls: + if urls_dict: print("\n" + "=" * 80, file=sys.__stderr__) print("📊 LOCAL UI EVALUATION RESULTS", file=sys.__stderr__) print("=" * 80, file=sys.__stderr__) - for url_data in urls: + for url_data in urls_dict.values(): print(f"📊 Invocation {url_data['invocation_id']}:", file=sys.__stderr__) print(f" 📊 Aggregate scores: {url_data['pivot_url']}", file=sys.__stderr__) print(f" 📋 Trajectories: {url_data['table_url']}", file=sys.__stderr__) diff --git a/eval_protocol/pytest/store_results_url.py b/eval_protocol/pytest/store_results_url.py index ccf48541..cc362e9b 100644 --- a/eval_protocol/pytest/store_results_url.py +++ b/eval_protocol/pytest/store_results_url.py @@ -8,7 +8,7 @@ class ResultsUrl(TypedDict): table_url: str -RESULTS_URLS_STASH_KEY = StashKey[list[ResultsUrl]]() +RESULTS_URLS_STASH_KEY = StashKey[dict[str, ResultsUrl]]() def _store_local_ui_url_in_stash(invocation_id: str, pivot_url: str, table_url: str): @@ -29,11 +29,14 @@ def _store_local_ui_url_in_stash(invocation_id: str, pivot_url: str, table_url: global RESULTS_URLS_STASH_KEY if RESULTS_URLS_STASH_KEY not in session.stash: # pyright: ignore[reportAny] - session.stash[RESULTS_URLS_STASH_KEY] = [] # pyright: ignore[reportAny] - - session.stash[RESULTS_URLS_STASH_KEY].append( # pyright: ignore[reportAny] - {"invocation_id": invocation_id, "pivot_url": pivot_url, "table_url": table_url} - ) + session.stash[RESULTS_URLS_STASH_KEY] = {} # pyright: ignore[reportAny] + + # Store by invocation_id as key - automatically handles deduplication + session.stash[RESULTS_URLS_STASH_KEY][invocation_id] = { # pyright: ignore[reportAny] + "invocation_id": invocation_id, + "pivot_url": pivot_url, + "table_url": table_url, + } else: pass diff --git a/eval_protocol/quickstart/llm_judge_braintrust.py b/eval_protocol/quickstart/llm_judge_braintrust.py index 91bce9cf..01619bcb 100644 --- a/eval_protocol/quickstart/llm_judge_braintrust.py +++ b/eval_protocol/quickstart/llm_judge_braintrust.py @@ -16,23 +16,22 @@ multi_turn_assistant_to_ground_truth, EvaluationRow, SingleTurnRolloutProcessor, + DynamicDataLoader, create_braintrust_adapter, - DefaultParameterIdGenerator, ) -# adapter = create_braintrust_adapter() -# input_rows = [ -# adapter.get_evaluation_rows( -# btql_query=f""" -# select: * -# from: project_logs('{os.getenv("BRAINTRUST_PROJECT_ID")}') traces -# filter: is_root = true -# limit: 10 -# """ -# ) -# ] -input_rows = [] + # uncomment when dataloader is fixed +def braintrust_data_generator(): + adapter = create_braintrust_adapter() + return adapter.get_evaluation_rows( + btql_query=f""" + select: * + from: project_logs('{os.getenv("BRAINTRUST_PROJECT_ID")}') traces + filter: is_root = true + limit: 10 + """ + ) @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI") @@ -53,7 +52,9 @@ ], ) @evaluation_test( - input_rows=[input_rows], + data_loaders=DynamicDataLoader( + generators=[braintrust_data_generator], + ), rollout_processor=SingleTurnRolloutProcessor(), preprocess_fn=multi_turn_assistant_to_ground_truth, max_concurrent_evaluations=2, diff --git a/eval_protocol/quickstart/llm_judge_langfuse.py b/eval_protocol/quickstart/llm_judge_langfuse.py index a8e92c05..bdcdb4dd 100644 --- a/eval_protocol/quickstart/llm_judge_langfuse.py +++ b/eval_protocol/quickstart/llm_judge_langfuse.py @@ -14,19 +14,21 @@ EvaluationRow, SingleTurnRolloutProcessor, create_langfuse_adapter, - DefaultParameterIdGenerator, + DynamicDataLoader, ) from eval_protocol.quickstart import aha_judge -adapter = create_langfuse_adapter() -input_rows = adapter.get_evaluation_rows( - to_timestamp=datetime(2025, 9, 12, 0, 11, 18), - limit=711, - sample_size=50, - sleep_between_gets=3.0, - max_retries=5, -) + +def langfuse_data_generator(): + adapter = create_langfuse_adapter() + return adapter.get_evaluation_rows( + to_timestamp=datetime(2025, 9, 12, 0, 11, 18), + limit=711, + sample_size=50, + sleep_between_gets=3.0, + max_retries=5, + ) @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI") @@ -47,7 +49,9 @@ ], ) @evaluation_test( - input_rows=[input_rows], + data_loaders=DynamicDataLoader( + generators=[langfuse_data_generator], + ), rollout_processor=SingleTurnRolloutProcessor(), preprocess_fn=multi_turn_assistant_to_ground_truth, max_concurrent_evaluations=2, diff --git a/eval_protocol/quickstart/llm_judge_langsmith.py b/eval_protocol/quickstart/llm_judge_langsmith.py index f62fdb28..be78b28f 100644 --- a/eval_protocol/quickstart/llm_judge_langsmith.py +++ b/eval_protocol/quickstart/llm_judge_langsmith.py @@ -31,32 +31,26 @@ EvaluationRow, SingleTurnRolloutProcessor, LangSmithAdapter, - DefaultParameterIdGenerator, + DynamicDataLoader, ) -def fetch_langsmith_traces_as_evaluation_rows( - project_name: Optional[str] = None, - limit: int = 20, -) -> List[EvaluationRow]: +def langsmith_data_generator() -> List[EvaluationRow]: """Fetch LangSmith root runs and convert to EvaluationRow, mirroring Langfuse adapter shape. - Extract messages from run.inputs and run.outputs - Append assistant message from outputs so we can derive ground_truth - Store run_id in input_metadata.session_data """ - project = project_name or os.getenv("LS_PROJECT", "ep-langgraph-examples") + project = os.getenv("LS_PROJECT", "ep-langgraph-examples") try: adapter = LangSmithAdapter() - return adapter.get_evaluation_rows(project_name=project, limit=limit, include_tool_calls=True) + return adapter.get_evaluation_rows(project_name=project, limit=20, include_tool_calls=True) except Exception as e: print(f"❌ LangSmithAdapter failed: {e}") return [] -input_rows = fetch_langsmith_traces_as_evaluation_rows() - - @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI") @pytest.mark.parametrize( "completion_params", @@ -72,7 +66,9 @@ def fetch_langsmith_traces_as_evaluation_rows( ], ) @evaluation_test( - input_rows=[input_rows], + data_loaders=DynamicDataLoader( + generators=[langsmith_data_generator], + ), rollout_processor=SingleTurnRolloutProcessor(), preprocess_fn=multi_turn_assistant_to_ground_truth, max_concurrent_evaluations=2, diff --git a/eval_protocol/quickstart/llm_judge_openai_responses.py b/eval_protocol/quickstart/llm_judge_openai_responses.py index a30feee0..a2b6c7c9 100644 --- a/eval_protocol/quickstart/llm_judge_openai_responses.py +++ b/eval_protocol/quickstart/llm_judge_openai_responses.py @@ -26,18 +26,20 @@ EvaluationRow, SingleTurnRolloutProcessor, OpenAIResponsesAdapter, - DefaultParameterIdGenerator, + DynamicDataLoader, ) -adapter = OpenAIResponsesAdapter() -input_rows = adapter.get_evaluation_rows( - response_ids=[ - "resp_0e1b7db5d96e92470068c99506443c819e9305e92915d2405f", - # "resp_05639dcaca074fbc0068c9946593b481908cac70075926d85c", - # "resp_0c96a910416e87aa0068c994d0b34c81a3bda0eddf22445aec", - # "resp_0efe023280e986f90068c994b85e088190bc8d8263fa603e02", - ] -) + +def openai_responses_data_generator(): + adapter = OpenAIResponsesAdapter() + return adapter.get_evaluation_rows( + response_ids=[ + "resp_0e1b7db5d96e92470068c99506443c819e9305e92915d2405f", + # "resp_05639dcaca074fbc0068c9946593b481908cac70075926d85c", + # "resp_0c96a910416e87aa0068c994d0b34c81a3bda0eddf22445aec", + # "resp_0efe023280e986f90068c994b85e088190bc8d8263fa603e02", + ] + ) @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI") @@ -53,7 +55,9 @@ ], ) @evaluation_test( - input_rows=[input_rows], + data_loaders=DynamicDataLoader( + generators=[openai_responses_data_generator], + ), rollout_processor=SingleTurnRolloutProcessor(), preprocess_fn=multi_turn_assistant_to_ground_truth, max_concurrent_evaluations=2, diff --git a/tests/chinook/braintrust/test_braintrust_chinook.py b/tests/chinook/braintrust/test_braintrust_chinook.py index d9c2a77d..dd2dc56c 100644 --- a/tests/chinook/braintrust/test_braintrust_chinook.py +++ b/tests/chinook/braintrust/test_braintrust_chinook.py @@ -1,12 +1,13 @@ import os from datetime import datetime, timedelta -from typing import List, Any, Dict +from typing import List import pytest from pydantic import BaseModel from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata from eval_protocol.pytest import evaluation_test, NoOpRolloutProcessor @@ -32,7 +33,7 @@ class Response(BaseModel): ) -def fetch_braintrust_traces_as_evaluation_rows(hours_back: int = 24) -> List[EvaluationRow]: +def braintrust_data_generator(hours_back: int = 24) -> List[EvaluationRow]: """ Dataset adapter: Use BraintrustAdapter to fetch traces from project logs. """ @@ -53,8 +54,12 @@ def fetch_braintrust_traces_as_evaluation_rows(hours_back: int = 24) -> List[Eva evaluation_rows = list( adapter.get_evaluation_rows( - from_timestamp=from_timestamp, - to_timestamp=now, + btql_query=f""" + select: * + from: project_logs('{os.getenv("BRAINTRUST_PROJECT_ID")}') traces + filter: is_root = true + limit: 10 + """, ) ) @@ -72,7 +77,9 @@ def fetch_braintrust_traces_as_evaluation_rows(hours_back: int = 24) -> List[Eva ) @pytest.mark.asyncio @evaluation_test( - input_rows=[fetch_braintrust_traces_as_evaluation_rows(hours_back=168)], # 1 week back + data_loaders=DynamicDataLoader( + generators=[braintrust_data_generator], + ), rollout_processor=NoOpRolloutProcessor(), # No-op since traces already exist mode="pointwise", ) diff --git a/tests/chinook/langfuse/test_langfuse_chinook.py b/tests/chinook/langfuse/test_langfuse_chinook.py index 1119b563..edb1eaa7 100644 --- a/tests/chinook/langfuse/test_langfuse_chinook.py +++ b/tests/chinook/langfuse/test_langfuse_chinook.py @@ -18,6 +18,7 @@ from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIChatModel +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata from eval_protocol.pytest import evaluation_test, NoOpRolloutProcessor @@ -45,17 +46,11 @@ class Response(BaseModel): reason: str -def fetch_langfuse_traces_as_evaluation_rows( - hours_back: int = 168, tags: List[str] = ["chinook_sql"] -) -> List[EvaluationRow]: +def langfuse_data_generator(hours_back: int = 168, tags: List[str] = ["chinook_sql"]) -> List[EvaluationRow]: try: from eval_protocol.adapters.langfuse import create_langfuse_adapter - adapter = create_langfuse_adapter( - public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), # pyright: ignore[reportArgumentType] - secret_key=os.getenv("LANGFUSE_SECRET_KEY"), # pyright: ignore[reportArgumentType] - host=os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com"), - ) + adapter = create_langfuse_adapter() now = datetime.now() from_timestamp = now - timedelta(hours=hours_back) @@ -72,7 +67,9 @@ def fetch_langfuse_traces_as_evaluation_rows( @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI") @pytest.mark.asyncio @evaluation_test( - input_rows=[fetch_langfuse_traces_as_evaluation_rows()], + data_loaders=DynamicDataLoader( + generators=[langfuse_data_generator], + ), rollout_processor=NoOpRolloutProcessor(), mode="pointwise", ) diff --git a/tests/chinook/langfuse/test_remote_langfuse_chinook.py b/tests/chinook/langfuse/test_remote_langfuse_chinook.py index 1b3951c9..c121a6e4 100644 --- a/tests/chinook/langfuse/test_remote_langfuse_chinook.py +++ b/tests/chinook/langfuse/test_remote_langfuse_chinook.py @@ -8,6 +8,7 @@ import pytest import requests +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor @@ -51,12 +52,11 @@ def _is_up() -> bool: return proc -# Ensure server is running BEFORE rollouts start (evaluation_test triggers rollouts before test body) -_SERVER_PROC = _ensure_server_running() -atexit.register(lambda: (_SERVER_PROC and _SERVER_PROC.is_alive() and _SERVER_PROC.terminate())) +def remote_langfuse_data_generator() -> List[EvaluationRow]: + # Ensure server is running BEFORE rollouts start (evaluation_test triggers rollouts before test body) + _SERVER_PROC = _ensure_server_running() + atexit.register(lambda: (_SERVER_PROC and _SERVER_PROC.is_alive() and _SERVER_PROC.terminate())) - -def _make_input_rows() -> List[EvaluationRow]: # Minimal single-user-turn message to trigger a response row = EvaluationRow(messages=[Message(role="user", content="Hello there! Please say hi back.")]) return [row] @@ -65,7 +65,9 @@ def _make_input_rows() -> List[EvaluationRow]: @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") @pytest.mark.asyncio @evaluation_test( - input_rows=[_make_input_rows()], + data_loaders=DynamicDataLoader( + generators=[remote_langfuse_data_generator], + ), completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}], rollout_processor=RemoteRolloutProcessor( remote_base_url="http://127.0.0.1:7077", diff --git a/tests/chinook/langsmith/test_langsmith_chinook.py b/tests/chinook/langsmith/test_langsmith_chinook.py index a2f03dcb..f5b600f4 100644 --- a/tests/chinook/langsmith/test_langsmith_chinook.py +++ b/tests/chinook/langsmith/test_langsmith_chinook.py @@ -10,6 +10,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata from eval_protocol.pytest import NoOpRolloutProcessor, evaluation_test +from eval_protocol.data_loader import DynamicDataLoader from tests.chinook.dataset import collect_dataset @@ -21,17 +22,6 @@ ADAPTER_AVAILABLE = False create_langsmith_adapter = None # type: ignore -try: - from langsmith import Client # type: ignore - - LANGSMITH_CLIENT: Optional[Client] - try: - LANGSMITH_CLIENT = Client() - except Exception as exc: # pragma: no cover - surfaced to the caller - print(f"⚠️ LangSmith client unavailable: {exc}") - LANGSMITH_CLIENT = None -except ImportError: # pragma: no cover - optional dependency - LANGSMITH_CLIENT = None PROJECT_NAME = os.getenv("LANGCHAIN_PROJECT") or os.getenv("LS_PROJECT") or "ep-chinook-langsmith" TRACE_TAGS = ["chinook_sql"] @@ -124,7 +114,9 @@ def fetch_langsmith_traces(limit: int = 20) -> List[EvaluationRow]: @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip LangSmith adapter test in CI") @pytest.mark.asyncio @evaluation_test( - input_rows=[fetch_langsmith_traces()], + data_loaders=DynamicDataLoader( + generators=[fetch_langsmith_traces], + ), rollout_processor=NoOpRolloutProcessor(), mode="pointwise", ) @@ -162,6 +154,18 @@ class JudgeResponse(BaseModel): reason=result.output.reason, ) + try: + from langsmith import Client # type: ignore + + LANGSMITH_CLIENT: Optional[Client] + try: + LANGSMITH_CLIENT = Client() + except Exception as exc: # pragma: no cover - surfaced to the caller + print(f"⚠️ LangSmith client unavailable: {exc}") + LANGSMITH_CLIENT = None + except ImportError: # pragma: no cover - optional dependency + LANGSMITH_CLIENT = None + if LANGSMITH_CLIENT and row.input_metadata and row.input_metadata.session_data: run_id = row.input_metadata.session_data.get("langsmith_run_id") if run_id: diff --git a/tests/data_loader/test_dynamic_data_loader.py b/tests/data_loader/test_dynamic_data_loader.py new file mode 100644 index 00000000..18780d9d --- /dev/null +++ b/tests/data_loader/test_dynamic_data_loader.py @@ -0,0 +1,48 @@ +from eval_protocol.data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test + + +def my_factory() -> list[EvaluationRow]: + """Factory function that generates evaluation rows dynamically.""" + return [EvaluationRow(messages=[Message(role="user", content="What is 2 + 2?")])] + + +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[my_factory], + ), +) +def test_dynamic_data_loader(row: EvaluationRow) -> EvaluationRow: + """Dynamic data loader should feed dynamically generated message bundles.""" + + assert row.messages[0].content == "What is 2 + 2?" + assert row.input_metadata.dataset_info is not None + assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "my_factory" + assert row.input_metadata.dataset_info.get("data_loader_num_rows") == 1 + assert row.input_metadata.dataset_info.get("data_loader_num_rows_after_preprocessing") == 1 + assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader" + assert ( + row.input_metadata.dataset_info.get("data_loader_variant_description") + == "Factory function that generates evaluation rows dynamically." + ) + assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False + return row + + +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[lambda: [EvaluationRow(messages=[Message(role="user", content="What is 3 * 3?")])]], + ), +) +def test_dynamic_data_loader_lambda(row: EvaluationRow) -> EvaluationRow: + """Dynamic data loader should work with lambda functions.""" + + assert row.messages[0].content == "What is 3 * 3?" + assert row.input_metadata.dataset_info is not None + assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "" + assert row.input_metadata.dataset_info.get("data_loader_num_rows") == 1 + assert row.input_metadata.dataset_info.get("data_loader_num_rows_after_preprocessing") == 1 + assert row.input_metadata.dataset_info.get("data_loader_type") == "DynamicDataLoader" + assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False + return row diff --git a/tests/data_loader/test_inline_data_loader.py b/tests/data_loader/test_inline_data_loader.py new file mode 100644 index 00000000..a6cde17c --- /dev/null +++ b/tests/data_loader/test_inline_data_loader.py @@ -0,0 +1,23 @@ +from eval_protocol.data_loader.inline_data_loader import InlineDataLoader +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor + + +@evaluation_test( + data_loaders=InlineDataLoader( + messages=[[Message(role="user", content="What is 2 + 2?")]], + ), +) +def test_inline_data_loader(row: EvaluationRow) -> EvaluationRow: + """Inline data loader should feed pre-constructed message bundles.""" + + assert row.messages[0].content == "What is 2 + 2?" + assert row.input_metadata.dataset_info is not None + assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "inline" + assert row.input_metadata.dataset_info.get("data_loader_num_rows") == 1 + assert row.input_metadata.dataset_info.get("data_loader_num_rows_after_preprocessing") == 1 + assert row.input_metadata.dataset_info.get("data_loader_type") == "InlineDataLoader" + assert row.input_metadata.dataset_info.get("data_loader_variant_description") is None + assert row.input_metadata.dataset_info.get("data_loader_preprocessed") is False + return row diff --git a/tests/pytest/test_parameterized_ids.py b/tests/pytest/test_parameterized_ids.py index d3363d0c..cacf0de8 100644 --- a/tests/pytest/test_parameterized_ids.py +++ b/tests/pytest/test_parameterized_ids.py @@ -91,47 +91,47 @@ def test_default_id_generator(): generator = DefaultParameterIdGenerator() # Test with full model path - combo1 = (None, {"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}, None, None, None) + combo1 = (None, {"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}, None, None, None, None) id1 = generator.generate_id(combo1) assert id1 == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b" # Test with simple model name - combo2 = (None, {"model": "gpt-4"}, None, None, None) + combo2 = (None, {"model": "gpt-4"}, None, None, None, None) id2 = generator.generate_id(combo2) assert id2 == "gpt-4" # Test with multiple string parameters - combo3 = (None, {"model": "gpt-4", "stream": "true", "temperature": "0.7"}, None, None, None) + combo3 = (None, {"model": "gpt-4", "stream": "true", "temperature": "0.7"}, None, None, None, None) id3 = generator.generate_id(combo3) assert id3 == "gpt-4:true:0.7" # Test with mixed string and numeric parameters - combo4 = (None, {"model": "gpt-4", "temperature": 0.7, "max_tokens": 100}, None, None, None) + combo4 = (None, {"model": "gpt-4", "temperature": 0.7, "max_tokens": 100}, None, None, None, None) id4 = generator.generate_id(combo4) assert id4 == "100:gpt-4:0.7" # Keys are sorted alphabetically: max_tokens, model, temperature # Test with only numeric values - combo5 = (None, {"temperature": 0.5, "max_tokens": 100}, None, None, None) + combo5 = (None, {"temperature": 0.5, "max_tokens": 100}, None, None, None, None) id5 = generator.generate_id(combo5) assert id5 == "100:0.5" # Keys are sorted alphabetically: max_tokens, temperature # Test with boolean values - combo6 = (None, {"stream": True, "echo": False}, None, None, None) + combo6 = (None, {"stream": True, "echo": False}, None, None, None, None) id6 = generator.generate_id(combo6) assert id6 == "False:True" # Keys are sorted alphabetically: echo, stream # Test with mixed string, numeric, and boolean values - combo7 = (None, {"model": "gpt-4", "temperature": 0.7, "stream": True}, None, None, None) + combo7 = (None, {"model": "gpt-4", "temperature": 0.7, "stream": True}, None, None, None, None) id7 = generator.generate_id(combo7) assert id7 == "gpt-4:True:0.7" # Keys are sorted alphabetically: model, stream, temperature # Test with no supported values (only non-supported types like lists, dicts) - combo8 = (None, {"messages": [{"role": "user"}], "config": {"key": "value"}}, None, None, None) + combo8 = (None, {"messages": [{"role": "user"}], "config": {"key": "value"}}, None, None, None, None) id8 = generator.generate_id(combo8) assert id8 is None # Test with None completion_params - combo9 = (None, None, None, None, None) + combo9 = (None, None, None, None, None, None) id9 = generator.generate_id(combo9) assert id9 is None @@ -141,9 +141,9 @@ def test_pytest_parametrize_with_custom_id_generator(): # Create test combinations combinations = [ - (None, {"model": "gpt-4"}, None, None, None), - (None, {"model": "claude-3"}, None, None, None), - (None, {"temperature": 0.5}, None, None, None), # Only numeric values + (None, {"model": "gpt-4"}, None, None, None, None), + (None, {"model": "claude-3"}, None, None, None, None), + (None, {"temperature": 0.5}, None, None, None, None), # Only numeric values ] # Test with default generator @@ -155,6 +155,7 @@ def test_pytest_parametrize_with_custom_id_generator(): completion_params_provided=True, input_messages=None, input_rows=None, + data_loaders=None, evaluation_test_kwargs=None, ) @@ -168,7 +169,7 @@ def test_id_generator_max_length(): generator = DefaultParameterIdGenerator(max_length=10) # Test with long model name - combo = (None, {"model": "very-long-model-name-that-exceeds-max-length"}, None, None, None) + combo = (None, {"model": "very-long-model-name-that-exceeds-max-length"}, None, None, None, None) id_str = generator.generate_id(combo) assert id_str == "very-lo..." assert len(id_str) <= 10 diff --git a/tests/test_batch_evaluation.py b/tests/test_batch_evaluation.py index 772b8290..06ca6e0b 100644 --- a/tests/test_batch_evaluation.py +++ b/tests/test_batch_evaluation.py @@ -33,7 +33,6 @@ def __init__(self, task_def: str, num_rollouts: int = 2, **kwargs): self.filter = kwargs.get("filter", None) -@pytest.mark.integration class TestBatchEvaluation: """Integration tests for batch evaluation functionality.""" @@ -1121,7 +1120,6 @@ def smart_move_generator(**kwargs): assert len(task_manager.server_ports) == 0 -@pytest.mark.integration class TestBatchEvaluationErrorHandling: """Test error handling in batch evaluation scenarios.""" diff --git a/tests/test_fireworks_api.py b/tests/test_fireworks_api.py index c419c544..647168c0 100644 --- a/tests/test_fireworks_api.py +++ b/tests/test_fireworks_api.py @@ -5,62 +5,64 @@ from eval_protocol.auth import get_fireworks_account_id, get_fireworks_api_key -# Get API key using the new auth module -api_key = get_fireworks_api_key() -if api_key: - print(f"API key retrieved via auth module: {api_key[:4]}...{api_key[-4:]}") -else: - print("No API key retrieved via auth module.") -# Get account ID using the new auth module -account_id = get_fireworks_account_id() -if account_id: - print(f"Account ID retrieved via auth module: {account_id}") -else: - print("No account ID retrieved via auth module.") +def test_fireworks_api(): + # Get API key using the new auth module + api_key = get_fireworks_api_key() + if api_key: + print(f"API key retrieved via auth module: {api_key[:4]}...{api_key[-4:]}") + else: + print("No API key retrieved via auth module.") -# Ensure api_key is not None for header construction, default to empty string if None -effective_api_key = api_key if api_key is not None else "" + # Get account ID using the new auth module + account_id = get_fireworks_account_id() + if account_id: + print(f"Account ID retrieved via auth module: {account_id}") + else: + print("No account ID retrieved via auth module.") -# Test API connection -try: - # Try listing models to verify API connectivity - headers = {"Authorization": f"Bearer {effective_api_key}"} - base_url = "https://api.fireworks.ai/v1" + # Ensure api_key is not None for header construction, default to empty string if None + effective_api_key = api_key if api_key is not None else "" - # Check if models endpoint works (to verify API connection) - models_url = f"{base_url}/models?limit=1" - print(f"Testing models endpoint: {models_url}") - response = requests.get(models_url, headers=headers) - print(f"Response: {response.status_code} - {response.reason}") - if response.status_code == 200: - print("Successfully connected to Fireworks API") - else: - print(f"Error response: {response.text}") + # Test API connection + try: + # Try listing models to verify API connectivity + headers = {"Authorization": f"Bearer {effective_api_key}"} + base_url = "https://api.fireworks.ai/v1" - if account_id: - # Check if the evaluations endpoint is available - eval_url = f"{base_url}/accounts/{account_id}/evaluations" - print(f"Testing evaluations endpoint: {eval_url}") - response = requests.get(eval_url, headers=headers) + # Check if models endpoint works (to verify API connection) + models_url = f"{base_url}/models?limit=1" + print(f"Testing models endpoint: {models_url}") + response = requests.get(models_url, headers=headers) print(f"Response: {response.status_code} - {response.reason}") - if response.status_code != 200: + if response.status_code == 200: + print("Successfully connected to Fireworks API") + else: print(f"Error response: {response.text}") - # Check if there's an evaluators endpoint - evaluators_url = f"{base_url}/accounts/{account_id}/evaluators" - print(f"Testing evaluators endpoint: {evaluators_url}") - response = requests.get(evaluators_url, headers=headers) - print(f"Response: {response.status_code} - {response.reason}") - if response.status_code != 200: - print(f"Error response: {response.text}") + if account_id: + # Check if the evaluations endpoint is available + eval_url = f"{base_url}/accounts/{account_id}/evaluations" + print(f"Testing evaluations endpoint: {eval_url}") + response = requests.get(eval_url, headers=headers) + print(f"Response: {response.status_code} - {response.reason}") + if response.status_code != 200: + print(f"Error response: {response.text}") - # Look for alternate endpoints - for endpoint in ["evaluation", "evaluator"]: - url = f"{base_url}/accounts/{account_id}/{endpoint}" - print(f"Testing alternate endpoint: {url}") - response = requests.get(url, headers=headers) + # Check if there's an evaluators endpoint + evaluators_url = f"{base_url}/accounts/{account_id}/evaluators" + print(f"Testing evaluators endpoint: {evaluators_url}") + response = requests.get(evaluators_url, headers=headers) print(f"Response: {response.status_code} - {response.reason}") + if response.status_code != 200: + print(f"Error response: {response.text}") + + # Look for alternate endpoints + for endpoint in ["evaluation", "evaluator"]: + url = f"{base_url}/accounts/{account_id}/{endpoint}" + print(f"Testing alternate endpoint: {url}") + response = requests.get(url, headers=headers) + print(f"Response: {response.status_code} - {response.reason}") -except Exception as e: - print(f"Error connecting to Fireworks API: {str(e)}") + except Exception as e: + print(f"Error connecting to Fireworks API: {str(e)}")