99from eval_protocol .pytest .types import DatasetPathParam , EvaluationInputParam , InputMessagesParam , TestFunction
1010
1111
12- class PytestParametrizeArgs (TypedDict ):
12+ class PytestMarkParametrizeKwargs (TypedDict ):
1313 argnames : Sequence [str ]
1414 argvalues : Iterable [ParameterSet | Sequence [object ] | object ]
1515 ids : Iterable [str ] | None
1616
1717
18+ class ParametrizeArgs (TypedDict ):
19+ """
20+ This contains all the necessary information to properly hijack the test
21+ function's signature and dynamically inject usage of
22+ pytest.mark.parametrize. The two will differ when a user manually provides
23+ the pytest.mark.parametrize decorator instead of passing completion_params
24+ on their own.
25+ """
26+
27+ # for create_dynamically_parameterized_wrapper
28+ sig_parameters : Sequence [str ]
29+
30+ # for pytest.mark.parametrize
31+ pytest_parametrize_kwargs : PytestMarkParametrizeKwargs
32+
33+
1834class ParameterIdGenerator (Protocol ):
1935 """Protocol for generating pytest parameter IDs from parameter combinations."""
2036
@@ -30,7 +46,7 @@ def generate_id(self, combo: CombinationTuple) -> str | None:
3046 ...
3147
3248
33- class DefaultParameterIdGenerator :
49+ class DefaultParameterIdGenerator ( ParameterIdGenerator ) :
3450 """Default ID generator that creates meaningful IDs from parameter combinations."""
3551
3652 def __init__ (self , max_length : int = 200 ):
@@ -46,34 +62,48 @@ def generate_id(self, combo: CombinationTuple) -> str | None:
4662 dataset , completion_params , messages , rows , evaluation_test_kwargs = combo
4763
4864 if completion_params :
49- # Get all string, numeric, and boolean values from completion_params, sorted by key
50- str_values = []
51- for key in sorted (completion_params .keys ()):
52- value = completion_params [key ]
53- if isinstance (value , (str , int , float , bool )):
54- str_values .append (str (value ))
65+ id = self .generate_id_from_dict (completion_params , self .max_length )
66+ if id :
67+ return id
68+ else :
69+ if rows :
70+ return f"rows(len={ len (rows )} )"
71+ elif messages :
72+ return f"messages(len={ len (messages )} )"
73+ elif dataset :
74+ return f"dataset(len={ len (dataset )} )"
75+ return None
5576
56- if str_values :
57- id_str = ":" .join (str_values )
77+ @staticmethod
78+ def generate_id_from_dict (d : dict [str , object ], max_length : int = 200 ) -> str | None :
79+ # Get all string, numeric, and boolean values from completion_params, sorted by key
80+ str_values = []
81+ for key in sorted (d .keys ()):
82+ value = d [key ]
83+ if isinstance (value , (str , int , float , bool )):
84+ str_values .append (str (value ))
5885
59- # Truncate if too long
60- if len (id_str ) > self .max_length :
61- id_str = id_str [: self .max_length - 3 ] + "..."
86+ if str_values :
87+ id_str = ":" .join (str_values )
6288
63- return id_str
89+ # Truncate if too long
90+ if len (id_str ) > max_length :
91+ id_str = id_str [: max_length - 3 ] + "..."
6492
93+ return id_str
6594 return None
6695
6796
6897def pytest_parametrize (
6998 combinations : list [CombinationTuple ],
7099 input_dataset : Sequence [DatasetPathParam ] | None ,
71100 completion_params : Sequence [CompletionParams | None ] | None ,
101+ completion_params_provided : bool ,
72102 input_messages : Sequence [list [InputMessagesParam ] | None ] | None ,
73103 input_rows : Sequence [list [EvaluationRow ]] | None ,
74104 evaluation_test_kwargs : Sequence [EvaluationInputParam | None ] | None ,
75105 id_generator : ParameterIdGenerator | None = None ,
76- ) -> PytestParametrizeArgs :
106+ ) -> ParametrizeArgs :
77107 """
78108 This function dynamically generates pytest.mark.parametrize arguments for a given
79109 set of combinations. This is the magic that allows developers to pass in their
@@ -84,16 +114,23 @@ def pytest_parametrize(
84114
85115 # Create parameter tuples for pytest.mark.parametrize
86116 argnames : list [str ] = []
117+ sig_parameters : list [str ] = []
87118 if input_dataset is not None :
88119 argnames .append ("dataset_path" )
120+ sig_parameters .append ("dataset_path" )
89121 if completion_params is not None :
90- argnames .append ("completion_params" )
122+ if completion_params_provided :
123+ argnames .append ("completion_params" )
124+ sig_parameters .append ("completion_params" )
91125 if input_messages is not None :
92126 argnames .append ("input_messages" )
127+ sig_parameters .append ("input_messages" )
93128 if input_rows is not None :
94129 argnames .append ("input_rows" )
130+ sig_parameters .append ("input_rows" )
95131 if evaluation_test_kwargs is not None :
96132 argnames .append ("evaluation_test_kwargs" )
133+ sig_parameters .append ("evaluation_test_kwargs" )
97134
98135 # Use default ID generator if none provided
99136 if id_generator is None :
@@ -109,7 +146,7 @@ def pytest_parametrize(
109146 # Build parameter tuple based on what's provided
110147 if input_dataset is not None :
111148 param_tuple .append (dataset )
112- if completion_params is not None :
149+ if completion_params_provided :
113150 param_tuple .append (cp )
114151 if input_messages is not None :
115152 param_tuple .append (messages )
@@ -132,7 +169,12 @@ def pytest_parametrize(
132169 ids .append (combo_id )
133170
134171 # Return None for ids if no IDs were generated (let pytest use defaults)
135- return PytestParametrizeArgs (argnames = argnames , argvalues = argvalues , ids = ids if ids else None )
172+ return ParametrizeArgs (
173+ pytest_parametrize_kwargs = PytestMarkParametrizeKwargs (
174+ argnames = argnames , argvalues = argvalues , ids = ids if ids else None
175+ ),
176+ sig_parameters = sig_parameters ,
177+ )
136178
137179
138180def create_dynamically_parameterized_wrapper (
0 commit comments