Skip to content

Commit 6fe0baa

Browse files
author
Dylan Huang
authored
Enhance pytest parameterization validation and AST node conversion (#234)
- Added validation to ensure that argvalues in pytest.mark.parametrize is a list or tuple, not a dict. - Introduced utility functions to convert AST nodes, including dictionaries, lists, and tuples, to their string representations for better error messaging. - Updated the _is_pytest_parametrize_with_completion_params function to include argvalues validation.
1 parent 2d1094c commit 6fe0baa

File tree

2 files changed

+313
-0
lines changed

2 files changed

+313
-0
lines changed

eval_protocol/pytest/parameterize.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool:
7373
and decorator.func.value.attr == "mark"
7474
and decorator.func.attr == "parametrize"
7575
):
76+
# Validate argvalues if present
77+
_validate_parametrize_argvalues(decorator)
78+
7679
# Check positional arguments first (argnames is typically the first positional arg)
7780
if len(decorator.args) > 0:
7881
argnames_arg = decorator.args[0]
@@ -88,6 +91,90 @@ def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool:
8891
return False
8992

9093

94+
def _ast_dict_to_string(dict_node: ast.Dict) -> str:
95+
"""
96+
Convert an AST Dict node to its string representation.
97+
98+
Args:
99+
dict_node: AST node representing a dictionary
100+
101+
Returns:
102+
String representation of the dictionary
103+
"""
104+
if not dict_node.keys:
105+
return "{}"
106+
107+
pairs = []
108+
for key, value in zip(dict_node.keys, dict_node.values):
109+
if key is not None:
110+
key_str = _ast_node_to_string(key)
111+
value_str = _ast_node_to_string(value)
112+
pairs.append(f"{key_str}: {value_str}")
113+
114+
return "{" + ", ".join(pairs) + "}"
115+
116+
117+
def _ast_node_to_string(node: ast.expr) -> str:
118+
"""
119+
Convert an AST node to its string representation.
120+
121+
Args:
122+
node: AST node to convert
123+
124+
Returns:
125+
String representation of the node
126+
"""
127+
if isinstance(node, ast.Constant):
128+
if isinstance(node.value, str):
129+
return repr(node.value)
130+
else:
131+
return str(node.value)
132+
elif isinstance(node, ast.Name):
133+
return node.id
134+
elif isinstance(node, ast.Dict):
135+
return _ast_dict_to_string(node)
136+
elif isinstance(node, ast.List):
137+
elements = [_ast_node_to_string(elt) for elt in node.elts]
138+
return "[" + ", ".join(elements) + "]"
139+
elif isinstance(node, ast.Tuple):
140+
elements = [_ast_node_to_string(elt) for elt in node.elts]
141+
return "(" + ", ".join(elements) + ")"
142+
else:
143+
# For complex expressions, return a simplified representation
144+
return "<complex expression>"
145+
146+
147+
def _validate_parametrize_argvalues(decorator: ast.Call) -> None:
148+
"""
149+
Validate that pytest.mark.parametrize argvalues is a list/tuple, not a dict.
150+
151+
Args:
152+
decorator: AST node representing the pytest.mark.parametrize decorator call
153+
154+
Raises:
155+
ValueError: If argvalues is a dict instead of a list/tuple
156+
"""
157+
# Check positional arguments (argvalues is typically the second positional arg)
158+
if len(decorator.args) > 1:
159+
argvalues_arg = decorator.args[1]
160+
if isinstance(argvalues_arg, ast.Dict):
161+
dict_repr = _ast_dict_to_string(argvalues_arg)
162+
raise ValueError(
163+
f"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict. "
164+
f"Use [{dict_repr}] instead of {dict_repr}."
165+
)
166+
167+
# Check keyword arguments for argvalues
168+
for keyword in decorator.keywords:
169+
if keyword.arg == "argvalues":
170+
if isinstance(keyword.value, ast.Dict):
171+
dict_repr = _ast_dict_to_string(keyword.value)
172+
raise ValueError(
173+
f"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict. "
174+
f"Use [{dict_repr}] instead of {dict_repr}."
175+
)
176+
177+
91178
def _check_argnames_for_completion_params(argnames_node: ast.expr) -> bool:
92179
"""
93180
Check if an argnames AST node contains "completion_params".
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""
2+
Test cases for pytest.mark.parametrize validation functionality.
3+
"""
4+
5+
import ast
6+
import pytest
7+
from eval_protocol.pytest.parameterize import _is_pytest_parametrize_with_completion_params
8+
9+
10+
def create_parametrize_decorator(argnames, argvalues, use_keyword=False):
11+
"""Create a pytest.mark.parametrize decorator AST node."""
12+
pytest_name = ast.Name(id="pytest", ctx=ast.Load())
13+
mark_attr = ast.Attribute(value=pytest_name, attr="mark", ctx=ast.Load())
14+
parametrize_attr = ast.Attribute(value=mark_attr, attr="parametrize", ctx=ast.Load())
15+
16+
if use_keyword:
17+
call = ast.Call(
18+
func=parametrize_attr,
19+
args=[],
20+
keywords=[
21+
ast.keyword(arg="argnames", value=ast.Constant(value=argnames)),
22+
ast.keyword(arg="argvalues", value=argvalues),
23+
],
24+
)
25+
else:
26+
call = ast.Call(func=parametrize_attr, args=[ast.Constant(value=argnames), argvalues], keywords=[])
27+
28+
return call
29+
30+
31+
class TestParametrizeValidation:
32+
"""Test cases for pytest.mark.parametrize validation."""
33+
34+
def test_invalid_dict_argvalues_positional(self):
35+
"""Test that a dict as positional argvalues throws an error."""
36+
decorator = create_parametrize_decorator(
37+
"completion_params", ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")])
38+
)
39+
40+
with pytest.raises(ValueError) as exc_info:
41+
_is_pytest_parametrize_with_completion_params(decorator)
42+
43+
error_msg = str(exc_info.value)
44+
assert (
45+
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
46+
in error_msg
47+
)
48+
assert "Use [{'model': 'gpt-4'}] instead of {'model': 'gpt-4'}" in error_msg
49+
50+
def test_valid_list_argvalues_positional(self):
51+
"""Test that a list as positional argvalues works correctly."""
52+
dict_value = ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")])
53+
list_value = ast.List(elts=[dict_value], ctx=ast.Load())
54+
55+
decorator = create_parametrize_decorator("completion_params", list_value)
56+
57+
result = _is_pytest_parametrize_with_completion_params(decorator)
58+
assert result is True
59+
60+
def test_invalid_dict_argvalues_keyword(self):
61+
"""Test that a dict as keyword argvalues throws an error."""
62+
decorator = create_parametrize_decorator(
63+
"completion_params",
64+
ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")]),
65+
use_keyword=True,
66+
)
67+
68+
with pytest.raises(ValueError) as exc_info:
69+
_is_pytest_parametrize_with_completion_params(decorator)
70+
71+
error_msg = str(exc_info.value)
72+
assert (
73+
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
74+
in error_msg
75+
)
76+
assert "Use [{'model': 'gpt-4'}] instead of {'model': 'gpt-4'}" in error_msg
77+
78+
def test_valid_list_argvalues_keyword(self):
79+
"""Test that a list as keyword argvalues works correctly."""
80+
dict_value = ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")])
81+
list_value = ast.List(elts=[dict_value], ctx=ast.Load())
82+
83+
decorator = create_parametrize_decorator("completion_params", list_value, use_keyword=True)
84+
85+
result = _is_pytest_parametrize_with_completion_params(decorator)
86+
assert result is True
87+
88+
def test_dynamic_error_simple_dict(self):
89+
"""Test dynamic error message with a simple dict."""
90+
decorator = create_parametrize_decorator(
91+
"completion_params", ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")])
92+
)
93+
94+
with pytest.raises(ValueError) as exc_info:
95+
_is_pytest_parametrize_with_completion_params(decorator)
96+
97+
error_msg = str(exc_info.value)
98+
assert (
99+
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
100+
in error_msg
101+
)
102+
assert "Use [{'model': 'gpt-4'}] instead of {'model': 'gpt-4'}" in error_msg
103+
104+
def test_dynamic_error_complex_dict(self):
105+
"""Test dynamic error message with a complex dict."""
106+
decorator = create_parametrize_decorator(
107+
"completion_params",
108+
ast.Dict(
109+
keys=[
110+
ast.Constant(value="model"),
111+
ast.Constant(value="temperature"),
112+
ast.Constant(value="max_tokens"),
113+
],
114+
values=[
115+
ast.Constant(value="accounts/fireworks/models/gpt-oss-120b"),
116+
ast.Constant(value=0.7),
117+
ast.Constant(value=1000),
118+
],
119+
),
120+
)
121+
122+
with pytest.raises(ValueError) as exc_info:
123+
_is_pytest_parametrize_with_completion_params(decorator)
124+
125+
error_msg = str(exc_info.value)
126+
assert (
127+
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
128+
in error_msg
129+
)
130+
assert "Use [{" in error_msg
131+
assert "}] instead of {" in error_msg
132+
assert "gpt-oss-120b" in error_msg
133+
assert "0.7" in error_msg
134+
assert "1000" in error_msg
135+
136+
def test_dynamic_error_nested_dict(self):
137+
"""Test dynamic error message with nested structures."""
138+
# Create a dict with nested dict
139+
nested_dict = ast.Dict(
140+
keys=[ast.Constant(value="config")],
141+
values=[
142+
ast.Dict(
143+
keys=[ast.Constant(value="model"), ast.Constant(value="api_key")],
144+
values=[ast.Constant(value="gpt-4"), ast.Constant(value="sk-123")],
145+
)
146+
],
147+
)
148+
149+
decorator = create_parametrize_decorator("completion_params", nested_dict)
150+
151+
with pytest.raises(ValueError) as exc_info:
152+
_is_pytest_parametrize_with_completion_params(decorator)
153+
154+
error_msg = str(exc_info.value)
155+
assert (
156+
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
157+
in error_msg
158+
)
159+
assert "Use [{" in error_msg
160+
assert "}] instead of {" in error_msg
161+
162+
def test_dynamic_error_boolean_values(self):
163+
"""Test dynamic error message with boolean values."""
164+
decorator = create_parametrize_decorator(
165+
"completion_params",
166+
ast.Dict(
167+
keys=[ast.Constant(value="stream"), ast.Constant(value="echo")],
168+
values=[ast.Constant(value=True), ast.Constant(value=False)],
169+
),
170+
)
171+
172+
with pytest.raises(ValueError) as exc_info:
173+
_is_pytest_parametrize_with_completion_params(decorator)
174+
175+
error_msg = str(exc_info.value)
176+
assert (
177+
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
178+
in error_msg
179+
)
180+
assert "True" in error_msg
181+
assert "False" in error_msg
182+
183+
def test_dynamic_error_empty_dict(self):
184+
"""Test dynamic error message with empty dict."""
185+
decorator = create_parametrize_decorator("completion_params", ast.Dict(keys=[], values=[]))
186+
187+
with pytest.raises(ValueError) as exc_info:
188+
_is_pytest_parametrize_with_completion_params(decorator)
189+
190+
error_msg = str(exc_info.value)
191+
assert (
192+
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
193+
in error_msg
194+
)
195+
assert "Use [{}] instead of {}" in error_msg
196+
197+
def test_valid_tuple_argvalues(self):
198+
"""Test that a tuple as argvalues works correctly."""
199+
dict_value = ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")])
200+
tuple_value = ast.Tuple(elts=[dict_value], ctx=ast.Load())
201+
202+
decorator = create_parametrize_decorator("completion_params", tuple_value)
203+
204+
result = _is_pytest_parametrize_with_completion_params(decorator)
205+
assert result is True
206+
207+
def test_non_parametrize_decorator(self):
208+
"""Test that non-parametrize decorators are ignored."""
209+
# Create a different decorator
210+
pytest_name = ast.Name(id="pytest", ctx=ast.Load())
211+
mark_attr = ast.Attribute(value=pytest_name, attr="mark", ctx=ast.Load())
212+
skipif_attr = ast.Attribute(value=mark_attr, attr="skipif", ctx=ast.Load())
213+
214+
decorator = ast.Call(func=skipif_attr, args=[ast.Constant(value=True)], keywords=[])
215+
216+
result = _is_pytest_parametrize_with_completion_params(decorator)
217+
assert result is False
218+
219+
def test_parametrize_without_completion_params(self):
220+
"""Test that parametrize without completion_params is ignored."""
221+
decorator = create_parametrize_decorator(
222+
"other_param", ast.List(elts=[ast.Constant(value="value")], ctx=ast.Load())
223+
)
224+
225+
result = _is_pytest_parametrize_with_completion_params(decorator)
226+
assert result is False

0 commit comments

Comments
 (0)