1+ import ast
12import inspect
23from typing import TypedDict , Protocol
34from collections .abc import Callable , Sequence , Iterable , Awaitable
910from eval_protocol .pytest .types import DatasetPathParam , EvaluationInputParam , InputMessagesParam , TestFunction
1011
1112
12- class PytestParametrizeArgs (TypedDict ):
13+ def _has_pytest_parametrize_with_completion_params (test_func : TestFunction ) -> bool :
14+ """
15+ Check if a test function has a pytest.mark.parametrize decorator with argnames="completion_params".
16+
17+ This function uses inspect.getsource and ast to parse the function's source code and look for
18+ pytest.mark.parametrize decorators that include "completion_params" in their argnames.
19+
20+ Args:
21+ test_func: The test function to analyze
22+
23+ Returns:
24+ True if the function has a pytest.mark.parametrize decorator with "completion_params" in argnames,
25+ False otherwise
26+
27+ Raises:
28+ OSError: If the source code cannot be retrieved (e.g., function is defined in interactive mode)
29+ SyntaxError: If the source code cannot be parsed as valid Python
30+ """
31+ try :
32+ source = inspect .getsource (test_func )
33+ except OSError :
34+ # Function source cannot be retrieved (e.g., defined in interactive mode)
35+ return False
36+
37+ try :
38+ tree = ast .parse (source )
39+ except SyntaxError :
40+ # Source code cannot be parsed
41+ return False
42+
43+ # Walk through the AST to find pytest.mark.parametrize decorators
44+ for node in ast .walk (tree ):
45+ if isinstance (node , ast .FunctionDef ) or isinstance (node , ast .AsyncFunctionDef ):
46+ # Check decorators on this function
47+ for decorator in node .decorator_list :
48+ if _is_pytest_parametrize_with_completion_params (decorator ):
49+ return True
50+
51+ return False
52+
53+
54+ def _is_pytest_parametrize_with_completion_params (decorator : ast .expr ) -> bool :
55+ """
56+ Check if a decorator is pytest.mark.parametrize with "completion_params" in argnames.
57+
58+ Args:
59+ decorator: AST node representing a decorator
60+
61+ Returns:
62+ True if this is a pytest.mark.parametrize decorator with "completion_params" in argnames
63+ """
64+ # Look for pytest.mark.parametrize pattern
65+ if isinstance (decorator , ast .Call ):
66+ # Check if it's pytest.mark.parametrize
67+ if isinstance (decorator .func , ast .Attribute ):
68+ if (
69+ isinstance (decorator .func .value , ast .Attribute )
70+ and isinstance (decorator .func .value .value , ast .Name )
71+ and decorator .func .value .value .id == "pytest"
72+ and decorator .func .value .attr == "mark"
73+ and decorator .func .attr == "parametrize"
74+ ):
75+ # Check positional arguments first (argnames is typically the first positional arg)
76+ if len (decorator .args ) > 0 :
77+ argnames_arg = decorator .args [0 ]
78+ if _check_argnames_for_completion_params (argnames_arg ):
79+ return True
80+
81+ # Check keyword arguments for argnames
82+ for keyword in decorator .keywords :
83+ if keyword .arg == "argnames" :
84+ if _check_argnames_for_completion_params (keyword .value ):
85+ return True
86+
87+ return False
88+
89+
90+ def _check_argnames_for_completion_params (argnames_node : ast .expr ) -> bool :
91+ """
92+ Check if an argnames AST node contains "completion_params".
93+
94+ Args:
95+ argnames_node: AST node representing the argnames value
96+
97+ Returns:
98+ True if argnames contains "completion_params"
99+ """
100+ if isinstance (argnames_node , ast .Constant ):
101+ # Single string case: argnames="completion_params"
102+ if argnames_node .value == "completion_params" :
103+ return True
104+ elif isinstance (argnames_node , ast .List ):
105+ # List case: argnames=["completion_params", ...]
106+ for elt in argnames_node .elts :
107+ if isinstance (elt , ast .Constant ) and elt .value == "completion_params" :
108+ return True
109+ elif isinstance (argnames_node , ast .Tuple ):
110+ # Tuple case: argnames=("completion_params", ...)
111+ for elt in argnames_node .elts :
112+ if isinstance (elt , ast .Constant ) and elt .value == "completion_params" :
113+ return True
114+
115+ return False
116+
117+
118+ class PytestMarkParametrizeKwargs (TypedDict ):
13119 argnames : Sequence [str ]
14120 argvalues : Iterable [ParameterSet | Sequence [object ] | object ]
15121 ids : Iterable [str ] | None
16122
17123
124+ class ParametrizeArgs (TypedDict ):
125+ """
126+ This contains all the necessary information to properly hijack the test
127+ function's signature and dynamically inject usage of
128+ pytest.mark.parametrize. The two will differ when a user manually provides
129+ the pytest.mark.parametrize decorator instead of passing completion_params
130+ on their own.
131+ """
132+
133+ # for create_dynamically_parameterized_wrapper
134+ sig_parameters : Sequence [str ]
135+
136+ # for pytest.mark.parametrize
137+ pytest_parametrize_kwargs : PytestMarkParametrizeKwargs
138+
139+
18140class ParameterIdGenerator (Protocol ):
19141 """Protocol for generating pytest parameter IDs from parameter combinations."""
20142
@@ -30,7 +152,7 @@ def generate_id(self, combo: CombinationTuple) -> str | None:
30152 ...
31153
32154
33- class DefaultParameterIdGenerator :
155+ class DefaultParameterIdGenerator ( ParameterIdGenerator ) :
34156 """Default ID generator that creates meaningful IDs from parameter combinations."""
35157
36158 def __init__ (self , max_length : int = 200 ):
@@ -46,34 +168,49 @@ def generate_id(self, combo: CombinationTuple) -> str | None:
46168 dataset , completion_params , messages , rows , evaluation_test_kwargs = combo
47169
48170 if completion_params :
49- # Get all string, numeric, and boolean values from completion_params, sorted by key
50- str_values = []
51- for key in sorted (completion_params .keys ()):
52- value = completion_params [key ]
53- if isinstance (value , (str , int , float , bool )):
54- str_values .append (str (value ))
171+ id = self .generate_id_from_dict (completion_params , self .max_length )
172+ if id :
173+ return id
174+ else :
175+ if rows :
176+ return f"rows(len={ len (rows )} )"
177+ elif messages :
178+ return f"messages(len={ len (messages )} )"
179+ elif dataset :
180+ return f"dataset(len={ len (dataset )} )"
181+ return None
55182
56- if str_values :
57- id_str = ":" .join (str_values )
183+ @staticmethod
184+ def generate_id_from_dict (d : dict [str , object ], max_length : int = 200 ) -> str | None :
185+ # Get all string, numeric, and boolean values from completion_params, sorted by key
186+ str_values = []
187+ for key in sorted (d .keys ()):
188+ value = d [key ]
189+ if isinstance (value , (str , int , float , bool )):
190+ str_values .append (str (value ))
58191
59- # Truncate if too long
60- if len (id_str ) > self .max_length :
61- id_str = id_str [: self .max_length - 3 ] + "..."
192+ if str_values :
193+ id_str = ":" .join (str_values )
62194
63- return id_str
195+ # Truncate if too long
196+ if len (id_str ) > max_length :
197+ id_str = id_str [: max_length - 3 ] + "..."
64198
199+ return id_str
65200 return None
66201
67202
68203def pytest_parametrize (
69204 combinations : list [CombinationTuple ],
205+ test_func : TestFunction | None ,
70206 input_dataset : Sequence [DatasetPathParam ] | None ,
71207 completion_params : Sequence [CompletionParams | None ] | None ,
208+ completion_params_provided : bool ,
72209 input_messages : Sequence [list [InputMessagesParam ] | None ] | None ,
73210 input_rows : Sequence [list [EvaluationRow ]] | None ,
74211 evaluation_test_kwargs : Sequence [EvaluationInputParam | None ] | None ,
75212 id_generator : ParameterIdGenerator | None = None ,
76- ) -> PytestParametrizeArgs :
213+ ) -> ParametrizeArgs :
77214 """
78215 This function dynamically generates pytest.mark.parametrize arguments for a given
79216 set of combinations. This is the magic that allows developers to pass in their
@@ -82,18 +219,31 @@ def pytest_parametrize(
82219 API.
83220 """
84221
222+ if test_func is not None :
223+ has_pytest_parametrize = _has_pytest_parametrize_with_completion_params (test_func )
224+ else :
225+ has_pytest_parametrize = False
226+
85227 # Create parameter tuples for pytest.mark.parametrize
86228 argnames : list [str ] = []
229+ sig_parameters : list [str ] = []
87230 if input_dataset is not None :
88231 argnames .append ("dataset_path" )
232+ sig_parameters .append ("dataset_path" )
89233 if completion_params is not None :
90- argnames .append ("completion_params" )
234+ if completion_params_provided and not has_pytest_parametrize :
235+ argnames .append ("completion_params" )
236+ if has_pytest_parametrize or completion_params_provided :
237+ sig_parameters .append ("completion_params" )
91238 if input_messages is not None :
92239 argnames .append ("input_messages" )
240+ sig_parameters .append ("input_messages" )
93241 if input_rows is not None :
94242 argnames .append ("input_rows" )
243+ sig_parameters .append ("input_rows" )
95244 if evaluation_test_kwargs is not None :
96245 argnames .append ("evaluation_test_kwargs" )
246+ sig_parameters .append ("evaluation_test_kwargs" )
97247
98248 # Use default ID generator if none provided
99249 if id_generator is None :
@@ -109,7 +259,7 @@ def pytest_parametrize(
109259 # Build parameter tuple based on what's provided
110260 if input_dataset is not None :
111261 param_tuple .append (dataset )
112- if completion_params is not None :
262+ if completion_params_provided :
113263 param_tuple .append (cp )
114264 if input_messages is not None :
115265 param_tuple .append (messages )
@@ -132,7 +282,12 @@ def pytest_parametrize(
132282 ids .append (combo_id )
133283
134284 # Return None for ids if no IDs were generated (let pytest use defaults)
135- return PytestParametrizeArgs (argnames = argnames , argvalues = argvalues , ids = ids if ids else None )
285+ return ParametrizeArgs (
286+ pytest_parametrize_kwargs = PytestMarkParametrizeKwargs (
287+ argnames = argnames , argvalues = argvalues , ids = ids if ids else None
288+ ),
289+ sig_parameters = sig_parameters ,
290+ )
136291
137292
138293def create_dynamically_parameterized_wrapper (
0 commit comments