Skip to content

Commit 8e406ee

Browse files
author
Dylan Huang
committed
allow for manual parametrization using pytest
1 parent aa6077c commit 8e406ee

File tree

5 files changed

+147
-51
lines changed

5 files changed

+147
-51
lines changed

eval_protocol/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .typed_interface import reward_function
4040
from .quickstart import aha_judge, split_multi_turn_rows
4141
from .pytest import evaluation_test, SingleTurnRolloutProcessor
42+
from .pytest.parameterize import DefaultParameterIdGenerator
4243

4344
from .adapters import OpenAIResponsesAdapter
4445

@@ -61,6 +62,7 @@
6162
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
6263

6364
__all__ = [
65+
"DefaultParameterIdGenerator",
6466
"aha_judge",
6567
"split_multi_turn_rows",
6668
"evaluation_test",

eval_protocol/pytest/evaluation_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,14 @@ def evaluation_test(
158158
exception_handler_config: Configuration for exception handling and backoff retry logic.
159159
If not provided, a default configuration will be used with common retryable exceptions.
160160
"""
161+
# Default to [None] when completion_params is not provided
162+
# This allows evaluation-only tests (e.g., using NoOpRolloutProcessor)
163+
# to work without requiring model generation parameters
161164
if completion_params is None:
165+
completion_params_provided = False
162166
completion_params = [None]
167+
else:
168+
completion_params_provided = True
163169
if rollout_processor is None:
164170
rollout_processor = NoOpRolloutProcessor()
165171

@@ -201,6 +207,7 @@ def decorator(
201207
combinations,
202208
input_dataset,
203209
completion_params,
210+
completion_params_provided,
204211
input_messages,
205212
input_rows,
206213
evaluation_test_kwargs,
@@ -565,12 +572,14 @@ async def execute_run_with_progress(run_idx: int, config):
565572
return create_dynamically_parameterized_wrapper(
566573
test_func,
567574
wrapper_body,
568-
pytest_parametrize_args["argnames"],
575+
pytest_parametrize_args["sig_parameters"],
569576
)
570577

571578
# Create the pytest wrapper
572579
pytest_wrapper = create_wrapper_with_signature()
573-
pytest_wrapper = pytest.mark.parametrize(**pytest_parametrize_args)(pytest_wrapper)
580+
pytest_wrapper = pytest.mark.parametrize(**pytest_parametrize_args["pytest_parametrize_kwargs"])(
581+
pytest_wrapper
582+
)
574583
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)
575584

576585
# Create the dual mode wrapper

eval_protocol/pytest/parameterize.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,28 @@
99
from eval_protocol.pytest.types import DatasetPathParam, EvaluationInputParam, InputMessagesParam, TestFunction
1010

1111

12-
class PytestParametrizeArgs(TypedDict):
12+
class PytestMarkParametrizeKwargs(TypedDict):
1313
argnames: Sequence[str]
1414
argvalues: Iterable[ParameterSet | Sequence[object] | object]
1515
ids: Iterable[str] | None
1616

1717

18+
class ParametrizeArgs(TypedDict):
19+
"""
20+
This contains all the necessary information to properly hijack the test
21+
function's signature and dynamically inject usage of
22+
pytest.mark.parametrize. The two will differ when a user manually provides
23+
the pytest.mark.parametrize decorator instead of passing completion_params
24+
on their own.
25+
"""
26+
27+
# for create_dynamically_parameterized_wrapper
28+
sig_parameters: Sequence[str]
29+
30+
# for pytest.mark.parametrize
31+
pytest_parametrize_kwargs: PytestMarkParametrizeKwargs
32+
33+
1834
class ParameterIdGenerator(Protocol):
1935
"""Protocol for generating pytest parameter IDs from parameter combinations."""
2036

@@ -30,7 +46,7 @@ def generate_id(self, combo: CombinationTuple) -> str | None:
3046
...
3147

3248

33-
class DefaultParameterIdGenerator:
49+
class DefaultParameterIdGenerator(ParameterIdGenerator):
3450
"""Default ID generator that creates meaningful IDs from parameter combinations."""
3551

3652
def __init__(self, max_length: int = 200):
@@ -46,34 +62,48 @@ def generate_id(self, combo: CombinationTuple) -> str | None:
4662
dataset, completion_params, messages, rows, evaluation_test_kwargs = combo
4763

4864
if completion_params:
49-
# Get all string, numeric, and boolean values from completion_params, sorted by key
50-
str_values = []
51-
for key in sorted(completion_params.keys()):
52-
value = completion_params[key]
53-
if isinstance(value, (str, int, float, bool)):
54-
str_values.append(str(value))
65+
id = self.generate_id_from_dict(completion_params, self.max_length)
66+
if id:
67+
return id
68+
else:
69+
if rows:
70+
return f"rows(len={len(rows)})"
71+
elif messages:
72+
return f"messages(len={len(messages)})"
73+
elif dataset:
74+
return f"dataset(len={len(dataset)})"
75+
return None
5576

56-
if str_values:
57-
id_str = ":".join(str_values)
77+
@staticmethod
78+
def generate_id_from_dict(d: dict[str, object], max_length: int = 200) -> str | None:
79+
# Get all string, numeric, and boolean values from completion_params, sorted by key
80+
str_values = []
81+
for key in sorted(d.keys()):
82+
value = d[key]
83+
if isinstance(value, (str, int, float, bool)):
84+
str_values.append(str(value))
5885

59-
# Truncate if too long
60-
if len(id_str) > self.max_length:
61-
id_str = id_str[: self.max_length - 3] + "..."
86+
if str_values:
87+
id_str = ":".join(str_values)
6288

63-
return id_str
89+
# Truncate if too long
90+
if len(id_str) > max_length:
91+
id_str = id_str[: max_length - 3] + "..."
6492

93+
return id_str
6594
return None
6695

6796

6897
def pytest_parametrize(
6998
combinations: list[CombinationTuple],
7099
input_dataset: Sequence[DatasetPathParam] | None,
71100
completion_params: Sequence[CompletionParams | None] | None,
101+
completion_params_provided: bool,
72102
input_messages: Sequence[list[InputMessagesParam] | None] | None,
73103
input_rows: Sequence[list[EvaluationRow]] | None,
74104
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None,
75105
id_generator: ParameterIdGenerator | None = None,
76-
) -> PytestParametrizeArgs:
106+
) -> ParametrizeArgs:
77107
"""
78108
This function dynamically generates pytest.mark.parametrize arguments for a given
79109
set of combinations. This is the magic that allows developers to pass in their
@@ -84,16 +114,23 @@ def pytest_parametrize(
84114

85115
# Create parameter tuples for pytest.mark.parametrize
86116
argnames: list[str] = []
117+
sig_parameters: list[str] = []
87118
if input_dataset is not None:
88119
argnames.append("dataset_path")
120+
sig_parameters.append("dataset_path")
89121
if completion_params is not None:
90-
argnames.append("completion_params")
122+
if completion_params_provided:
123+
argnames.append("completion_params")
124+
sig_parameters.append("completion_params")
91125
if input_messages is not None:
92126
argnames.append("input_messages")
127+
sig_parameters.append("input_messages")
93128
if input_rows is not None:
94129
argnames.append("input_rows")
130+
sig_parameters.append("input_rows")
95131
if evaluation_test_kwargs is not None:
96132
argnames.append("evaluation_test_kwargs")
133+
sig_parameters.append("evaluation_test_kwargs")
97134

98135
# Use default ID generator if none provided
99136
if id_generator is None:
@@ -109,7 +146,7 @@ def pytest_parametrize(
109146
# Build parameter tuple based on what's provided
110147
if input_dataset is not None:
111148
param_tuple.append(dataset)
112-
if completion_params is not None:
149+
if completion_params_provided:
113150
param_tuple.append(cp)
114151
if input_messages is not None:
115152
param_tuple.append(messages)
@@ -132,7 +169,12 @@ def pytest_parametrize(
132169
ids.append(combo_id)
133170

134171
# Return None for ids if no IDs were generated (let pytest use defaults)
135-
return PytestParametrizeArgs(argnames=argnames, argvalues=argvalues, ids=ids if ids else None)
172+
return ParametrizeArgs(
173+
pytest_parametrize_kwargs=PytestMarkParametrizeKwargs(
174+
argnames=argnames, argvalues=argvalues, ids=ids if ids else None
175+
),
176+
sig_parameters=sig_parameters,
177+
)
136178

137179

138180
def create_dynamically_parameterized_wrapper(

eval_protocol/quickstart/llm_judge_openai_responses.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
EvaluationRow,
2828
SingleTurnRolloutProcessor,
2929
OpenAIResponsesAdapter,
30+
DefaultParameterIdGenerator,
3031
)
3132

3233
adapter = OpenAIResponsesAdapter()
@@ -41,17 +42,20 @@
4142

4243

4344
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI")
44-
@pytest.mark.asyncio
45-
@evaluation_test(
46-
input_rows=[input_rows],
47-
completion_params=[
45+
@pytest.mark.parametrize(
46+
"completion_params",
47+
[
4848
{
4949
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
5050
},
5151
{
5252
"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct-0905",
5353
},
5454
],
55+
ids=DefaultParameterIdGenerator.generate_id_from_dict,
56+
)
57+
@evaluation_test(
58+
input_rows=[input_rows],
5559
rollout_processor=SingleTurnRolloutProcessor(),
5660
preprocess_fn=split_multi_turn_rows,
5761
mode="all",

tests/pytest/test_parameterized_ids.py

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,47 @@
1+
from collections.abc import Awaitable, Callable
2+
3+
import pytest
14
from eval_protocol.models import EvaluationRow, Message
25
from eval_protocol.pytest import evaluation_test
36
from eval_protocol.pytest.parameterize import DefaultParameterIdGenerator, pytest_parametrize
47
from eval_protocol.pytest.generate_parameter_combinations import generate_parameter_combinations
8+
from eval_protocol.pytest.types import TestFunction
9+
10+
11+
def verify_parametrize_mark(test_function: TestFunction, expected_ids_set: list[object]):
12+
# The function should exist and be callable
13+
assert test_function is not None
14+
assert callable(test_function)
15+
16+
# Test that the decorator was applied (function should have pytest marks)
17+
import pytest
18+
19+
marks = getattr(test_function, "pytestmark", [])
20+
assert len(marks) > 0, "Function should have pytest marks from evaluation_test decorator"
21+
22+
# Verify it's a parametrize mark
23+
parametrize_marks = [mark for mark in marks if hasattr(mark, "name") and mark.name == "parametrize"]
24+
assert len(parametrize_marks) > 0, "Should have parametrize mark"
25+
26+
assert len(parametrize_marks) == len(expected_ids_set), (
27+
f"Expected {len(expected_ids_set)} parametrize marks, got {len(parametrize_marks)}"
28+
)
29+
30+
# Check that the parametrize mark has IDs
31+
for parametrize_mark, expected_ids in zip(parametrize_marks, expected_ids_set):
32+
assert hasattr(parametrize_mark, "kwargs"), "Parametrize mark should have kwargs"
33+
assert "ids" in parametrize_mark.kwargs, "Should have ids in kwargs"
34+
35+
# Extract the IDs from the parametrize mark
36+
ids = parametrize_mark.kwargs.get("ids")
37+
if not ids:
38+
raise ValueError("No IDs found in parametrize mark")
39+
# Should have IDs for all parameters that have string/numeric values
40+
assert ids == expected_ids, f"Expected {expected_ids}, got {ids}"
541

642

743
def test_parameterized_ids():
844
"""Test that evaluation_test generates proper parameter IDs."""
9-
collected_ids = []
1045

1146
@evaluation_test(
1247
input_messages=[[[Message(role="user", content="Hello, how are you?")]]],
@@ -17,35 +52,38 @@ def test_parameterized_ids():
1752
],
1853
)
1954
def test_parameterized_ids(row: EvaluationRow) -> EvaluationRow:
20-
# Collect the row to verify it was processed
21-
collected_ids.append(row.input_metadata.row_id)
2255
return row
2356

24-
# The function should exist and be callable
25-
assert test_parameterized_ids is not None
26-
assert callable(test_parameterized_ids)
27-
28-
# Test that the decorator was applied (function should have pytest marks)
29-
import pytest
57+
verify_parametrize_mark(
58+
test_parameterized_ids, [["fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "gpt-4", "0.5"]]
59+
)
3060

31-
marks = getattr(test_parameterized_ids, "pytestmark", [])
32-
assert len(marks) > 0, "Function should have pytest marks from evaluation_test decorator"
3361

34-
# Verify it's a parametrize mark
35-
parametrize_marks = [mark for mark in marks if hasattr(mark, "name") and mark.name == "parametrize"]
36-
assert len(parametrize_marks) > 0, "Should have parametrize mark"
62+
def test_parametrized_ids_with_manual_decorator_and_input_rows():
63+
"""Test that evaluation_test generates proper parameter IDs."""
3764

38-
# Check that the parametrize mark has IDs
39-
parametrize_mark = parametrize_marks[0]
40-
assert hasattr(parametrize_mark, "kwargs"), "Parametrize mark should have kwargs"
41-
assert "ids" in parametrize_mark.kwargs, "Should have ids in kwargs"
65+
@pytest.mark.parametrize(
66+
"completion_params",
67+
[
68+
{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"},
69+
{"model": "gpt-4"},
70+
{"temperature": 0.5},
71+
],
72+
ids=DefaultParameterIdGenerator.generate_id_from_dict,
73+
)
74+
@evaluation_test(
75+
input_rows=[[EvaluationRow(messages=[Message(role="user", content="Hello, how are you?")])]],
76+
)
77+
def test_parameterized_ids(row: EvaluationRow) -> EvaluationRow:
78+
return row
4279

43-
# Extract the IDs from the parametrize mark
44-
ids = parametrize_mark.kwargs.get("ids")
45-
if ids is not None:
46-
# Should have IDs for all parameters that have string/numeric values
47-
expected_ids = ["fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "gpt-4", "0.5"]
48-
assert list(ids) == expected_ids, f"Expected {expected_ids}, got {list(ids)}"
80+
verify_parametrize_mark(
81+
test_parameterized_ids,
82+
[
83+
["rows(len=1)"],
84+
DefaultParameterIdGenerator.generate_id_from_dict,
85+
],
86+
)
4987

5088

5189
def test_default_id_generator():
@@ -113,14 +151,15 @@ def test_pytest_parametrize_with_custom_id_generator():
113151
combinations=combinations,
114152
input_dataset=None,
115153
completion_params=[{"model": "gpt-4"}, {"model": "claude-3"}, {"temperature": 0.5}],
154+
completion_params_provided=True,
116155
input_messages=None,
117156
input_rows=None,
118157
evaluation_test_kwargs=None,
119158
)
120159

121-
assert result["argnames"] == ["completion_params"]
122-
assert len(list(result["argvalues"])) == 3
123-
assert result["ids"] == ["gpt-4", "claude-3", "0.5"] # All have string/numeric values
160+
assert result["pytest_parametrize_kwargs"]["argnames"] == ["completion_params"]
161+
assert len(list(result["pytest_parametrize_kwargs"]["argvalues"])) == 3
162+
assert result["pytest_parametrize_kwargs"]["ids"] == ["gpt-4", "claude-3", "0.5"] # All have string/numeric values
124163

125164

126165
def test_id_generator_max_length():

0 commit comments

Comments
 (0)