Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
im = kwargs["input_messages"]
data = [EvaluationRow(messages=dataset_messages) for dataset_messages in im]
elif "input_rows" in kwargs and kwargs["input_rows"] is not None:
# Use pre-constructed EvaluationRow objects directly
data = kwargs["input_rows"]
# Deep copy pre-constructed EvaluationRow objects
data = [row.model_copy(deep=True) for row in kwargs["input_rows"]]
else:
raise ValueError("No input dataset, input messages, or input rows provided")

Expand Down
71 changes: 68 additions & 3 deletions eval_protocol/pytest/parameterize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import TypedDict
from typing import TypedDict, Protocol
from collections.abc import Callable, Sequence, Iterable, Awaitable

from _pytest.mark import ParameterSet
Expand All @@ -12,6 +12,54 @@
class PytestParametrizeArgs(TypedDict):
argnames: Sequence[str]
argvalues: Iterable[ParameterSet | Sequence[object] | object]
ids: Iterable[str] | None


class ParameterIdGenerator(Protocol):
"""Protocol for generating pytest parameter IDs from parameter combinations."""

def generate_id(self, combo: CombinationTuple) -> str | None:
"""Generate an ID for a parameter combination.

Args:
combo: The parameter combination tuple

Returns:
A string ID for the combination, or None to use default pytest ID
"""
...


class DefaultParameterIdGenerator:
"""Default ID generator that creates meaningful IDs from parameter combinations."""

def __init__(self, max_length: int = 50):
"""Initialize the ID generator with configuration options.

Args:
max_length: Maximum length of generated IDs
"""
self.max_length = max_length

def generate_id(self, combo: CombinationTuple) -> str | None:
"""Generate an ID for a parameter combination."""
dataset, completion_params, messages, rows, evaluation_test_kwargs = combo

# Add model name if available
if completion_params:
model = completion_params.get("model")
if model:
# Extract just the model name, not the full path
model_name = model.split("/")[-1] if "/" in model else model
id_str = f"model-{model_name}"

# Truncate if too long
if len(id_str) > self.max_length:
id_str = id_str[: self.max_length - 3] + "..."

return id_str

return None


