Skip to content

Commit ea9ec83

Browse files
author
Dylan Huang
committed
add parameterized ids
1 parent f77ce2f commit ea9ec83

File tree

2 files changed

+177
-3
lines changed

2 files changed

+177
-3
lines changed

eval_protocol/pytest/parameterize.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import TypedDict
2+
from typing import TypedDict, Protocol
33
from collections.abc import Callable, Sequence, Iterable, Awaitable
44

55
from _pytest.mark import ParameterSet
@@ -12,6 +12,54 @@
1212
class PytestParametrizeArgs(TypedDict):
1313
argnames: Sequence[str]
1414
argvalues: Iterable[ParameterSet | Sequence[object] | object]
15+
ids: Iterable[str] | None
16+
17+
18+
class ParameterIdGenerator(Protocol):
19+
"""Protocol for generating pytest parameter IDs from parameter combinations."""
20+
21+
def generate_id(self, combo: CombinationTuple) -> str | None:
22+
"""Generate an ID for a parameter combination.
23+
24+
Args:
25+
combo: The parameter combination tuple
26+
27+
Returns:
28+
A string ID for the combination, or None to use default pytest ID
29+
"""
30+
...
31+
32+
33+
class DefaultParameterIdGenerator:
34+
"""Default ID generator that creates meaningful IDs from parameter combinations."""
35+
36+
def __init__(self, max_length: int = 50):
37+
"""Initialize the ID generator with configuration options.
38+
39+
Args:
40+
max_length: Maximum length of generated IDs
41+
"""
42+
self.max_length = max_length
43+
44+
def generate_id(self, combo: CombinationTuple) -> str | None:
45+
"""Generate an ID for a parameter combination."""
46+
dataset, completion_params, messages, rows, evaluation_test_kwargs = combo
47+
48+
# Add model name if available
49+
if completion_params:
50+
model = completion_params.get("model")
51+
if model:
52+
# Extract just the model name, not the full path
53+
model_name = model.split("/")[-1] if "/" in model else model
54+
id_str = f"model-{model_name}"
55+
56+
# Truncate if too long
57+
if len(id_str) > self.max_length:
58+
id_str = id_str[: self.max_length - 3] + "..."
59+
60+
return id_str
61+
62+
return None
1563

1664

