Skip to content

Commit 057e132

Browse files
author
Dylan Huang
authored
Supports explicit pytest parametrization (#190)
* v2 proposal * allow for manual parametrization using pytest * delete proposal * test_import_logs works * add ids
1 parent 64abf2d commit 057e132

File tree

6 files changed

+264
-53
lines changed

6 files changed

+264
-53
lines changed

eval_protocol/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .typed_interface import reward_function
4040
from .quickstart import aha_judge, split_multi_turn_rows
4141
from .pytest import evaluation_test, SingleTurnRolloutProcessor
42+
from .pytest.parameterize import DefaultParameterIdGenerator
4243

4344
from .adapters import OpenAIResponsesAdapter
4445

@@ -61,6 +62,7 @@
6162
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
6263

6364
__all__ = [
65+
"DefaultParameterIdGenerator",
6466
"aha_judge",
6567
"split_multi_turn_rows",
6668
"evaluation_test",

eval_protocol/pytest/evaluation_test.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,14 @@ def evaluation_test(
158158
exception_handler_config: Configuration for exception handling and backoff retry logic.
159159
If not provided, a default configuration will be used with common retryable exceptions.
160160
"""
161+
# Default to [None] when completion_params is not provided
162+
# This allows evaluation-only tests (e.g., using NoOpRolloutProcessor)
163+
# to work without requiring model generation parameters
161164
if completion_params is None:
165+
completion_params_provided = False
162166
completion_params = [None]
167+
else:
168+
completion_params_provided = True
163169
if rollout_processor is None:
164170
rollout_processor = NoOpRolloutProcessor()
165171

@@ -199,8 +205,10 @@ def decorator(
199205
# Create parameter tuples for pytest.mark.parametrize
200206
pytest_parametrize_args = pytest_parametrize(
201207
combinations,
208+
test_func,
202209
input_dataset,
203210
completion_params,
211+
completion_params_provided,
204212
input_messages,
205213
input_rows,
206214
evaluation_test_kwargs,
@@ -261,7 +269,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
261269
index = abs(index) % (max_index + 1)
262270
row.input_metadata.row_id = generate_id(seed=0, index=index)
263271

264-
completion_params = kwargs["completion_params"]
272+
completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None
265273
# Create eval metadata with test function info and current commit hash
266274
eval_metadata = EvalMetadata(
267275
name=test_func.__name__,
@@ -565,12 +573,14 @@ async def execute_run_with_progress(run_idx: int, config):
565573
return create_dynamically_parameterized_wrapper(
566574
test_func,
567575
wrapper_body,
568-
pytest_parametrize_args["argnames"],
576+
pytest_parametrize_args["sig_parameters"],
569577
)
570578

571579
# Create the pytest wrapper
572580
pytest_wrapper = create_wrapper_with_signature()
573-
pytest_wrapper = pytest.mark.parametrize(**pytest_parametrize_args)(pytest_wrapper)
581+
pytest_wrapper = pytest.mark.parametrize(**pytest_parametrize_args["pytest_parametrize_kwargs"])(
582+
pytest_wrapper
583+
)
574584
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)
575585

576586
# Create the dual mode wrapper

eval_protocol/pytest/generate_parameter_combinations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
]
3232

3333

34-
class ParameterizedTestKwargs(TypedDict):
34+
class ParameterizedTestKwargs(TypedDict, total=False):
3535
"""
3636
These are the type of parameters that can be passed to the generated pytest
3737
function. Every experiment is a unique combination of these parameters.

eval_protocol/pytest/parameterize.py

Lines changed: 173 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import inspect
23
from typing import TypedDict, Protocol
34
from collections.abc import Callable, Sequence, Iterable, Awaitable
@@ -9,12 +10,133 @@
910
from eval_protocol.pytest.types import DatasetPathParam, EvaluationInputParam, InputMessagesParam, TestFunction
1011

1112

12-
class PytestParametrizeArgs(TypedDict):
13+
def _has_pytest_parametrize_with_completion_params(test_func: TestFunction) -> bool:
14+
"""
15+
Check if a test function has a pytest.mark.parametrize decorator with argnames="completion_params".
16+
17+
This function uses inspect.getsource and ast to parse the function's source code and look for
18+
pytest.mark.parametrize decorators that include "completion_params" in their argnames.
19+
20+
Args:
21+
test_func: The test function to analyze
22+
23+
Returns:
24+
True if the function has a pytest.mark.parametrize decorator with "completion_params" in argnames,
25+
False otherwise
26+
27+
Raises:
28+
OSError: If the source code cannot be retrieved (e.g., function is defined in interactive mode)
29+
SyntaxError: If the source code cannot be parsed as valid Python
30+
"""
31+
try:
32+
source = inspect.getsource(test_func)
33+
except OSError:
34+
# Function source cannot be retrieved (e.g., defined in interactive mode)
35+
return False
36+
37+
try:
38+
tree = ast.parse(source)
39+
except SyntaxError:
40+
# Source code cannot be parsed
41+
return False
42+
43+
# Walk through the AST to find pytest.mark.parametrize decorators
44+
for node in ast.walk(tree):
45+
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
46+
# Check decorators on this function
47+
for decorator in node.decorator_list:
48+
if _is_pytest_parametrize_with_completion_params(decorator):
49+
return True
50+
51+
return False
52+
53+
54+
def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool:
55+
"""
56+
Check if a decorator is pytest.mark.parametrize with "completion_params" in argnames.
57+
58+
Args:
59+
decorator: AST node representing a decorator
60+
61+
Returns:
62+
True if this is a pytest.mark.parametrize decorator with "completion_params" in argnames
63+
"""
64+
# Look for pytest.mark.parametrize pattern
65+
if isinstance(decorator, ast.Call):
66+
# Check if it's pytest.mark.parametrize
67+
if isinstance(decorator.func, ast.Attribute):
68+
if (
69+
isinstance(decorator.func.value, ast.Attribute)
70+
and isinstance(decorator.func.value.value, ast.Name)
71+
and decorator.func.value.value.id == "pytest"
72+
and decorator.func.value.attr == "mark"
73+
and decorator.func.attr == "parametrize"
74+
):
75+
# Check positional arguments first (argnames is typically the first positional arg)
76+
if len(decorator.args) > 0:
77+
argnames_arg = decorator.args[0]
78+
if _check_argnames_for_completion_params(argnames_arg):
79+
return True
80+
81+
# Check keyword arguments for argnames
82+
for keyword in decorator.keywords:
83+
if keyword.arg == "argnames":
84+
if _check_argnames_for_completion_params(keyword.value):
85+
return True
86+
87+
return False
88+
89+
90+
def _check_argnames_for_completion_params(argnames_node: ast.expr) -> bool:
91+
"""
92+
Check if an argnames AST node contains "completion_params".
93+
94+
Args:
95+
argnames_node: AST node representing the argnames value
96+
97+
Returns:
98+
True if argnames contains "completion_params"
99+
"""
100+
if isinstance(argnames_node, ast.Constant):
101+
# Single string case: argnames="completion_params"
102+
if argnames_node.value == "completion_params":
103+
return True
104+
elif isinstance(argnames_node, ast.List):
105+
# List case: argnames=["completion_params", ...]
106+
for elt in argnames_node.elts:
107+
if isinstance(elt, ast.Constant) and elt.value == "completion_params":
108+
return True
109+
elif isinstance(argnames_node, ast.Tuple):
110+
# Tuple case: argnames=("completion_params", ...)
111+
for elt in argnames_node.elts:
112+
if isinstance(elt, ast.Constant) and elt.value == "completion_params":
113+
return True
114+
115+
return False
116+
117+
118+
class PytestMarkParametrizeKwargs(TypedDict):
13119
argnames: Sequence[str]
14120
argvalues: Iterable[ParameterSet | Sequence[object] | object]
15121
ids: Iterable[str] | None
16122

17123

124+
class ParametrizeArgs(TypedDict):
125+
"""
126+
This contains all the necessary information to properly hijack the test
127+
function's signature and dynamically inject usage of
128+
pytest.mark.parametrize. The two will differ when a user manually provides
129+
the pytest.mark.parametrize decorator instead of passing completion_params
130+
on their own.
131+
"""
132+
133+
# for create_dynamically_parameterized_wrapper
134+
sig_parameters: Sequence[str]
135+
136+
# for pytest.mark.parametrize
137+
pytest_parametrize_kwargs: PytestMarkParametrizeKwargs
138+
139+
18140
class ParameterIdGenerator(Protocol):
19141
"""Protocol for generating pytest parameter IDs from parameter combinations."""
20142

@@ -30,7 +152,7 @@ def generate_id(self, combo: CombinationTuple) -> str | None:
30152
...
31153

32154

33-
class DefaultParameterIdGenerator:
155+
class DefaultParameterIdGenerator(ParameterIdGenerator):
34156
"""Default ID generator that creates meaningful IDs from parameter combinations."""
35157

36158
def __init__(self, max_length: int = 200):
@@ -46,34 +168,49 @@ def generate_id(self, combo: CombinationTuple) -> str | None:
46168
dataset, completion_params, messages, rows, evaluation_test_kwargs = combo
47169

48170
if completion_params:
49-
# Get all string, numeric, and boolean values from completion_params, sorted by key
50-
str_values = []
51-
for key in sorted(completion_params.keys()):
52-
value = completion_params[key]
53-
if isinstance(value, (str, int, float, bool)):
54-
str_values.append(str(value))
171+
id = self.generate_id_from_dict(completion_params, self.max_length)
172+
if id:
173+
return id
174+
else:
175+
if rows:
176+
return f"rows(len={len(rows)})"
177+
elif messages:
178+
return f"messages(len={len(messages)})"
179+
elif dataset:
180+
return f"dataset(len={len(dataset)})"
181+
return None
55182

56-
if str_values:
57-
id_str = ":".join(str_values)
183+
@staticmethod
184+
def generate_id_from_dict(d: dict[str, object], max_length: int = 200) -> str | None:
185+
# Get all string, numeric, and boolean values from completion_params, sorted by key
186+
str_values = []
187+
for key in sorted(d.keys()):
188+
value = d[key]
189+
if isinstance(value, (str, int, float, bool)):
190+
str_values.append(str(value))
58191

59-
# Truncate if too long
60-
if len(id_str) > self.max_length:
61-
id_str = id_str[: self.max_length - 3] + "..."
192+
if str_values:
193+
id_str = ":".join(str_values)
62194

63-
return id_str
195+
# Truncate if too long
196+
if len(id_str) > max_length:
197+
id_str = id_str[: max_length - 3] + "..."
64198

199+
return id_str
65200
return None
66201

67202

68203
def pytest_parametrize(
69204
combinations: list[CombinationTuple],
205+
test_func: TestFunction | None,
70206
input_dataset: Sequence[DatasetPathParam] | None,
71207
completion_params: Sequence[CompletionParams | None] | None,
208+
completion_params_provided: bool,
72209
input_messages: Sequence[list[InputMessagesParam] | None] | None,
73210
input_rows: Sequence[list[EvaluationRow]] | None,
74211
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None,
75212
id_generator: ParameterIdGenerator | None = None,
76-
) -> PytestParametrizeArgs:
213+
) -> ParametrizeArgs:
77214
"""
78215
This function dynamically generates pytest.mark.parametrize arguments for a given
79216
set of combinations. This is the magic that allows developers to pass in their
@@ -82,18 +219,31 @@ def pytest_parametrize(
82219
API.
83220
"""
84221

222+
if test_func is not None:
223+
has_pytest_parametrize = _has_pytest_parametrize_with_completion_params(test_func)
224+
else:
225+
has_pytest_parametrize = False
226+
85227
# Create parameter tuples for pytest.mark.parametrize
86228
argnames: list[str] = []
229+
sig_parameters: list[str] = []
87230
if input_dataset is not None:
88231
argnames.append("dataset_path")
232+
sig_parameters.append("dataset_path")
89233
if completion_params is not None:
90-
argnames.append("completion_params")
234+
if completion_params_provided and not has_pytest_parametrize:
235+
argnames.append("completion_params")
236+
if has_pytest_parametrize or completion_params_provided:
237+
sig_parameters.append("completion_params")
91238
if input_messages is not None:
92239
argnames.append("input_messages")
240+
sig_parameters.append("input_messages")
93241
if input_rows is not None:
94242
argnames.append("input_rows")
243+
sig_parameters.append("input_rows")
95244
if evaluation_test_kwargs is not None:
96245
argnames.append("evaluation_test_kwargs")
246+
sig_parameters.append("evaluation_test_kwargs")
97247

98248
# Use default ID generator if none provided
99249
if id_generator is None:
@@ -109,7 +259,7 @@ def pytest_parametrize(
109259
# Build parameter tuple based on what's provided
110260
if input_dataset is not None:
111261
param_tuple.append(dataset)
112-
if completion_params is not None:
262+
if completion_params_provided:
113263
param_tuple.append(cp)
114264
if input_messages is not None:
115265
param_tuple.append(messages)
@@ -132,7 +282,12 @@ def pytest_parametrize(
132282
ids.append(combo_id)
133283

134284
# Return None for ids if no IDs were generated (let pytest use defaults)
135-
return PytestParametrizeArgs(argnames=argnames, argvalues=argvalues, ids=ids if ids else None)
285+
return ParametrizeArgs(
286+
pytest_parametrize_kwargs=PytestMarkParametrizeKwargs(
287+
argnames=argnames, argvalues=argvalues, ids=ids if ids else None
288+
),
289+
sig_parameters=sig_parameters,
290+
)
136291

137292

138293
def create_dynamically_parameterized_wrapper(

eval_protocol/quickstart/llm_judge_openai_responses.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
EvaluationRow,
2828
SingleTurnRolloutProcessor,
2929
OpenAIResponsesAdapter,
30+
DefaultParameterIdGenerator,
3031
)
3132

3233
adapter = OpenAIResponsesAdapter()
@@ -41,17 +42,20 @@
4142

4243

4344
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI")
44-
@pytest.mark.asyncio
45-
@evaluation_test(
46-
input_rows=[input_rows],
47-
completion_params=[
45+
@pytest.mark.parametrize(
46+
"completion_params",
47+
[
4848
{
4949
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
5050
},
5151
{
5252
"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct-0905",
5353
},
5454
],
55+
ids=DefaultParameterIdGenerator.generate_id_from_dict,
56+
)
57+
@evaluation_test(
58+
input_rows=[input_rows],
5559
rollout_processor=SingleTurnRolloutProcessor(),
5660
preprocess_fn=split_multi_turn_rows,
5761
mode="all",

0 commit comments

Comments
 (0)