Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .typed_interface import reward_function
from .quickstart import aha_judge, split_multi_turn_rows
from .pytest import evaluation_test, SingleTurnRolloutProcessor
from .pytest.parameterize import DefaultParameterIdGenerator

from .adapters import OpenAIResponsesAdapter

Expand All @@ -61,6 +62,7 @@
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")

__all__ = [
"DefaultParameterIdGenerator",
"aha_judge",
"split_multi_turn_rows",
"evaluation_test",
Expand Down
16 changes: 13 additions & 3 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,14 @@ def evaluation_test(
exception_handler_config: Configuration for exception handling and backoff retry logic.
If not provided, a default configuration will be used with common retryable exceptions.
"""
# Default to [None] when completion_params is not provided
# This allows evaluation-only tests (e.g., using NoOpRolloutProcessor)
# to work without requiring model generation parameters
if completion_params is None:
completion_params_provided = False
completion_params = [None]
else:
completion_params_provided = True
if rollout_processor is None:
rollout_processor = NoOpRolloutProcessor()

Expand Down Expand Up @@ -199,8 +205,10 @@ def decorator(
# Create parameter tuples for pytest.mark.parametrize
pytest_parametrize_args = pytest_parametrize(
combinations,
test_func,
input_dataset,
completion_params,
completion_params_provided,
input_messages,
input_rows,
evaluation_test_kwargs,
Expand Down Expand Up @@ -261,7 +269,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
index = abs(index) % (max_index + 1)
row.input_metadata.row_id = generate_id(seed=0, index=index)

completion_params = kwargs["completion_params"]
completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None
# Create eval metadata with test function info and current commit hash
eval_metadata = EvalMetadata(
name=test_func.__name__,
Expand Down Expand Up @@ -565,12 +573,14 @@ async def execute_run_with_progress(run_idx: int, config):
return create_dynamically_parameterized_wrapper(
test_func,
wrapper_body,
pytest_parametrize_args["argnames"],
pytest_parametrize_args["sig_parameters"],
)

# Create the pytest wrapper
pytest_wrapper = create_wrapper_with_signature()
pytest_wrapper = pytest.mark.parametrize(**pytest_parametrize_args)(pytest_wrapper)
pytest_wrapper = pytest.mark.parametrize(**pytest_parametrize_args["pytest_parametrize_kwargs"])(
pytest_wrapper
)
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)

# Create the dual mode wrapper
Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/pytest/generate_parameter_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
]


class ParameterizedTestKwargs(TypedDict):
class ParameterizedTestKwargs(TypedDict, total=False):
"""
These are the type of parameters that can be passed to the generated pytest
function. Every experiment is a unique combination of these parameters.
Expand Down
191 changes: 173 additions & 18 deletions eval_protocol/pytest/parameterize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import inspect
from typing import TypedDict, Protocol
from collections.abc import Callable, Sequence, Iterable, Awaitable
Expand All @@ -9,12 +10,133 @@
from eval_protocol.pytest.types import DatasetPathParam, EvaluationInputParam, InputMessagesParam, TestFunction


class PytestParametrizeArgs(TypedDict):
def _has_pytest_parametrize_with_completion_params(test_func: TestFunction) -> bool:
"""
Check if a test function has a pytest.mark.parametrize decorator with argnames="completion_params".

This function uses inspect.getsource and ast to parse the function's source code and look for
pytest.mark.parametrize decorators that include "completion_params" in their argnames.

Args:
test_func: The test function to analyze

Returns:
True if the function has a pytest.mark.parametrize decorator with "completion_params" in argnames,
False otherwise

Raises:
OSError: If the source code cannot be retrieved (e.g., function is defined in interactive mode)
SyntaxError: If the source code cannot be parsed as valid Python
"""
try:
source = inspect.getsource(test_func)
except OSError:
# Function source cannot be retrieved (e.g., defined in interactive mode)
return False

try:
tree = ast.parse(source)
except SyntaxError:
# Source code cannot be parsed
return False

# Walk through the AST to find pytest.mark.parametrize decorators
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
# Check decorators on this function
for decorator in node.decorator_list:
if _is_pytest_parametrize_with_completion_params(decorator):
return True

return False


def _is_pytest_parametrize_with_completion_params(decorator: ast.expr) -> bool:
"""
Check if a decorator is pytest.mark.parametrize with "completion_params" in argnames.

Args:
decorator: AST node representing a decorator

Returns:
True if this is a pytest.mark.parametrize decorator with "completion_params" in argnames
"""
# Look for pytest.mark.parametrize pattern
if isinstance(decorator, ast.Call):
# Check if it's pytest.mark.parametrize
if isinstance(decorator.func, ast.Attribute):
if (
isinstance(decorator.func.value, ast.Attribute)
and isinstance(decorator.func.value.value, ast.Name)
and decorator.func.value.value.id == "pytest"
and decorator.func.value.attr == "mark"
and decorator.func.attr == "parametrize"
):
# Check positional arguments first (argnames is typically the first positional arg)
if len(decorator.args) > 0:
argnames_arg = decorator.args[0]
if _check_argnames_for_completion_params(argnames_arg):
return True

# Check keyword arguments for argnames
for keyword in decorator.keywords:
if keyword.arg == "argnames":
if _check_argnames_for_completion_params(keyword.value):
return True

return False


def _check_argnames_for_completion_params(argnames_node: ast.expr) -> bool:
"""
Check if an argnames AST node contains "completion_params".

Args:
argnames_node: AST node representing the argnames value

Returns:
True if argnames contains "completion_params"
"""
if isinstance(argnames_node, ast.Constant):
# Single string case: argnames="completion_params"
if argnames_node.value == "completion_params":
return True
elif isinstance(argnames_node, ast.List):
# List case: argnames=["completion_params", ...]
for elt in argnames_node.elts:
if isinstance(elt, ast.Constant) and elt.value == "completion_params":
return True
elif isinstance(argnames_node, ast.Tuple):
# Tuple case: argnames=("completion_params", ...)
for elt in argnames_node.elts:
if isinstance(elt, ast.Constant) and elt.value == "completion_params":
return True

return False


class PytestMarkParametrizeKwargs(TypedDict):
argnames: Sequence[str]
argvalues: Iterable[ParameterSet | Sequence[object] | object]
ids: Iterable[str] | None


class ParametrizeArgs(TypedDict):
"""
This contains all the necessary information to properly hijack the test
function's signature and dynamically inject usage of
pytest.mark.parametrize. The two will differ when a user manually provides
the pytest.mark.parametrize decorator instead of passing completion_params
on their own.
"""

# for create_dynamically_parameterized_wrapper
sig_parameters: Sequence[str]

# for pytest.mark.parametrize
pytest_parametrize_kwargs: PytestMarkParametrizeKwargs


class ParameterIdGenerator(Protocol):
"""Protocol for generating pytest parameter IDs from parameter combinations."""

Expand All @@ -30,7 +152,7 @@ def generate_id(self, combo: CombinationTuple) -> str | None:
...


class DefaultParameterIdGenerator:
class DefaultParameterIdGenerator(ParameterIdGenerator):
"""Default ID generator that creates meaningful IDs from parameter combinations."""

def __init__(self, max_length: int = 200):
Expand All @@ -46,34 +168,49 @@ def generate_id(self, combo: CombinationTuple) -> str | None:
dataset, completion_params, messages, rows, evaluation_test_kwargs = combo

if completion_params:
# Get all string, numeric, and boolean values from completion_params, sorted by key
str_values = []
for key in sorted(completion_params.keys()):
value = completion_params[key]
if isinstance(value, (str, int, float, bool)):
str_values.append(str(value))
id = self.generate_id_from_dict(completion_params, self.max_length)
if id:
return id
else:
if rows:
return f"rows(len={len(rows)})"
elif messages:
return f"messages(len={len(messages)})"
elif dataset:
return f"dataset(len={len(dataset)})"
return None

if str_values:
id_str = ":".join(str_values)
@staticmethod
def generate_id_from_dict(d: dict[str, object], max_length: int = 200) -> str | None:
# Get all string, numeric, and boolean values from completion_params, sorted by key
str_values = []
for key in sorted(d.keys()):
value = d[key]
if isinstance(value, (str, int, float, bool)):
str_values.append(str(value))

# Truncate if too long
if len(id_str) > self.max_length:
id_str = id_str[: self.max_length - 3] + "..."
if str_values:
id_str = ":".join(str_values)

return id_str
# Truncate if too long
if len(id_str) > max_length:
id_str = id_str[: max_length - 3] + "..."

return id_str
return None


def pytest_parametrize(
combinations: list[CombinationTuple],
test_func: TestFunction | None,
input_dataset: Sequence[DatasetPathParam] | None,
completion_params: Sequence[CompletionParams | None] | None,
completion_params_provided: bool,
input_messages: Sequence[list[InputMessagesParam] | None] | None,
input_rows: Sequence[list[EvaluationRow]] | None,
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None,
id_generator: ParameterIdGenerator | None = None,
) -> PytestParametrizeArgs:
) -> ParametrizeArgs:
"""
This function dynamically generates pytest.mark.parametrize arguments for a given
set of combinations. This is the magic that allows developers to pass in their
Expand All @@ -82,18 +219,31 @@ def pytest_parametrize(
API.
"""

if test_func is not None:
has_pytest_parametrize = _has_pytest_parametrize_with_completion_params(test_func)
else:
has_pytest_parametrize = False

# Create parameter tuples for pytest.mark.parametrize
argnames: list[str] = []
sig_parameters: list[str] = []
if input_dataset is not None:
argnames.append("dataset_path")
sig_parameters.append("dataset_path")
if completion_params is not None:
argnames.append("completion_params")
if completion_params_provided and not has_pytest_parametrize:
argnames.append("completion_params")
if has_pytest_parametrize or completion_params_provided:
sig_parameters.append("completion_params")
if input_messages is not None:
argnames.append("input_messages")
sig_parameters.append("input_messages")
if input_rows is not None:
argnames.append("input_rows")
sig_parameters.append("input_rows")
if evaluation_test_kwargs is not None:
argnames.append("evaluation_test_kwargs")
sig_parameters.append("evaluation_test_kwargs")

# Use default ID generator if none provided
if id_generator is None:
Expand All @@ -109,7 +259,7 @@ def pytest_parametrize(
# Build parameter tuple based on what's provided
if input_dataset is not None:
param_tuple.append(dataset)
if completion_params is not None:
if completion_params_provided:
param_tuple.append(cp)
if input_messages is not None:
param_tuple.append(messages)
Expand All @@ -132,7 +282,12 @@ def pytest_parametrize(
ids.append(combo_id)

# Return None for ids if no IDs were generated (let pytest use defaults)
return PytestParametrizeArgs(argnames=argnames, argvalues=argvalues, ids=ids if ids else None)
return ParametrizeArgs(
pytest_parametrize_kwargs=PytestMarkParametrizeKwargs(
argnames=argnames, argvalues=argvalues, ids=ids if ids else None
),
sig_parameters=sig_parameters,
)


def create_dynamically_parameterized_wrapper(
Expand Down
12 changes: 8 additions & 4 deletions eval_protocol/quickstart/llm_judge_openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
EvaluationRow,
SingleTurnRolloutProcessor,
OpenAIResponsesAdapter,
DefaultParameterIdGenerator,
)

adapter = OpenAIResponsesAdapter()
Expand All @@ -41,17 +42,20 @@


@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Skip in CI")
@pytest.mark.asyncio
@evaluation_test(
input_rows=[input_rows],
completion_params=[
@pytest.mark.parametrize(
"completion_params",
[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
},
{
"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct-0905",
},
],
ids=DefaultParameterIdGenerator.generate_id_from_dict,
)
@evaluation_test(
input_rows=[input_rows],
rollout_processor=SingleTurnRolloutProcessor(),
preprocess_fn=split_multi_turn_rows,
mode="all",
Expand Down
Loading
Loading