def pytest_parametrize(
Expand All @@ -21,6 +69,7 @@ def pytest_parametrize(
input_messages: Sequence[list[InputMessagesParam] | None] | None,
input_rows: Sequence[list[EvaluationRow]] | None,
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None,
id_generator: ParameterIdGenerator | None = None,
) -> PytestParametrizeArgs:
"""
This function dynamically generates pytest.mark.parametrize arguments for a given
Expand All @@ -43,10 +92,18 @@ def pytest_parametrize(
if evaluation_test_kwargs is not None:
argnames.append("evaluation_test_kwargs")

# Use default ID generator if none provided
if id_generator is None:
id_generator = DefaultParameterIdGenerator()

argvalues: list[ParameterSet | Sequence[object] | object] = []
ids: list[str] = []

for combo in combinations:
dataset, cp, messages, rows, etk = combo
param_tuple: list[object] = []

# Build parameter tuple based on what's provided
if input_dataset is not None:
param_tuple.append(dataset)
if completion_params is not None:
Expand All @@ -57,14 +114,22 @@ def pytest_parametrize(
param_tuple.append(rows)
if evaluation_test_kwargs is not None:
param_tuple.append(etk)
# do validation that the length of argnames is the same as the length of param_tuple

# Validate parameter tuple length
if len(argnames) != len(param_tuple):
raise ValueError(
f"The length of argnames ({len(argnames)}) is not the same as the length of param_tuple ({len(param_tuple)})"
)

argvalues.append(tuple(param_tuple))

return PytestParametrizeArgs(argnames=argnames, argvalues=argvalues)
# Generate ID for this combination
combo_id = id_generator.generate_id(combo)
if combo_id is not None:
ids.append(combo_id)

# Return None for ids if no IDs were generated (let pytest use defaults)
return PytestParametrizeArgs(argnames=argnames, argvalues=argvalues, ids=ids if ids else None)


def create_dynamically_parameterized_wrapper(
Expand Down
109 changes: 109 additions & 0 deletions tests/pytest/test_parameterized_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.parameterize import DefaultParameterIdGenerator, pytest_parametrize
from eval_protocol.pytest.generate_parameter_combinations import generate_parameter_combinations


def test_parameterized_ids():
"""Test that evaluation_test generates proper parameter IDs."""
collected_ids = []

@evaluation_test(
input_messages=[[[Message(role="user", content="Hello, how are you?")]]],
completion_params=[
{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"},
{"model": "gpt-4"},
{"temperature": 0.5}, # No model - should not generate ID
],
)
def test_parameterized_ids(row: EvaluationRow) -> EvaluationRow:
# Collect the row to verify it was processed
collected_ids.append(row.input_metadata.row_id)
return row

# The function should exist and be callable
assert test_parameterized_ids is not None
assert callable(test_parameterized_ids)

# Test that the decorator was applied (function should have pytest marks)
import pytest

marks = getattr(test_parameterized_ids, "pytestmark", [])
assert len(marks) > 0, "Function should have pytest marks from evaluation_test decorator"

# Verify it's a parametrize mark
parametrize_marks = [mark for mark in marks if hasattr(mark, "name") and mark.name == "parametrize"]
assert len(parametrize_marks) > 0, "Should have parametrize mark"

# Check that the parametrize mark has IDs
parametrize_mark = parametrize_marks[0]
assert hasattr(parametrize_mark, "kwargs"), "Parametrize mark should have kwargs"
assert "ids" in parametrize_mark.kwargs, "Should have ids in kwargs"

# Extract the IDs from the parametrize mark
ids = parametrize_mark.kwargs.get("ids")
if ids is not None:
# Should have IDs for models but not for temperature-only params
expected_ids = ["model-gpt-oss-120b", "model-gpt-4"]
assert list(ids) == expected_ids, f"Expected {expected_ids}, got {list(ids)}"


def test_default_id_generator():
"""Test the DefaultParameterIdGenerator with various model names."""
generator = DefaultParameterIdGenerator()

# Test with full model path
combo1 = (None, {"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}, None, None, None)
id1 = generator.generate_id(combo1)
assert id1 == "model-gpt-oss-120b"

# Test with simple model name
combo2 = (None, {"model": "gpt-4"}, None, None, None)
id2 = generator.generate_id(combo2)
assert id2 == "model-gpt-4"

# Test with no model
combo3 = (None, {"temperature": 0.5}, None, None, None)
id3 = generator.generate_id(combo3)
assert id3 is None

# Test with None completion_params
combo4 = (None, None, None, None, None)
id4 = generator.generate_id(combo4)
assert id4 is None


def test_pytest_parametrize_with_custom_id_generator():
"""Test pytest_parametrize with a custom ID generator."""

# Create test combinations
combinations = [
(None, {"model": "gpt-4"}, None, None, None),
(None, {"model": "claude-3"}, None, None, None),
(None, {"temperature": 0.5}, None, None, None), # No model
]

# Test with default generator
result = pytest_parametrize(
combinations=combinations,
input_dataset=None,
completion_params=[{"model": "gpt-4"}, {"model": "claude-3"}, {"temperature": 0.5}],
input_messages=None,
input_rows=None,
evaluation_test_kwargs=None,
)

assert result["argnames"] == ["completion_params"]
assert len(list(result["argvalues"])) == 3
assert result["ids"] == ["model-gpt-4", "model-claude-3"] # None for no model


def test_id_generator_max_length():
"""Test that ID generator respects max_length parameter."""
generator = DefaultParameterIdGenerator(max_length=10)

# Test with long model name
combo = (None, {"model": "very-long-model-name-that-exceeds-max-length"}, None, None, None)
id_str = generator.generate_id(combo)
assert id_str == "model-v..."
assert len(id_str) <= 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test


@evaluation_test(
completion_params=[{"model": "gpt-4"}, {"model": "gpt-4o"}],
input_rows=[[EvaluationRow(messages=[Message(role="user", content="Hello, how are you?")])]],
evaluation_test_kwargs=[{"seen_models": set()}],
)
def test_pytest_input_rows_parametrized_completion_params(row: EvaluationRow, **kwargs) -> EvaluationRow:
"""Tests that parametrized completion params are working correctly for input_rows"""
seen_models = kwargs["seen_models"]
model = row.input_metadata.completion_params["model"]
if len(seen_models) == 1:
# assert that the other model was seen
if model == "gpt-4":
assert "gpt-4o" in seen_models
else:
assert "gpt-4" in seen_models
seen_models.add(model)
return row
Loading