diff --git a/eval_protocol/pytest/parameterize.py b/eval_protocol/pytest/parameterize.py index f8c12259..7892d9c5 100644 --- a/eval_protocol/pytest/parameterize.py +++ b/eval_protocol/pytest/parameterize.py @@ -73,6 +73,9 @@ def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool: and decorator.func.value.attr == "mark" and decorator.func.attr == "parametrize" ): + # Validate argvalues if present + _validate_parametrize_argvalues(decorator) + # Check positional arguments first (argnames is typically the first positional arg) if len(decorator.args) > 0: argnames_arg = decorator.args[0] @@ -88,6 +91,90 @@ def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool: return False +def _ast_dict_to_string(dict_node: ast.Dict) -> str: + """ + Convert an AST Dict node to its string representation. + + Args: + dict_node: AST node representing a dictionary + + Returns: + String representation of the dictionary + """ + if not dict_node.keys: + return "{}" + + pairs = [] + for key, value in zip(dict_node.keys, dict_node.values): + if key is not None: + key_str = _ast_node_to_string(key) + value_str = _ast_node_to_string(value) + pairs.append(f"{key_str}: {value_str}") + + return "{" + ", ".join(pairs) + "}" + + +def _ast_node_to_string(node: ast.expr) -> str: + """ + Convert an AST node to its string representation. + + Args: + node: AST node to convert + + Returns: + String representation of the node + """ + if isinstance(node, ast.Constant): + if isinstance(node.value, str): + return repr(node.value) + else: + return str(node.value) + elif isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Dict): + return _ast_dict_to_string(node) + elif isinstance(node, ast.List): + elements = [_ast_node_to_string(elt) for elt in node.elts] + return "[" + ", ".join(elements) + "]" + elif isinstance(node, ast.Tuple): + elements = [_ast_node_to_string(elt) for elt in node.elts] + return "(" + ", ".join(elements) + ")" + else: + # For complex expressions, return a simplified representation + return "" + + +def _validate_parametrize_argvalues(decorator: ast.Call) -> None: + """ + Validate that pytest.mark.parametrize argvalues is a list/tuple, not a dict. + + Args: + decorator: AST node representing the pytest.mark.parametrize decorator call + + Raises: + ValueError: If argvalues is a dict instead of a list/tuple + """ + # Check positional arguments (argvalues is typically the second positional arg) + if len(decorator.args) > 1: + argvalues_arg = decorator.args[1] + if isinstance(argvalues_arg, ast.Dict): + dict_repr = _ast_dict_to_string(argvalues_arg) + raise ValueError( + f"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict. " + f"Use [{dict_repr}] instead of {dict_repr}." + ) + + # Check keyword arguments for argvalues + for keyword in decorator.keywords: + if keyword.arg == "argvalues": + if isinstance(keyword.value, ast.Dict): + dict_repr = _ast_dict_to_string(keyword.value) + raise ValueError( + f"For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict. " + f"Use [{dict_repr}] instead of {dict_repr}." + ) + + def _check_argnames_for_completion_params(argnames_node: ast.expr) -> bool: """ Check if an argnames AST node contains "completion_params". diff --git a/tests/pytest/test_parameterize_validation.py b/tests/pytest/test_parameterize_validation.py new file mode 100644 index 00000000..b116965d --- /dev/null +++ b/tests/pytest/test_parameterize_validation.py @@ -0,0 +1,226 @@ +""" +Test cases for pytest.mark.parametrize validation functionality. +""" + +import ast +import pytest +from eval_protocol.pytest.parameterize import _is_pytest_parametrize_with_completion_params + + +def create_parametrize_decorator(argnames, argvalues, use_keyword=False): + """Create a pytest.mark.parametrize decorator AST node.""" + pytest_name = ast.Name(id="pytest", ctx=ast.Load()) + mark_attr = ast.Attribute(value=pytest_name, attr="mark", ctx=ast.Load()) + parametrize_attr = ast.Attribute(value=mark_attr, attr="parametrize", ctx=ast.Load()) + + if use_keyword: + call = ast.Call( + func=parametrize_attr, + args=[], + keywords=[ + ast.keyword(arg="argnames", value=ast.Constant(value=argnames)), + ast.keyword(arg="argvalues", value=argvalues), + ], + ) + else: + call = ast.Call(func=parametrize_attr, args=[ast.Constant(value=argnames), argvalues], keywords=[]) + + return call + + +class TestParametrizeValidation: + """Test cases for pytest.mark.parametrize validation.""" + + def test_invalid_dict_argvalues_positional(self): + """Test that a dict as positional argvalues throws an error.""" + decorator = create_parametrize_decorator( + "completion_params", ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")]) + ) + + with pytest.raises(ValueError) as exc_info: + _is_pytest_parametrize_with_completion_params(decorator) + + error_msg = str(exc_info.value) + assert ( + "For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict" + in error_msg + ) + assert "Use [{'model': 'gpt-4'}] instead of {'model': 'gpt-4'}" in error_msg + + def test_valid_list_argvalues_positional(self): + """Test that a list as positional argvalues works correctly.""" + dict_value = ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")]) + list_value = ast.List(elts=[dict_value], ctx=ast.Load()) + + decorator = create_parametrize_decorator("completion_params", list_value) + + result = _is_pytest_parametrize_with_completion_params(decorator) + assert result is True + + def test_invalid_dict_argvalues_keyword(self): + """Test that a dict as keyword argvalues throws an error.""" + decorator = create_parametrize_decorator( + "completion_params", + ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")]), + use_keyword=True, + ) + + with pytest.raises(ValueError) as exc_info: + _is_pytest_parametrize_with_completion_params(decorator) + + error_msg = str(exc_info.value) + assert ( + "For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict" + in error_msg + ) + assert "Use [{'model': 'gpt-4'}] instead of {'model': 'gpt-4'}" in error_msg + + def test_valid_list_argvalues_keyword(self): + """Test that a list as keyword argvalues works correctly.""" + dict_value = ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")]) + list_value = ast.List(elts=[dict_value], ctx=ast.Load()) + + decorator = create_parametrize_decorator("completion_params", list_value, use_keyword=True) + + result = _is_pytest_parametrize_with_completion_params(decorator) + assert result is True + + def test_dynamic_error_simple_dict(self): + """Test dynamic error message with a simple dict.""" + decorator = create_parametrize_decorator( + "completion_params", ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")]) + ) + + with pytest.raises(ValueError) as exc_info: + _is_pytest_parametrize_with_completion_params(decorator) + + error_msg = str(exc_info.value) + assert ( + "For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict" + in error_msg + ) + assert "Use [{'model': 'gpt-4'}] instead of {'model': 'gpt-4'}" in error_msg + + def test_dynamic_error_complex_dict(self): + """Test dynamic error message with a complex dict.""" + decorator = create_parametrize_decorator( + "completion_params", + ast.Dict( + keys=[ + ast.Constant(value="model"), + ast.Constant(value="temperature"), + ast.Constant(value="max_tokens"), + ], + values=[ + ast.Constant(value="accounts/fireworks/models/gpt-oss-120b"), + ast.Constant(value=0.7), + ast.Constant(value=1000), + ], + ), + ) + + with pytest.raises(ValueError) as exc_info: + _is_pytest_parametrize_with_completion_params(decorator) + + error_msg = str(exc_info.value) + assert ( + "For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict" + in error_msg + ) + assert "Use [{" in error_msg + assert "}] instead of {" in error_msg + assert "gpt-oss-120b" in error_msg + assert "0.7" in error_msg + assert "1000" in error_msg + + def test_dynamic_error_nested_dict(self): + """Test dynamic error message with nested structures.""" + # Create a dict with nested dict + nested_dict = ast.Dict( + keys=[ast.Constant(value="config")], + values=[ + ast.Dict( + keys=[ast.Constant(value="model"), ast.Constant(value="api_key")], + values=[ast.Constant(value="gpt-4"), ast.Constant(value="sk-123")], + ) + ], + ) + + decorator = create_parametrize_decorator("completion_params", nested_dict) + + with pytest.raises(ValueError) as exc_info: + _is_pytest_parametrize_with_completion_params(decorator) + + error_msg = str(exc_info.value) + assert ( + "For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict" + in error_msg + ) + assert "Use [{" in error_msg + assert "}] instead of {" in error_msg + + def test_dynamic_error_boolean_values(self): + """Test dynamic error message with boolean values.""" + decorator = create_parametrize_decorator( + "completion_params", + ast.Dict( + keys=[ast.Constant(value="stream"), ast.Constant(value="echo")], + values=[ast.Constant(value=True), ast.Constant(value=False)], + ), + ) + + with pytest.raises(ValueError) as exc_info: + _is_pytest_parametrize_with_completion_params(decorator) + + error_msg = str(exc_info.value) + assert ( + "For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict" + in error_msg + ) + assert "True" in error_msg + assert "False" in error_msg + + def test_dynamic_error_empty_dict(self): + """Test dynamic error message with empty dict.""" + decorator = create_parametrize_decorator("completion_params", ast.Dict(keys=[], values=[])) + + with pytest.raises(ValueError) as exc_info: + _is_pytest_parametrize_with_completion_params(decorator) + + error_msg = str(exc_info.value) + assert ( + "For evaluation_test with completion_params, pytest.mark.parametrize argvalues must be a list or tuple, not a dict" + in error_msg + ) + assert "Use [{}] instead of {}" in error_msg + + def test_valid_tuple_argvalues(self): + """Test that a tuple as argvalues works correctly.""" + dict_value = ast.Dict(keys=[ast.Constant(value="model")], values=[ast.Constant(value="gpt-4")]) + tuple_value = ast.Tuple(elts=[dict_value], ctx=ast.Load()) + + decorator = create_parametrize_decorator("completion_params", tuple_value) + + result = _is_pytest_parametrize_with_completion_params(decorator) + assert result is True + + def test_non_parametrize_decorator(self): + """Test that non-parametrize decorators are ignored.""" + # Create a different decorator + pytest_name = ast.Name(id="pytest", ctx=ast.Load()) + mark_attr = ast.Attribute(value=pytest_name, attr="mark", ctx=ast.Load()) + skipif_attr = ast.Attribute(value=mark_attr, attr="skipif", ctx=ast.Load()) + + decorator = ast.Call(func=skipif_attr, args=[ast.Constant(value=True)], keywords=[]) + + result = _is_pytest_parametrize_with_completion_params(decorator) + assert result is False + + def test_parametrize_without_completion_params(self): + """Test that parametrize without completion_params is ignored.""" + decorator = create_parametrize_decorator( + "other_param", ast.List(elts=[ast.Constant(value="value")], ctx=ast.Load()) + ) + + result = _is_pytest_parametrize_with_completion_params(decorator) + assert result is False