-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathparameterize.py
More file actions
170 lines (131 loc) · 6.53 KB
/
parameterize.py
File metadata and controls
170 lines (131 loc) · 6.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import inspect
from typing import TypedDict, Protocol
from collections.abc import Callable, Sequence, Iterable, Awaitable
from _pytest.mark import ParameterSet
from eval_protocol.models import CompletionParams, EvaluationRow
from eval_protocol.pytest.generate_parameter_combinations import CombinationTuple
from eval_protocol.pytest.types import DatasetPathParam, EvaluationInputParam, InputMessagesParam, TestFunction
class PytestParametrizeArgs(TypedDict):
argnames: Sequence[str]
argvalues: Iterable[ParameterSet | Sequence[object] | object]
ids: Iterable[str] | None
class ParameterIdGenerator(Protocol):
"""Protocol for generating pytest parameter IDs from parameter combinations."""
def generate_id(self, combo: CombinationTuple) -> str | None:
"""Generate an ID for a parameter combination.
Args:
combo: The parameter combination tuple
Returns:
A string ID for the combination, or None to use default pytest ID
"""
...
class DefaultParameterIdGenerator:
"""Default ID generator that creates meaningful IDs from parameter combinations."""
def __init__(self, max_length: int = 200):
"""Initialize the ID generator with configuration options.
Args:
max_length: Maximum length of generated IDs
"""
self.max_length = max_length
def generate_id(self, combo: CombinationTuple) -> str | None:
"""Generate an ID for a parameter combination."""
dataset, completion_params, messages, rows, evaluation_test_kwargs = combo
if completion_params:
# Get all string, numeric, and boolean values from completion_params, sorted by key
str_values = []
for key in sorted(completion_params.keys()):
value = completion_params[key]
if isinstance(value, (str, int, float, bool)):
str_values.append(str(value))
if str_values:
id_str = ":".join(str_values)
# Truncate if too long
if len(id_str) > self.max_length:
id_str = id_str[: self.max_length - 3] + "..."
return id_str
return None
def pytest_parametrize(
combinations: list[CombinationTuple],
input_dataset: Sequence[DatasetPathParam] | None,
completion_params: Sequence[CompletionParams | None] | None,
input_messages: Sequence[list[InputMessagesParam] | None] | None,
input_rows: Sequence[list[EvaluationRow]] | None,
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None,
id_generator: ParameterIdGenerator | None = None,
) -> PytestParametrizeArgs:
"""
This function dynamically generates pytest.mark.parametrize arguments for a given
set of combinations. This is the magic that allows developers to pass in their
inputs in a single decorator and generate all combinations of experiments
without having to create their own fixtures and confirming to eval-protocol's
API.
"""
# Create parameter tuples for pytest.mark.parametrize
argnames: list[str] = []
if input_dataset is not None:
argnames.append("dataset_path")
if completion_params is not None:
argnames.append("completion_params")
if input_messages is not None:
argnames.append("input_messages")
if input_rows is not None:
argnames.append("input_rows")
if evaluation_test_kwargs is not None:
argnames.append("evaluation_test_kwargs")
# Use default ID generator if none provided
if id_generator is None:
id_generator = DefaultParameterIdGenerator()
argvalues: list[ParameterSet | Sequence[object] | object] = []
ids: list[str] = []
for combo in combinations:
dataset, cp, messages, rows, etk = combo
param_tuple: list[object] = []
# Build parameter tuple based on what's provided
if input_dataset is not None:
param_tuple.append(dataset)
if completion_params is not None:
param_tuple.append(cp)
if input_messages is not None:
param_tuple.append(messages)
if input_rows is not None:
param_tuple.append(rows)
if evaluation_test_kwargs is not None:
param_tuple.append(etk)
# Validate parameter tuple length
if len(argnames) != len(param_tuple):
raise ValueError(
f"The length of argnames ({len(argnames)}) is not the same as the length of param_tuple ({len(param_tuple)})"
)
argvalues.append(tuple(param_tuple))
# Generate ID for this combination
combo_id = id_generator.generate_id(combo)
if combo_id is not None:
ids.append(combo_id)
# Return None for ids if no IDs were generated (let pytest use defaults)
return PytestParametrizeArgs(argnames=argnames, argvalues=argvalues, ids=ids if ids else None)
def create_dynamically_parameterized_wrapper(
test_func: TestFunction, wrapper_body: Callable[..., Awaitable[None]], test_param_names: Sequence[str]
) -> Callable[..., None]:
"""
Creates a wrapper function with dynamic parameters for pytest parameterization.
This function takes a test function and creates a wrapper that:
1. Preserves the original function's metadata using functools.wraps
2. Creates a new function signature with the specified parameter names that maps to pytest.mark.parametrize decorator
3. Returns a callable that can be used with pytest.mark.parametrize
The function signature is dynamically created to match the parameter names expected by
pytest.mark.parametrize, ensuring that pytest can properly map the test parameters
to the function arguments.
Args:
test_func: The original test function to wrap
wrapper_body: The function body that contains the actual test logic
test_param_names: List of parameter names for the dynamic signature
Returns:
A wrapper function with the specified parameter signature that calls wrapper_body
"""
from functools import wraps
@wraps(test_func) # pyright: ignore[reportArgumentType]
async def wrapper(**kwargs) -> None: # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
return await wrapper_body(**kwargs)
parameters = [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in test_param_names]
wrapper.__signature__ = inspect.Signature(parameters) # pyright: ignore[reportAttributeAccessIssue]
return wrapper # pyright: ignore[reportUnknownVariableType, reportReturnType]