1765
def pytest_parametrize(
@@ -21,6 +69,7 @@ def pytest_parametrize(
2169
input_messages: Sequence[list[InputMessagesParam] | None] | None,
2270
input_rows: Sequence[list[EvaluationRow]] | None,
2371
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None,
72+
id_generator: ParameterIdGenerator | None = None,
2473
) -> PytestParametrizeArgs:
2574
"""
2675
This function dynamically generates pytest.mark.parametrize arguments for a given
@@ -43,10 +92,18 @@ def pytest_parametrize(
4392
if evaluation_test_kwargs is not None:
4493
argnames.append("evaluation_test_kwargs")
4594

95+
# Use default ID generator if none provided
96+
if id_generator is None:
97+
id_generator = DefaultParameterIdGenerator()
98+
4699
argvalues: list[ParameterSet | Sequence[object] | object] = []
100+
ids: list[str] = []
101+
47102
for combo in combinations:
48103
dataset, cp, messages, rows, etk = combo
49104
param_tuple: list[object] = []
105+
106+
# Build parameter tuple based on what's provided
50107
if input_dataset is not None:
51108
param_tuple.append(dataset)
52109
if completion_params is not None:
@@ -57,14 +114,22 @@ def pytest_parametrize(
57114
param_tuple.append(rows)
58115
if evaluation_test_kwargs is not None:
59116
param_tuple.append(etk)
60-
# do validation that the length of argnames is the same as the length of param_tuple
117+
118+
# Validate parameter tuple length
61119
if len(argnames) != len(param_tuple):
62120
raise ValueError(
63121
f"The length of argnames ({len(argnames)}) is not the same as the length of param_tuple ({len(param_tuple)})"
64122
)
123+
65124
argvalues.append(tuple(param_tuple))
66125

67-
return PytestParametrizeArgs(argnames=argnames, argvalues=argvalues)
126+
# Generate ID for this combination
127+
combo_id = id_generator.generate_id(combo)
128+
if combo_id is not None:
129+
ids.append(combo_id)
130+
131+
# Return None for ids if no IDs were generated (let pytest use defaults)
132+
return PytestParametrizeArgs(argnames=argnames, argvalues=argvalues, ids=ids if ids else None)
68133

69134

70135
def create_dynamically_parameterized_wrapper(
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from eval_protocol.models import EvaluationRow, Message
2+
from eval_protocol.pytest import evaluation_test
3+
from eval_protocol.pytest.parameterize import DefaultParameterIdGenerator, pytest_parametrize
4+
from eval_protocol.pytest.generate_parameter_combinations import generate_parameter_combinations
5+
6+
7+
def test_parameterized_ids():
8+
"""Test that evaluation_test generates proper parameter IDs."""
9+
collected_ids = []
10+
11+
@evaluation_test(
12+
input_messages=[[[Message(role="user", content="Hello, how are you?")]]],
13+
completion_params=[
14+
{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"},
15+
{"model": "gpt-4"},
16+
{"temperature": 0.5}, # No model - should not generate ID
17+
],
18+
)
19+
def test_parameterized_ids(row: EvaluationRow) -> EvaluationRow:
20+
# Collect the row to verify it was processed
21+
collected_ids.append(row.input_metadata.row_id)
22+
return row
23+
24+
# The function should exist and be callable
25+
assert test_parameterized_ids is not None
26+
assert callable(test_parameterized_ids)
27+
28+
# Test that the decorator was applied (function should have pytest marks)
29+
import pytest
30+
31+
marks = getattr(test_parameterized_ids, "pytestmark", [])
32+
assert len(marks) > 0, "Function should have pytest marks from evaluation_test decorator"
33+
34+
# Verify it's a parametrize mark
35+
parametrize_marks = [mark for mark in marks if hasattr(mark, "name") and mark.name == "parametrize"]
36+
assert len(parametrize_marks) > 0, "Should have parametrize mark"
37+
38+
# Check that the parametrize mark has IDs
39+
parametrize_mark = parametrize_marks[0]
40+
assert hasattr(parametrize_mark, "kwargs"), "Parametrize mark should have kwargs"
41+
assert "ids" in parametrize_mark.kwargs, "Should have ids in kwargs"
42+
43+
# Extract the IDs from the parametrize mark
44+
ids = parametrize_mark.kwargs.get("ids")
45+
if ids is not None:
46+
# Should have IDs for models but not for temperature-only params
47+
expected_ids = ["model-gpt-oss-120b", "model-gpt-4"]
48+
assert list(ids) == expected_ids, f"Expected {expected_ids}, got {list(ids)}"
49+
50+
51+
def test_default_id_generator():
52+
"""Test the DefaultParameterIdGenerator with various model names."""
53+
generator = DefaultParameterIdGenerator()
54+
55+
# Test with full model path
56+
combo1 = (None, {"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}, None, None, None)
57+
id1 = generator.generate_id(combo1)
58+
assert id1 == "model-gpt-oss-120b"
59+
60+
# Test with simple model name
61+
combo2 = (None, {"model": "gpt-4"}, None, None, None)
62+
id2 = generator.generate_id(combo2)
63+
assert id2 == "model-gpt-4"
64+
65+
# Test with no model
66+
combo3 = (None, {"temperature": 0.5}, None, None, None)
67+
id3 = generator.generate_id(combo3)
68+
assert id3 is None
69+
70+
# Test with None completion_params
71+
combo4 = (None, None, None, None, None)
72+
id4 = generator.generate_id(combo4)
73+
assert id4 is None
74+
75+
76+
def test_pytest_parametrize_with_custom_id_generator():
77+
"""Test pytest_parametrize with a custom ID generator."""
78+
79+
# Create test combinations
80+
combinations = [
81+
(None, {"model": "gpt-4"}, None, None, None),
82+
(None, {"model": "claude-3"}, None, None, None),
83+
(None, {"temperature": 0.5}, None, None, None), # No model
84+
]
85+
86+
# Test with default generator
87+
result = pytest_parametrize(
88+
combinations=combinations,
89+
input_dataset=None,
90+
completion_params=[{"model": "gpt-4"}, {"model": "claude-3"}, {"temperature": 0.5}],
91+
input_messages=None,
92+
input_rows=None,
93+
evaluation_test_kwargs=None,
94+
)
95+
96+
assert result["argnames"] == ["completion_params"]
97+
assert len(list(result["argvalues"])) == 3
98+
assert result["ids"] == ["model-gpt-4", "model-claude-3"] # None for no model
99+
100+
101+
def test_id_generator_max_length():
102+
"""Test that ID generator respects max_length parameter."""
103+
generator = DefaultParameterIdGenerator(max_length=10)
104+
105+
# Test with long model name
106+
combo = (None, {"model": "very-long-model-name-that-exceeds-max-length"}, None, None, None)
107+
id_str = generator.generate_id(combo)
108+
assert id_str == "model-v..."
109+
assert len(id_str) <= 10

0 commit comments

Comments
 (0)