diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index cd1efd2c..e6c000d2 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -39,6 +39,7 @@ from .typed_interface import reward_function from .quickstart import aha_judge, split_multi_turn_rows from .pytest import evaluation_test, SingleTurnRolloutProcessor +from .pytest.parameterize import DefaultParameterIdGenerator from .adapters import OpenAIResponsesAdapter @@ -61,6 +62,7 @@ warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol") __all__ = [ + "DefaultParameterIdGenerator", "aha_judge", "split_multi_turn_rows", "evaluation_test", diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index a7ec65f3..e51d008b 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -158,8 +158,14 @@ def evaluation_test( exception_handler_config: Configuration for exception handling and backoff retry logic. If not provided, a default configuration will be used with common retryable exceptions. """ + # Default to [None] when completion_params is not provided + # This allows evaluation-only tests (e.g., using NoOpRolloutProcessor) + # to work without requiring model generation parameters if completion_params is None: + completion_params_provided = False completion_params = [None] + else: + completion_params_provided = True if rollout_processor is None: rollout_processor = NoOpRolloutProcessor() @@ -199,8 +205,10 @@ def decorator( # Create parameter tuples for pytest.mark.parametrize pytest_parametrize_args = pytest_parametrize( combinations, + test_func, input_dataset, completion_params, + completion_params_provided, input_messages, input_rows, evaluation_test_kwargs, @@ -261,7 +269,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo index = abs(index) % (max_index + 1) row.input_metadata.row_id = generate_id(seed=0, index=index) - completion_params = kwargs["completion_params"] + completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None # Create eval metadata with test function info and current commit hash eval_metadata = EvalMetadata( name=test_func.__name__, @@ -565,12 +573,14 @@ async def execute_run_with_progress(run_idx: int, config): return create_dynamically_parameterized_wrapper( test_func, wrapper_body, - pytest_parametrize_args["argnames"], + pytest_parametrize_args["sig_parameters"], ) # Create the pytest wrapper pytest_wrapper = create_wrapper_with_signature() - pytest_wrapper = pytest.mark.parametrize(**pytest_parametrize_args)(pytest_wrapper) + pytest_wrapper = pytest.mark.parametrize(**pytest_parametrize_args["pytest_parametrize_kwargs"])( + pytest_wrapper + ) pytest_wrapper = pytest.mark.asyncio(pytest_wrapper) # Create the dual mode wrapper diff --git a/eval_protocol/pytest/generate_parameter_combinations.py b/eval_protocol/pytest/generate_parameter_combinations.py index 6a1dcf2f..99c37b74 100644 --- a/eval_protocol/pytest/generate_parameter_combinations.py +++ b/eval_protocol/pytest/generate_parameter_combinations.py @@ -31,7 +31,7 @@ ] -class ParameterizedTestKwargs(TypedDict): +class ParameterizedTestKwargs(TypedDict, total=False): """ These are the type of parameters that can be passed to the generated pytest function. Every experiment is a unique combination of these parameters. diff --git a/eval_protocol/pytest/parameterize.py b/eval_protocol/pytest/parameterize.py index cba8f65c..a2140da5 100644 --- a/eval_protocol/pytest/parameterize.py +++ b/eval_protocol/pytest/parameterize.py @@ -1,3 +1,4 @@ +import ast import inspect from typing import TypedDict, Protocol from collections.abc import Callable, Sequence, Iterable, Awaitable @@ -9,12 +10,133 @@ from eval_protocol.pytest.types import DatasetPathParam, EvaluationInputParam, InputMessagesParam, TestFunction -class PytestParametrizeArgs(TypedDict): +def _has_pytest_parametrize_with_completion_params(test_func: TestFunction) -> bool: + """ + Check if a test function has a pytest.mark.parametrize decorator with argnames="completion_params". + + This function uses inspect.getsource and ast to parse the function's source code and look for + pytest.mark.parametrize decorators that include "completion_params" in their argnames. + + Args: + test_func: The test function to analyze + + Returns: + True if the function has a pytest.mark.parametrize decorator with "completion_params" in argnames, + False otherwise + + Raises: + OSError: If the source code cannot be retrieved (e.g., function is defined in interactive mode) + SyntaxError: If the source code cannot be parsed as valid Python + """ + try: + source = inspect.getsource(test_func) + except OSError: + # Function source cannot be retrieved (e.g., defined in interactive mode) + return False + + try: + tree = ast.parse(source) + except SyntaxError: + # Source code cannot be parsed + return False + + # Walk through the AST to find pytest.mark.parametrize decorators + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef): + # Check decorators on this function + for decorator in node.decorator_list: + if _is_pytest_parametrize_with_completion_params(decorator): + return True + + return False + + +def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool: + """ + Check if a decorator is pytest.mark.parametrize with "completion_params" in argnames. + + Args: + decorator: AST node representing a decorator + + Returns: + True if this is a pytest.mark.parametrize decorator with "completion_params" in argnames + """ + # Look for pytest.mark.parametrize pattern + if isinstance(decorator, ast.Call): + # Check if it's pytest.mark.parametrize + if isinstance(decorator.func, ast.Attribute): + if ( + isinstance(decorator.func.value, ast.Attribute) + and isinstance(decorator.func.value.value, ast.Name) + and decorator.func.value.value.id == "pytest" + and decorator.func.value.attr == "mark" + and decorator.func.attr == "parametrize" + ): + # Check positional arguments first (argnames is typically the first positional arg) + if len(decorator.args) > 0: + argnames_arg = decorator.args[0] + if _check_argnames_for_completion_params(argnames_arg): + return True + + # Check keyword arguments for argnames + for keyword in decorator.keywords: + if keyword.arg == "argnames": + if _check_argnames_for_completion_params(keyword.value): + return True + + return False + + +def _check_argnames_for_completion_params(argnames_node: ast.expr) -> bool: + """ + Check if an argnames AST node contains "completion_params". + + Args: + argnames_node: AST node representing the argnames value + + Returns: + True if argnames contains "completion_params" + """ + if isinstance(argnames_node, ast.Constant): + # Single string case: argnames="completion_params" + if argnames_node.value == "completion_params": + return True + elif isinstance(argnames_node, ast.List): + # List case: argnames=["completion_params", ...] + for elt in argnames_node.elts: + if isinstance(elt, ast.Constant) and elt.value == "completion_params": + return True + elif isinstance(argnames_node, ast.Tuple): + # Tuple case: argnames=("completion_params", ...) + for elt in argnames_node.elts: + if isinstance(elt, ast.Constant) and elt.value == "completion_params": + return True + + return False + + +class PytestMarkParametrizeKwargs(TypedDict): argnames: Sequence[str] argvalues: Iterable[ParameterSet | Sequence[object] | object] ids: Iterable[str] | None +class ParametrizeArgs(TypedDict): + """ + This contains all the necessary information to properly hijack the test + function's signature and dynamically inject usage of + pytest.mark.parametrize. The two will differ when a user manually provides + the pytest.mark.parametrize decorator instead of passing completion_params + on their own. + """ + + # for create_dynamically_parameterized_wrapper + sig_parameters: Sequence[str] + + # for pytest.mark.parametrize + pytest_parametrize_kwargs: PytestMarkParametrizeKwargs + + class ParameterIdGenerator(Protocol): """Protocol for generating pytest parameter IDs from parameter combinations.""" @@ -30,7 +152,7 @@ def generate_id(self, combo: CombinationTuple) -> str | None: ... -class DefaultParameterIdGenerator: +class DefaultParameterIdGenerator(ParameterIdGenerator): """Default ID generator that creates meaningful IDs from parameter combinations.""" def __init__(self, max_length: int = 200): @@ -46,34 +168,49 @@ def generate_id(self, combo: CombinationTuple) -> str | None: dataset, completion_params, messages, rows, evaluation_test_kwargs = combo if completion_params: - # Get all string, numeric, and boolean values from completion_params, sorted by key - str_values = [] - for key in sorted(completion_params.keys()): - value = completion_params[key] - if isinstance(value, (str, int, float, bool)): - str_values.append(str(value)) + id = self.generate_id_from_dict(completion_params, self.max_length) + if id: + return id + else: + if rows: + return f"rows(len={len(rows)})" + elif messages: + return f"messages(len={len(messages)})" + elif dataset: + return f"dataset(len={len(dataset)})" + return None - if str_values: - id_str = ":".join(str_values) + @staticmethod + def generate_id_from_dict(d: dict[str, object], max_length: int = 200) -> str | None: + # Get all string, numeric, and boolean values from completion_params, sorted by key + str_values = [] + for key in sorted(d.keys()): + value = d[key] + if isinstance(value, (str, int, float, bool)): + str_values.append(str(value)) - # Truncate if too long - if len(id_str) > self.max_length: - id_str = id_str[: self.max_length - 3] + "..." + if str_values: + id_str = ":".join(str_values) - return id_str + # Truncate if too long + if len(id_str) > max_length: + id_str = id_str[: max_length - 3] + "..." + return id_str return None def pytest_parametrize( combinations: list[CombinationTuple], + test_func: TestFunction | None, input_dataset: Sequence[DatasetPathParam] | None, completion_params: Sequence[CompletionParams | None] | None, + completion_params_provided: bool, input_messages: Sequence[list[InputMessagesParam] | None] | None, input_rows: Sequence[list[EvaluationRow]] | None, evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None, id_generator: ParameterIdGenerator | None = None, -) -> PytestParametrizeArgs: +) -> ParametrizeArgs: """ This function dynamically generates pytest.mark.parametrize arguments for a given set of combinations. This is the magic that allows developers to pass in their @@ -82,18 +219,31 @@ def pytest_parametrize( API. """ + if test_func is not None: + has_pytest_parametrize = _has_pytest_parametrize_with_completion_params(test_func) + else: + has_pytest_parametrize = False + # Create parameter tuples for pytest.mark.parametrize argnames: list[str] = [] + sig_parameters: list[str] = [] if input_dataset is not None: argnames.append("dataset_path") + sig_parameters.append("dataset_path") if completion_params is not None: - argnames.append("completion_params") + if completion_params_provided and not has_pytest_parametrize: + argnames.append("completion_params") + if has_pytest_parametrize or completion_params_provided: + sig_parameters.append("completion_params") if input_messages is not None: argnames.append("input_messages") + sig_parameters.append("input_messages") if input_rows is not None: argnames.append("input_rows") + sig_parameters.append("input_rows") if evaluation_test_kwargs is not None: argnames.append("evaluation_test_kwargs") + sig_parameters.append("evaluation_test_kwargs") # Use default ID generator if none provided if id_generator is None: @@ -109,7 +259,7 @@ def pytest_parametrize( # Build parameter tuple based on what's provided if input_dataset is not None: param_tuple.append(dataset) - if completion_params is not None: + if completion_params_provided: param_tuple.append(cp) if input_messages is not None: param_tuple.append(messages) @@ -132,7 +282,12 @@ def pytest_parametrize( ids.append(combo_id) # Return None for ids if no IDs were generated (let pytest use defaults) - return PytestParametrizeArgs(argnames=argnames, argvalues=argvalues, ids=ids if ids else None) + return ParametrizeArgs( + pytest_parametrize_kwargs=PytestMarkParametrizeKwargs( + argnames=argnames, argvalues=argvalues, ids=ids if ids else None + ), + sig_parameters=sig_parameters, + ) def create_dynamically_parameterized_wrapper( diff --git a/eval_protocol/quickstart/llm_judge_openai_responses.py b/eval_protocol/quickstart/llm_judge_openai_responses.py index 5d8cb983..06fe502c 100644 --- a/eval_protocol/quickstart/llm_judge_openai_responses.py +++ b/eval_protocol/quickstart/llm_judge_openai_responses.py @@ -27,6 +27,7 @@ EvaluationRow, SingleTurnRolloutProcessor, OpenAIResponsesAdapter, + DefaultParameterIdGenerator, ) adapter = OpenAIResponsesAdapter() @@ -41,10 +42,9 @@ @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI") -@pytest.mark.asyncio -@evaluation_test( - input_rows=[input_rows], - completion_params=[ +@pytest.mark.parametrize( + "completion_params", + [ { "model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", }, @@ -52,6 +52,10 @@ "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct-0905", }, ], + ids=DefaultParameterIdGenerator.generate_id_from_dict, +) +@evaluation_test( + input_rows=[input_rows], rollout_processor=SingleTurnRolloutProcessor(), preprocess_fn=split_multi_turn_rows, mode="all", diff --git a/tests/pytest/test_parameterized_ids.py b/tests/pytest/test_parameterized_ids.py index b182bfe5..d3363d0c 100644 --- a/tests/pytest/test_parameterized_ids.py +++ b/tests/pytest/test_parameterized_ids.py @@ -1,12 +1,47 @@ +from collections.abc import Awaitable, Callable + +import pytest from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.parameterize import DefaultParameterIdGenerator, pytest_parametrize from eval_protocol.pytest.generate_parameter_combinations import generate_parameter_combinations +from eval_protocol.pytest.types import TestFunction + + +def verify_parametrize_mark(test_function: TestFunction, expected_ids_set: list[object]): + # The function should exist and be callable + assert test_function is not None + assert callable(test_function) + + # Test that the decorator was applied (function should have pytest marks) + import pytest + + marks = getattr(test_function, "pytestmark", []) + assert len(marks) > 0, "Function should have pytest marks from evaluation_test decorator" + + # Verify it's a parametrize mark + parametrize_marks = [mark for mark in marks if hasattr(mark, "name") and mark.name == "parametrize"] + assert len(parametrize_marks) > 0, "Should have parametrize mark" + + assert len(parametrize_marks) == len(expected_ids_set), ( + f"Expected {len(expected_ids_set)} parametrize marks, got {len(parametrize_marks)}" + ) + + # Check that the parametrize mark has IDs + for parametrize_mark, expected_ids in zip(parametrize_marks, expected_ids_set): + assert hasattr(parametrize_mark, "kwargs"), "Parametrize mark should have kwargs" + assert "ids" in parametrize_mark.kwargs, "Should have ids in kwargs" + + # Extract the IDs from the parametrize mark + ids = parametrize_mark.kwargs.get("ids") + if not ids: + raise ValueError("No IDs found in parametrize mark") + # Should have IDs for all parameters that have string/numeric values + assert ids == expected_ids, f"Expected {expected_ids}, got {ids}" def test_parameterized_ids(): """Test that evaluation_test generates proper parameter IDs.""" - collected_ids = [] @evaluation_test( input_messages=[[[Message(role="user", content="Hello, how are you?")]]], @@ -17,35 +52,38 @@ def test_parameterized_ids(): ], ) def test_parameterized_ids(row: EvaluationRow) -> EvaluationRow: - # Collect the row to verify it was processed - collected_ids.append(row.input_metadata.row_id) return row - # The function should exist and be callable - assert test_parameterized_ids is not None - assert callable(test_parameterized_ids) - - # Test that the decorator was applied (function should have pytest marks) - import pytest + verify_parametrize_mark( + test_parameterized_ids, [["fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "gpt-4", "0.5"]] + ) - marks = getattr(test_parameterized_ids, "pytestmark", []) - assert len(marks) > 0, "Function should have pytest marks from evaluation_test decorator" - # Verify it's a parametrize mark - parametrize_marks = [mark for mark in marks if hasattr(mark, "name") and mark.name == "parametrize"] - assert len(parametrize_marks) > 0, "Should have parametrize mark" +def test_parametrized_ids_with_manual_decorator_and_input_rows(): + """Test that evaluation_test generates proper parameter IDs.""" - # Check that the parametrize mark has IDs - parametrize_mark = parametrize_marks[0] - assert hasattr(parametrize_mark, "kwargs"), "Parametrize mark should have kwargs" - assert "ids" in parametrize_mark.kwargs, "Should have ids in kwargs" + @pytest.mark.parametrize( + "completion_params", + [ + {"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}, + {"model": "gpt-4"}, + {"temperature": 0.5}, + ], + ids=DefaultParameterIdGenerator.generate_id_from_dict, + ) + @evaluation_test( + input_rows=[[EvaluationRow(messages=[Message(role="user", content="Hello, how are you?")])]], + ) + def test_parameterized_ids(row: EvaluationRow) -> EvaluationRow: + return row - # Extract the IDs from the parametrize mark - ids = parametrize_mark.kwargs.get("ids") - if ids is not None: - # Should have IDs for all parameters that have string/numeric values - expected_ids = ["fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "gpt-4", "0.5"] - assert list(ids) == expected_ids, f"Expected {expected_ids}, got {list(ids)}" + verify_parametrize_mark( + test_parameterized_ids, + [ + ["rows(len=1)"], + DefaultParameterIdGenerator.generate_id_from_dict, + ], + ) def test_default_id_generator(): @@ -111,16 +149,18 @@ def test_pytest_parametrize_with_custom_id_generator(): # Test with default generator result = pytest_parametrize( combinations=combinations, + test_func=None, input_dataset=None, completion_params=[{"model": "gpt-4"}, {"model": "claude-3"}, {"temperature": 0.5}], + completion_params_provided=True, input_messages=None, input_rows=None, evaluation_test_kwargs=None, ) - assert result["argnames"] == ["completion_params"] - assert len(list(result["argvalues"])) == 3 - assert result["ids"] == ["gpt-4", "claude-3", "0.5"] # All have string/numeric values + assert result["pytest_parametrize_kwargs"]["argnames"] == ["completion_params"] + assert len(list(result["pytest_parametrize_kwargs"]["argvalues"])) == 3 + assert result["pytest_parametrize_kwargs"]["ids"] == ["gpt-4", "claude-3", "0.5"] # All have string/numeric values def test_id_generator_max_length():