Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions eval_protocol/pytest/parameterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool:
and decorator.func.value.attr == "mark"
and decorator.func.attr == "parametrize"
):
# Validate argvalues if present
_validate_parametrize_argvalues(decorator)

# Check positional arguments first (argnames is typically the first positional arg)
if len(decorator.args) > 0:
argnames_arg = decorator.args[0]
Expand All @@ -88,6 +91,90 @@ def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool:
return False


def _ast_dict_to_string(dict_node: ast.Dict) -> str:
"""
Convert an AST Dict node to its string representation.

Args:
dict_node: AST node representing a dictionary

Returns:
String representation of the dictionary
"""
if not dict_node.keys:
return "{}"

pairs = []
for key, value in zip(dict_node.keys, dict_node.values):
if key is not None:
key_str = _ast_node_to_string(key)
value_str = _ast_node_to_string(value)
pairs.append(f"{key_str}: {value_str}")

return "{" + ", ".join(pairs) + "}"


def _ast_node_to_string(node: ast.expr) -> str:
"""
Convert an AST node to its string representation.

Args:
node: AST node to convert

Returns:
String representation of the node
"""
if isinstance(node, ast.Constant):
if isinstance(node.value, str):
return repr(node.value)
else:
return str(node.value)
elif isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Dict):
return _ast_dict_to_string(node)
elif isinstance(node, ast.List):
elements = [_ast_node_to_string(elt) for elt in node.elts]
return "[" + ", ".join(elements) + "]"
elif isinstance(node, ast.Tuple):
elements = [_ast_node_to_string(elt) for elt in node.elts]
return "(" + ", ".join(elements) + ")"
else:
# For complex expressions, return a simplified representation
return "<complex expression>"


def _validate_parametrize_argvalues(decorator: ast.Call) -> None:
"""
Validate that pytest.mark.parametrize argvalues is a list/tuple, not a dict.

Args:
decorator: AST node representing the pytest.mark.parametrize decorator call

Raises:
ValueError: If argvalues is a dict instead of a list/tuple
"""
# Check positional arguments (argvalues is typically the second positional arg)
if len(decorator.args) > 1:
argvalues_arg = decorator.args[1]
if isinstance(argvalues_arg, ast.Dict):
dict_repr = _ast_dict_to_string(argvalues_arg)
raise ValueError(
f"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict. "
f"Use [{dict_repr}] instead of {dict_repr}."
)

# Check keyword arguments for argvalues
for keyword in decorator.keywords:
if keyword.arg == "argvalues":
if isinstance(keyword.value, ast.Dict):
dict_repr = _ast_dict_to_string(keyword.value)
raise ValueError(
f"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict. "
f"Use [{dict_repr}] instead of {dict_repr}."
)


def _check_argnames_for_completion_params(argnames_node: ast.expr) -> bool:
"""
Check if an argnames AST node contains "completion_params".
Expand Down
226 changes: 226 additions & 0 deletions tests/pytest/test_parameterize_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""
Test cases for pytest.mark.parametrize validation functionality.
"""

import ast
import pytest
from eval_protocol.pytest.parameterize import _is_pytest_parametrize_with_completion_params


def create_parametrize_decorator(argnames, argvalues, use_keyword=False):
"""Create a pytest.mark.parametrize decorator AST node."""
pytest_name = ast.Name(id="pytest", ctx=ast.Load())
mark_attr = ast.Attribute(value=pytest_name, attr="mark", ctx=ast.Load())
parametrize_attr = ast.Attribute(value=mark_attr, attr="parametrize", ctx=ast.Load())

if use_keyword:
call = ast.Call(
func=parametrize_attr,
args=[],
keywords=[
ast.keyword(arg="argnames", value=ast.Constant(value=argnames)),
ast.keyword(arg="argvalues", value=argvalues),
],
)
else:
call = ast.Call(func=parametrize_attr, args=[ast.Constant(value=argnames), argvalues], keywords=[])

return call


class TestParametrizeValidation:
"""Test cases for pytest.mark.parametrize validation."""

def test_invalid_dict_argvalues_positional(self):
"""Test that a dict as positional argvalues throws an error."""
decorator = create_parametrize_decorator(
"completion_params", ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")])
)

with pytest.raises(ValueError) as exc_info:
_is_pytest_parametrize_with_completion_params(decorator)

error_msg = str(exc_info.value)
assert (
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
in error_msg
)
assert "Use [{'model': 'gpt-4'}] instead of {'model': 'gpt-4'}" in error_msg

def test_valid_list_argvalues_positional(self):
"""Test that a list as positional argvalues works correctly."""
dict_value = ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")])
list_value = ast.List(elts=[dict_value], ctx=ast.Load())

decorator = create_parametrize_decorator("completion_params", list_value)

result = _is_pytest_parametrize_with_completion_params(decorator)
assert result is True

def test_invalid_dict_argvalues_keyword(self):
"""Test that a dict as keyword argvalues throws an error."""
decorator = create_parametrize_decorator(
"completion_params",
ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")]),
use_keyword=True,
)

