Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 65 additions & 37 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,12 +730,20 @@ async def _collect_result(config, lst, max_retry):
_log_eval_error("error", data if "data" in locals() else None, passed=False)
raise

return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
if asyncio.iscoroutinefunction(test_func):
return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
else:

def sync_wrapper_body(**kwargs):
return asyncio.run(wrapper_body(**kwargs))

return create_dynamically_parameterized_wrapper(test_func, sync_wrapper_body, test_param_names)

# Create the pytest wrapper
pytest_wrapper = create_wrapper_with_signature()
pytest_wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(pytest_wrapper)
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)
if asyncio.iscoroutinefunction(test_func):
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)

def create_dual_mode_wrapper() -> Callable:
"""
Expand All @@ -756,46 +764,66 @@ def create_dual_mode_wrapper() -> Callable:
# Check if the test function is async
is_async = asyncio.iscoroutinefunction(test_func)

async def call_test_func(**call_kwargs):
"""Helper to call test_func with proper async/sync handling"""
if is_async:
return await test_func(**call_kwargs)
else:
return test_func(**call_kwargs)

async def dual_mode_wrapper(*args, **kwargs):
# Check if this is a direct call with the expected signature
if mode == "pointwise":
# For pointwise mode, check if called with a single row argument
if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs:
return await call_test_func(row=args[0])
else:
# For batch mode, check if called with rows argument
if (
len(args) == 1
and isinstance(args[0], list)
and all(isinstance(r, EvaluationRow) for r in args[0])
and not kwargs
):
return await call_test_func(rows=args[0])
# Also check if called with keyword argument 'rows'
if (
len(args) == 0
and "rows" in kwargs
and isinstance(kwargs["rows"], list)
and all(isinstance(r, EvaluationRow) for r in kwargs["rows"])
):
return await call_test_func(**kwargs)

# If not a direct call, use the pytest wrapper
return await pytest_wrapper(*args, **kwargs)
if is_async:

async def dual_mode_wrapper(*args, **kwargs):
# Check if this is a direct call with the expected signature
if mode == "pointwise":
# For pointwise mode, check if called with a single row argument
if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs:
return await test_func(row=args[0])
else:
# For batch mode, check if called with rows argument
if (
len(args) == 1
and isinstance(args[0], list)
and all(isinstance(r, EvaluationRow) for r in args[0])
and not kwargs
):
return await test_func(rows=args[0])
# Also check if called with keyword argument 'rows'
if (
len(args) == 0
and "rows" in kwargs
and isinstance(kwargs["rows"], list)
and all(isinstance(r, EvaluationRow) for r in kwargs["rows"])
):
return await test_func(**kwargs)

# If not a direct call, use the pytest wrapper
return await pytest_wrapper(*args, **kwargs)

_dual_model_wrapper_fn = dual_mode_wrapper
else:

def dual_mode_wrapper(*args, **kwargs):
if mode == "pointwise":
if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs:
return test_func(row=args[0])
else:
if (
len(args) == 1
and isinstance(args[0], list)
and all(isinstance(r, EvaluationRow) for r in args[0])
and not kwargs
):
return test_func(rows=args[0])
if (
"rows" in kwargs
and isinstance(kwargs["rows"], list)
and all(isinstance(r, EvaluationRow) for r in kwargs["rows"])
):
return test_func(**kwargs)
return pytest_wrapper(*args, **kwargs)

_dual_model_wrapper_fn = dual_mode_wrapper

# Copy all attributes from the pytest wrapper to our dual mode wrapper
import functools

functools.update_wrapper(dual_mode_wrapper, pytest_wrapper)
functools.update_wrapper(_dual_model_wrapper_fn, pytest_wrapper)

return dual_mode_wrapper
return _dual_model_wrapper_fn

# Create the dual mode wrapper
dual_mode_wrapper = create_dual_mode_wrapper()
Expand Down
21 changes: 18 additions & 3 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
)


def is_in_event_loop():
try:
asyncio.get_event_loop()
return True
except RuntimeError:
return False


def execute_function(func: Callable, **kwargs) -> Any:
"""
Execute a function with proper async handling.
Expand Down Expand Up @@ -98,9 +106,16 @@ def create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param
"""
from functools import wraps

@wraps(test_func)
async def wrapper(**kwargs):
return await wrapper_body(**kwargs)
if asyncio.iscoroutinefunction(wrapper_body):

@wraps(test_func)
async def wrapper(**kwargs):
return await wrapper_body(**kwargs)
else:

@wraps(test_func)
def wrapper(**kwargs):
return wrapper_body(**kwargs)

parameters = [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in test_param_names]
wrapper.__signature__ = inspect.Signature(parameters)
Expand Down
83 changes: 83 additions & 0 deletions tests/pytest/test_direct_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from eval_protocol.models import Message, EvaluationRow, EvaluateResult
from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test
from typing import List
import pytest


@evaluation_test(
input_messages=[
[
Message(role="user", content="What is the capital of France?"),
],
[
Message(role="user", content="What is the capital of the moon?"),
],
],
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
rollout_processor=SingleTurnRolloutProcessor(),
mode="all",
)
def test_direct_run(rows: List[EvaluationRow]) -> List[EvaluationRow]:
"""Run math evaluation on sample dataset using pytest interface."""
for idx, row in enumerate(rows):
row.evaluation_result = EvaluateResult(score=idx, reason="test")
return rows


def test_direct_run_main():
rows = [
EvaluationRow(
messages=[
Message(role="user", content="What is the capital of France?"),
],
),
EvaluationRow(
messages=[
Message(role="user", content="What is the capital of the moon?"),
],
),
]
res = test_direct_run(rows)
assert res[0].evaluation_result.score == 0
assert res[1].evaluation_result.score == 1


@evaluation_test(
input_messages=[
[
Message(role="user", content="What is the capital of France?"),
],
[
Message(role="user", content="What is the capital of the moon?"),
],
],
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
rollout_processor=SingleTurnRolloutProcessor(),
mode="all",
)
async def test_direct_run_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
"""Run math evaluation on sample dataset using pytest interface."""
for idx, row in enumerate(rows):
row.evaluation_result = EvaluateResult(score=idx, reason="test")
return rows


@pytest.mark.asyncio
async def test_direct_run_async_main():
rows = [
EvaluationRow(
messages=[
Message(role="user", content="1"),
],
),
EvaluationRow(
messages=[
Message(role="user", content="2"),
],
),
]
res = await test_direct_run_async(rows)
assert res[0].messages[0].content == "1"
assert res[1].messages[0].content == "2"
assert res[0].evaluation_result.score == 0
assert res[1].evaluation_result.score == 1
4 changes: 2 additions & 2 deletions tests/pytest/test_pytest_ensure_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import Mock, patch


async def test_ensure_logging(monkeypatch):
def test_ensure_logging(monkeypatch):
"""
Ensure that default SQLITE logger gets called by mocking the storage and checking that the storage is called.
"""
Expand Down Expand Up @@ -37,7 +37,7 @@ async def test_ensure_logging(monkeypatch):
def eval_fn(row: EvaluationRow) -> EvaluationRow:
return row

await eval_fn(
eval_fn(
dataset_path=["tests/pytest/data/markdown_dataset.jsonl"],
completion_params={"temperature": 0.0, "model": "dummy/local-model"},
)
Expand Down
8 changes: 4 additions & 4 deletions tests/pytest/test_pytest_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def read(self):
return list(self._rows.values())


async def test_evaluation_test_decorator(monkeypatch):
def test_evaluation_test_decorator(monkeypatch):
from eval_protocol.pytest.evaluation_test import evaluation_test

logger = InMemoryLogger()
Expand All @@ -45,13 +45,13 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:

# Manually invoke all parameter combinations within a single test
for ds_path in dataset_paths:
await eval_fn(dataset_path=[ds_path], completion_params={"temperature": 0.0, "model": "dummy/local-model"})
eval_fn(dataset_path=[ds_path], completion_params={"temperature": 0.0, "model": "dummy/local-model"})

# Assertions on IDs generated by the decorator logic
assert len(logger.read()) == 38


async def test_evaluation_test_decorator_ids_single(monkeypatch):
def test_evaluation_test_decorator_ids_single(monkeypatch):
in_memory_logger = InMemoryLogger()
unique_run_ids = set()
unique_experiment_ids = set()
Expand Down Expand Up @@ -97,7 +97,7 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
# Manually invoke all parameter combinations within a single test
for ds_path in dataset_paths:
for params in completion_params_list:
await eval_fn(dataset_path=[ds_path], completion_params=params)
eval_fn(dataset_path=[ds_path], completion_params=params)

# Assertions on IDs generated by the decorator logic
assert len(unique_invocation_ids) == 1
Expand Down
12 changes: 6 additions & 6 deletions tests/pytest/test_pytest_stable_row_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row


async def test_evaluation_test_decorator_ids_single():
def test_evaluation_test_decorator_ids_single():
from eval_protocol.pytest.evaluation_test import evaluation_test

row_ids = set()
Expand Down Expand Up @@ -35,18 +35,18 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
# Manually invoke all parameter combinations within a single test
for ds_path in input_dataset:
for params in completion_params_list:
await eval_fn(dataset_path=[ds_path], completion_params=params)
eval_fn(dataset_path=[ds_path], completion_params=params)

# Second invocation to ensure that IDs are stable across multiple invocations
for ds_path in input_dataset:
for params in completion_params_list:
await eval_fn(dataset_path=[ds_path], completion_params=params)
eval_fn(dataset_path=[ds_path], completion_params=params)

# Assertions on IDs generated by the decorator logic
assert len(row_ids) == 19 # from the markdown dataset


async def test_evaluation_test_generated_row_ids_without_dataset_keys():
def test_evaluation_test_generated_row_ids_without_dataset_keys():
from eval_protocol.pytest.evaluation_test import evaluation_test

# Adapter that does NOT set row_id; lets evaluation_test generate IDs
Expand Down Expand Up @@ -86,12 +86,12 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
# Single invocation (one dataset, one param set) with multiple runs
for ds_path in input_dataset:
for params in completion_params:
await eval_fn(dataset_path=[ds_path], completion_params=params)
eval_fn(dataset_path=[ds_path], completion_params=params)

# Second invocation to ensure that IDs are stable across multiple invocations
for ds_path in input_dataset:
for params in completion_params:
await eval_fn(dataset_path=[ds_path], completion_params=params)
eval_fn(dataset_path=[ds_path], completion_params=params)

# Even with multiple runs, generated row_ids should be stable within the invocation
assert len(row_ids) == 19 # equals dataset size when IDs are generated once and preserved across runs