diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index e46f58c9..a749a454 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -730,12 +730,20 @@ async def _collect_result(config, lst, max_retry): _log_eval_error("error", data if "data" in locals() else None, passed=False) raise - return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names) + if asyncio.iscoroutinefunction(test_func): + return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names) + else: + + def sync_wrapper_body(**kwargs): + return asyncio.run(wrapper_body(**kwargs)) + + return create_dynamically_parameterized_wrapper(test_func, sync_wrapper_body, test_param_names) # Create the pytest wrapper pytest_wrapper = create_wrapper_with_signature() pytest_wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(pytest_wrapper) - pytest_wrapper = pytest.mark.asyncio(pytest_wrapper) + if asyncio.iscoroutinefunction(test_func): + pytest_wrapper = pytest.mark.asyncio(pytest_wrapper) def create_dual_mode_wrapper() -> Callable: """ @@ -756,46 +764,66 @@ def create_dual_mode_wrapper() -> Callable: # Check if the test function is async is_async = asyncio.iscoroutinefunction(test_func) - async def call_test_func(**call_kwargs): - """Helper to call test_func with proper async/sync handling""" - if is_async: - return await test_func(**call_kwargs) - else: - return test_func(**call_kwargs) - - async def dual_mode_wrapper(*args, **kwargs): - # Check if this is a direct call with the expected signature - if mode == "pointwise": - # For pointwise mode, check if called with a single row argument - if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: - return await call_test_func(row=args[0]) - else: - # For batch mode, check if called with rows argument - if ( - len(args) == 1 - and isinstance(args[0], list) - and all(isinstance(r, EvaluationRow) for r in args[0]) - and not kwargs - ): - return await call_test_func(rows=args[0]) - # Also check if called with keyword argument 'rows' - if ( - len(args) == 0 - and "rows" in kwargs - and isinstance(kwargs["rows"], list) - and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) - ): - return await call_test_func(**kwargs) - - # If not a direct call, use the pytest wrapper - return await pytest_wrapper(*args, **kwargs) + if is_async: + + async def dual_mode_wrapper(*args, **kwargs): + # Check if this is a direct call with the expected signature + if mode == "pointwise": + # For pointwise mode, check if called with a single row argument + if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: + return await test_func(row=args[0]) + else: + # For batch mode, check if called with rows argument + if ( + len(args) == 1 + and isinstance(args[0], list) + and all(isinstance(r, EvaluationRow) for r in args[0]) + and not kwargs + ): + return await test_func(rows=args[0]) + # Also check if called with keyword argument 'rows' + if ( + len(args) == 0 + and "rows" in kwargs + and isinstance(kwargs["rows"], list) + and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) + ): + return await test_func(**kwargs) + + # If not a direct call, use the pytest wrapper + return await pytest_wrapper(*args, **kwargs) + + _dual_model_wrapper_fn = dual_mode_wrapper + else: + + def dual_mode_wrapper(*args, **kwargs): + if mode == "pointwise": + if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: + return test_func(row=args[0]) + else: + if ( + len(args) == 1 + and isinstance(args[0], list) + and all(isinstance(r, EvaluationRow) for r in args[0]) + and not kwargs + ): + return test_func(rows=args[0]) + if ( + "rows" in kwargs + and isinstance(kwargs["rows"], list) + and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) + ): + return test_func(**kwargs) + return pytest_wrapper(*args, **kwargs) + + _dual_model_wrapper_fn = dual_mode_wrapper # Copy all attributes from the pytest wrapper to our dual mode wrapper import functools - functools.update_wrapper(dual_mode_wrapper, pytest_wrapper) + functools.update_wrapper(_dual_model_wrapper_fn, pytest_wrapper) - return dual_mode_wrapper + return _dual_model_wrapper_fn # Create the dual mode wrapper dual_mode_wrapper = create_dual_mode_wrapper() diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index e0b8328a..8dce9c6d 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -17,6 +17,14 @@ ) +def is_in_event_loop(): + try: + asyncio.get_event_loop() + return True + except RuntimeError: + return False + + def execute_function(func: Callable, **kwargs) -> Any: """ Execute a function with proper async handling. @@ -98,9 +106,16 @@ def create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param """ from functools import wraps - @wraps(test_func) - async def wrapper(**kwargs): - return await wrapper_body(**kwargs) + if asyncio.iscoroutinefunction(wrapper_body): + + @wraps(test_func) + async def wrapper(**kwargs): + return await wrapper_body(**kwargs) + else: + + @wraps(test_func) + def wrapper(**kwargs): + return wrapper_body(**kwargs) parameters = [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in test_param_names] wrapper.__signature__ = inspect.Signature(parameters) diff --git a/tests/pytest/test_direct_run.py b/tests/pytest/test_direct_run.py new file mode 100644 index 00000000..7b9c7a5d --- /dev/null +++ b/tests/pytest/test_direct_run.py @@ -0,0 +1,83 @@ +from eval_protocol.models import Message, EvaluationRow, EvaluateResult +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test +from typing import List +import pytest + + +@evaluation_test( + input_messages=[ + [ + Message(role="user", content="What is the capital of France?"), + ], + [ + Message(role="user", content="What is the capital of the moon?"), + ], + ], + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], + rollout_processor=SingleTurnRolloutProcessor(), + mode="all", +) +def test_direct_run(rows: List[EvaluationRow]) -> List[EvaluationRow]: + """Run math evaluation on sample dataset using pytest interface.""" + for idx, row in enumerate(rows): + row.evaluation_result = EvaluateResult(score=idx, reason="test") + return rows + + +def test_direct_run_main(): + rows = [ + EvaluationRow( + messages=[ + Message(role="user", content="What is the capital of France?"), + ], + ), + EvaluationRow( + messages=[ + Message(role="user", content="What is the capital of the moon?"), + ], + ), + ] + res = test_direct_run(rows) + assert res[0].evaluation_result.score == 0 + assert res[1].evaluation_result.score == 1 + + +@evaluation_test( + input_messages=[ + [ + Message(role="user", content="What is the capital of France?"), + ], + [ + Message(role="user", content="What is the capital of the moon?"), + ], + ], + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], + rollout_processor=SingleTurnRolloutProcessor(), + mode="all", +) +async def test_direct_run_async(rows: List[EvaluationRow]) -> List[EvaluationRow]: + """Run math evaluation on sample dataset using pytest interface.""" + for idx, row in enumerate(rows): + row.evaluation_result = EvaluateResult(score=idx, reason="test") + return rows + + +@pytest.mark.asyncio +async def test_direct_run_async_main(): + rows = [ + EvaluationRow( + messages=[ + Message(role="user", content="1"), + ], + ), + EvaluationRow( + messages=[ + Message(role="user", content="2"), + ], + ), + ] + res = await test_direct_run_async(rows) + assert res[0].messages[0].content == "1" + assert res[1].messages[0].content == "2" + assert res[0].evaluation_result.score == 0 + assert res[1].evaluation_result.score == 1 diff --git a/tests/pytest/test_pytest_ensure_logging.py b/tests/pytest/test_pytest_ensure_logging.py index e57b3c8c..89828fe8 100644 --- a/tests/pytest/test_pytest_ensure_logging.py +++ b/tests/pytest/test_pytest_ensure_logging.py @@ -2,7 +2,7 @@ from unittest.mock import Mock, patch -async def test_ensure_logging(monkeypatch): +def test_ensure_logging(monkeypatch): """ Ensure that default SQLITE logger gets called by mocking the storage and checking that the storage is called. """ @@ -37,7 +37,7 @@ async def test_ensure_logging(monkeypatch): def eval_fn(row: EvaluationRow) -> EvaluationRow: return row - await eval_fn( + eval_fn( dataset_path=["tests/pytest/data/markdown_dataset.jsonl"], completion_params={"temperature": 0.0, "model": "dummy/local-model"}, ) diff --git a/tests/pytest/test_pytest_ids.py b/tests/pytest/test_pytest_ids.py index b6bb4a35..0a4053ed 100644 --- a/tests/pytest/test_pytest_ids.py +++ b/tests/pytest/test_pytest_ids.py @@ -19,7 +19,7 @@ def read(self): return list(self._rows.values()) -async def test_evaluation_test_decorator(monkeypatch): +def test_evaluation_test_decorator(monkeypatch): from eval_protocol.pytest.evaluation_test import evaluation_test logger = InMemoryLogger() @@ -45,13 +45,13 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: # Manually invoke all parameter combinations within a single test for ds_path in dataset_paths: - await eval_fn(dataset_path=[ds_path], completion_params={"temperature": 0.0, "model": "dummy/local-model"}) + eval_fn(dataset_path=[ds_path], completion_params={"temperature": 0.0, "model": "dummy/local-model"}) # Assertions on IDs generated by the decorator logic assert len(logger.read()) == 38 -async def test_evaluation_test_decorator_ids_single(monkeypatch): +def test_evaluation_test_decorator_ids_single(monkeypatch): in_memory_logger = InMemoryLogger() unique_run_ids = set() unique_experiment_ids = set() @@ -97,7 +97,7 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: # Manually invoke all parameter combinations within a single test for ds_path in dataset_paths: for params in completion_params_list: - await eval_fn(dataset_path=[ds_path], completion_params=params) + eval_fn(dataset_path=[ds_path], completion_params=params) # Assertions on IDs generated by the decorator logic assert len(unique_invocation_ids) == 1 diff --git a/tests/pytest/test_pytest_stable_row_id.py b/tests/pytest/test_pytest_stable_row_id.py index c2a5709a..a5d6f5db 100644 --- a/tests/pytest/test_pytest_stable_row_id.py +++ b/tests/pytest/test_pytest_stable_row_id.py @@ -5,7 +5,7 @@ from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row -async def test_evaluation_test_decorator_ids_single(): +def test_evaluation_test_decorator_ids_single(): from eval_protocol.pytest.evaluation_test import evaluation_test row_ids = set() @@ -35,18 +35,18 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: # Manually invoke all parameter combinations within a single test for ds_path in input_dataset: for params in completion_params_list: - await eval_fn(dataset_path=[ds_path], completion_params=params) + eval_fn(dataset_path=[ds_path], completion_params=params) # Second invocation to ensure that IDs are stable across multiple invocations for ds_path in input_dataset: for params in completion_params_list: - await eval_fn(dataset_path=[ds_path], completion_params=params) + eval_fn(dataset_path=[ds_path], completion_params=params) # Assertions on IDs generated by the decorator logic assert len(row_ids) == 19 # from the markdown dataset -async def test_evaluation_test_generated_row_ids_without_dataset_keys(): +def test_evaluation_test_generated_row_ids_without_dataset_keys(): from eval_protocol.pytest.evaluation_test import evaluation_test # Adapter that does NOT set row_id; lets evaluation_test generate IDs @@ -86,12 +86,12 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: # Single invocation (one dataset, one param set) with multiple runs for ds_path in input_dataset: for params in completion_params: - await eval_fn(dataset_path=[ds_path], completion_params=params) + eval_fn(dataset_path=[ds_path], completion_params=params) # Second invocation to ensure that IDs are stable across multiple invocations for ds_path in input_dataset: for params in completion_params: - await eval_fn(dataset_path=[ds_path], completion_params=params) + eval_fn(dataset_path=[ds_path], completion_params=params) # Even with multiple runs, generated row_ids should be stable within the invocation assert len(row_ids) == 19 # equals dataset size when IDs are generated once and preserved across runs