with pytest.raises(ValueError) as exc_info:
_is_pytest_parametrize_with_completion_params(decorator)

error_msg = str(exc_info.value)
assert (
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
in error_msg
)
assert "Use [{'model': 'gpt-4'}] instead of {'model': 'gpt-4'}" in error_msg

def test_valid_list_argvalues_keyword(self):
"""Test that a list as keyword argvalues works correctly."""
dict_value = ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")])
list_value = ast.List(elts=[dict_value], ctx=ast.Load())

decorator = create_parametrize_decorator("completion_params", list_value, use_keyword=True)

result = _is_pytest_parametrize_with_completion_params(decorator)
assert result is True

def test_dynamic_error_simple_dict(self):
"""Test dynamic error message with a simple dict."""
decorator = create_parametrize_decorator(
"completion_params", ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")])
)

with pytest.raises(ValueError) as exc_info:
_is_pytest_parametrize_with_completion_params(decorator)

error_msg = str(exc_info.value)
assert (
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
in error_msg
)
assert "Use [{'model': 'gpt-4'}] instead of {'model': 'gpt-4'}" in error_msg

def test_dynamic_error_complex_dict(self):
"""Test dynamic error message with a complex dict."""
decorator = create_parametrize_decorator(
"completion_params",
ast.Dict(
keys=[
ast.Constant(value="model"),
ast.Constant(value="temperature"),
ast.Constant(value="max_tokens"),
],
values=[
ast.Constant(value="accounts/fireworks/models/gpt-oss-120b"),
ast.Constant(value=0.7),
ast.Constant(value=1000),
],
),
)

with pytest.raises(ValueError) as exc_info:
_is_pytest_parametrize_with_completion_params(decorator)

error_msg = str(exc_info.value)
assert (
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
in error_msg
)
assert "Use [{" in error_msg
assert "}] instead of {" in error_msg
assert "gpt-oss-120b" in error_msg
assert "0.7" in error_msg
assert "1000" in error_msg

def test_dynamic_error_nested_dict(self):
"""Test dynamic error message with nested structures."""
# Create a dict with nested dict
nested_dict = ast.Dict(
keys=[ast.Constant(value="config")],
values=[
ast.Dict(
keys=[ast.Constant(value="model"), ast.Constant(value="api_key")],
values=[ast.Constant(value="gpt-4"), ast.Constant(value="sk-123")],
)
],
)

decorator = create_parametrize_decorator("completion_params", nested_dict)

with pytest.raises(ValueError) as exc_info:
_is_pytest_parametrize_with_completion_params(decorator)

error_msg = str(exc_info.value)
assert (
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
in error_msg
)
assert "Use [{" in error_msg
assert "}] instead of {" in error_msg

def test_dynamic_error_boolean_values(self):
"""Test dynamic error message with boolean values."""
decorator = create_parametrize_decorator(
"completion_params",
ast.Dict(
keys=[ast.Constant(value="stream"), ast.Constant(value="echo")],
values=[ast.Constant(value=True), ast.Constant(value=False)],
),
)

with pytest.raises(ValueError) as exc_info:
_is_pytest_parametrize_with_completion_params(decorator)

error_msg = str(exc_info.value)
assert (
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
in error_msg
)
assert "True" in error_msg
assert "False" in error_msg

def test_dynamic_error_empty_dict(self):
"""Test dynamic error message with empty dict."""
decorator = create_parametrize_decorator("completion_params", ast.Dict(keys=[], values=[]))

with pytest.raises(ValueError) as exc_info:
_is_pytest_parametrize_with_completion_params(decorator)

error_msg = str(exc_info.value)
assert (
"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict"
in error_msg
)
assert "Use [{}] instead of {}" in error_msg

def test_valid_tuple_argvalues(self):
"""Test that a tuple as argvalues works correctly."""
dict_value = ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")])
tuple_value = ast.Tuple(elts=[dict_value], ctx=ast.Load())

decorator = create_parametrize_decorator("completion_params", tuple_value)

result = _is_pytest_parametrize_with_completion_params(decorator)
assert result is True

def test_non_parametrize_decorator(self):
"""Test that non-parametrize decorators are ignored."""
# Create a different decorator
pytest_name = ast.Name(id="pytest", ctx=ast.Load())
mark_attr = ast.Attribute(value=pytest_name, attr="mark", ctx=ast.Load())
skipif_attr = ast.Attribute(value=mark_attr, attr="skipif", ctx=ast.Load())

decorator = ast.Call(func=skipif_attr, args=[ast.Constant(value=True)], keywords=[])

result = _is_pytest_parametrize_with_completion_params(decorator)
assert result is False

def test_parametrize_without_completion_params(self):
"""Test that parametrize without completion_params is ignored."""
decorator = create_parametrize_decorator(
"other_param", ast.List(elts=[ast.Constant(value="value")], ctx=ast.Load())
)

result = _is_pytest_parametrize_with_completion_params(decorator)
assert result is False
Loading