-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathparameterize.py
More file actions
325 lines (261 loc) · 12.2 KB
/
parameterize.py
File metadata and controls
325 lines (261 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import ast
import inspect
from typing import TypedDict, Protocol
from collections.abc import Callable, Sequence, Iterable, Awaitable
from _pytest.mark import ParameterSet
from eval_protocol.models import CompletionParams, EvaluationRow
from eval_protocol.pytest.generate_parameter_combinations import CombinationTuple
from eval_protocol.pytest.types import DatasetPathParam, EvaluationInputParam, InputMessagesParam, TestFunction
def _has_pytest_parametrize_with_completion_params(test_func: TestFunction) -> bool:
"""
Check if a test function has a pytest.mark.parametrize decorator with argnames="completion_params".
This function uses inspect.getsource and ast to parse the function's source code and look for
pytest.mark.parametrize decorators that include "completion_params" in their argnames.
Args:
test_func: The test function to analyze
Returns:
True if the function has a pytest.mark.parametrize decorator with "completion_params" in argnames,
False otherwise
Raises:
OSError: If the source code cannot be retrieved (e.g., function is defined in interactive mode)
SyntaxError: If the source code cannot be parsed as valid Python
"""
try:
source = inspect.getsource(test_func)
except OSError:
# Function source cannot be retrieved (e.g., defined in interactive mode)
return False
try:
tree = ast.parse(source)
except SyntaxError:
# Source code cannot be parsed
return False
# Walk through the AST to find pytest.mark.parametrize decorators
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
# Check decorators on this function
for decorator in node.decorator_list:
if _is_pytest_parametrize_with_completion_params(decorator):
return True
return False
def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool:
"""
Check if a decorator is pytest.mark.parametrize with "completion_params" in argnames.
Args:
decorator: AST node representing a decorator
Returns:
True if this is a pytest.mark.parametrize decorator with "completion_params" in argnames
"""
# Look for pytest.mark.parametrize pattern
if isinstance(decorator, ast.Call):
# Check if it's pytest.mark.parametrize
if isinstance(decorator.func, ast.Attribute):
if (
isinstance(decorator.func.value, ast.Attribute)
and isinstance(decorator.func.value.value, ast.Name)
and decorator.func.value.value.id == "pytest"
and decorator.func.value.attr == "mark"
and decorator.func.attr == "parametrize"
):
# Check positional arguments first (argnames is typically the first positional arg)
if len(decorator.args) > 0:
argnames_arg = decorator.args[0]
if _check_argnames_for_completion_params(argnames_arg):
return True
# Check keyword arguments for argnames
for keyword in decorator.keywords:
if keyword.arg == "argnames":
if _check_argnames_for_completion_params(keyword.value):
return True
return False
def _check_argnames_for_completion_params(argnames_node: ast.expr) -> bool:
"""
Check if an argnames AST node contains "completion_params".
Args:
argnames_node: AST node representing the argnames value
Returns:
True if argnames contains "completion_params"
"""
if isinstance(argnames_node, ast.Constant):
# Single string case: argnames="completion_params"
if argnames_node.value == "completion_params":
return True
elif isinstance(argnames_node, ast.List):
# List case: argnames=["completion_params", ...]
for elt in argnames_node.elts:
if isinstance(elt, ast.Constant) and elt.value == "completion_params":
return True
elif isinstance(argnames_node, ast.Tuple):
# Tuple case: argnames=("completion_params", ...)
for elt in argnames_node.elts:
if isinstance(elt, ast.Constant) and elt.value == "completion_params":
return True
return False
class PytestMarkParametrizeKwargs(TypedDict):
argnames: Sequence[str]
argvalues: Iterable[ParameterSet | Sequence[object] | object]
ids: Iterable[str] | None
class ParametrizeArgs(TypedDict):
"""
This contains all the necessary information to properly hijack the test
function's signature and dynamically inject usage of
pytest.mark.parametrize. The two will differ when a user manually provides
the pytest.mark.parametrize decorator instead of passing completion_params
on their own.
"""
# for create_dynamically_parameterized_wrapper
sig_parameters: Sequence[str]
# for pytest.mark.parametrize
pytest_parametrize_kwargs: PytestMarkParametrizeKwargs
class ParameterIdGenerator(Protocol):
"""Protocol for generating pytest parameter IDs from parameter combinations."""
def generate_id(self, combo: CombinationTuple) -> str | None:
"""Generate an ID for a parameter combination.
Args:
combo: The parameter combination tuple
Returns:
A string ID for the combination, or None to use default pytest ID
"""
...
class DefaultParameterIdGenerator(ParameterIdGenerator):
"""Default ID generator that creates meaningful IDs from parameter combinations."""
def __init__(self, max_length: int = 200):
"""Initialize the ID generator with configuration options.
Args:
max_length: Maximum length of generated IDs
"""
self.max_length = max_length
def generate_id(self, combo: CombinationTuple) -> str | None:
"""Generate an ID for a parameter combination."""
dataset, completion_params, messages, rows, evaluation_test_kwargs = combo
if completion_params:
id = self.generate_id_from_dict(completion_params, self.max_length)
if id:
return id
else:
if rows:
return f"rows(len={len(rows)})"
elif messages:
return f"messages(len={len(messages)})"
elif dataset:
return f"dataset(len={len(dataset)})"
return None
@staticmethod
def generate_id_from_dict(d: dict[str, object], max_length: int = 200) -> str | None:
# Get all string, numeric, and boolean values from completion_params, sorted by key
str_values = []
for key in sorted(d.keys()):
value = d[key]
if isinstance(value, (str, int, float, bool)):
str_values.append(str(value))
if str_values:
id_str = ":".join(str_values)
# Truncate if too long
if len(id_str) > max_length:
id_str = id_str[: max_length - 3] + "..."
return id_str
return None
def pytest_parametrize(
combinations: list[CombinationTuple],
test_func: TestFunction | None,
input_dataset: Sequence[DatasetPathParam] | None,
completion_params: Sequence[CompletionParams | None] | None,
completion_params_provided: bool,
input_messages: Sequence[list[InputMessagesParam] | None] | None,
input_rows: Sequence[list[EvaluationRow]] | None,
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None,
id_generator: ParameterIdGenerator | None = None,
) -> ParametrizeArgs:
"""
This function dynamically generates pytest.mark.parametrize arguments for a given
set of combinations. This is the magic that allows developers to pass in their
inputs in a single decorator and generate all combinations of experiments
without having to create their own fixtures and confirming to eval-protocol's
API.
"""
if test_func is not None:
has_pytest_parametrize = _has_pytest_parametrize_with_completion_params(test_func)
else:
has_pytest_parametrize = False
# Create parameter tuples for pytest.mark.parametrize
argnames: list[str] = []
sig_parameters: list[str] = []
if input_dataset is not None:
argnames.append("dataset_path")
sig_parameters.append("dataset_path")
if completion_params is not None:
if completion_params_provided and not has_pytest_parametrize:
argnames.append("completion_params")
if has_pytest_parametrize or completion_params_provided:
sig_parameters.append("completion_params")
if input_messages is not None:
argnames.append("input_messages")
sig_parameters.append("input_messages")
if input_rows is not None:
argnames.append("input_rows")
sig_parameters.append("input_rows")
if evaluation_test_kwargs is not None:
argnames.append("evaluation_test_kwargs")
sig_parameters.append("evaluation_test_kwargs")
# Use default ID generator if none provided
if id_generator is None:
id_generator = DefaultParameterIdGenerator()
argvalues: list[ParameterSet | Sequence[object] | object] = []
ids: list[str] = []
for combo in combinations:
dataset, cp, messages, rows, etk = combo
param_tuple: list[object] = []
# Build parameter tuple based on what's provided
if input_dataset is not None:
param_tuple.append(dataset)
if completion_params_provided:
param_tuple.append(cp)
if input_messages is not None:
param_tuple.append(messages)
if input_rows is not None:
param_tuple.append(rows)
if evaluation_test_kwargs is not None:
param_tuple.append(etk)
# Validate parameter tuple length
if len(argnames) != len(param_tuple):
raise ValueError(
f"The length of argnames ({len(argnames)}) is not the same as the length of param_tuple ({len(param_tuple)})"
)
argvalues.append(tuple(param_tuple))
# Generate ID for this combination
combo_id = id_generator.generate_id(combo)
if combo_id is not None:
ids.append(combo_id)
# Return None for ids if no IDs were generated (let pytest use defaults)
return ParametrizeArgs(
pytest_parametrize_kwargs=PytestMarkParametrizeKwargs(
argnames=argnames, argvalues=argvalues, ids=ids if ids else None
),
sig_parameters=sig_parameters,
)
def create_dynamically_parameterized_wrapper(
test_func: TestFunction, wrapper_body: Callable[..., Awaitable[None]], test_param_names: Sequence[str]
) -> Callable[..., None]:
"""
Creates a wrapper function with dynamic parameters for pytest parameterization.
This function takes a test function and creates a wrapper that:
1. Preserves the original function's metadata using functools.wraps
2. Creates a new function signature with the specified parameter names that maps to pytest.mark.parametrize decorator
3. Returns a callable that can be used with pytest.mark.parametrize
The function signature is dynamically created to match the parameter names expected by
pytest.mark.parametrize, ensuring that pytest can properly map the test parameters
to the function arguments.
Args:
test_func: The original test function to wrap
wrapper_body: The function body that contains the actual test logic
test_param_names: List of parameter names for the dynamic signature
Returns:
A wrapper function with the specified parameter signature that calls wrapper_body
"""
from functools import wraps
@wraps(test_func) # pyright: ignore[reportArgumentType]
async def wrapper(**kwargs) -> None: # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
return await wrapper_body(**kwargs)
parameters = [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in test_param_names]
wrapper.__signature__ = inspect.Signature(parameters) # pyright: ignore[reportAttributeAccessIssue]
return wrapper # pyright: ignore[reportUnknownVariableType, reportReturnType]