From 44b6f27ae884b393b9f013a9094b9c86b1ac3a55 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Tue, 19 Aug 2025 18:58:24 -0700 Subject: [PATCH 1/4] properly handle sync and async functio wrap --- eval_protocol/pytest/evaluation_test.py | 89 +++++++++++++++---------- eval_protocol/pytest/utils.py | 11 ++- tests/pytest/test_direct_run.py | 82 +++++++++++++++++++++++ 3 files changed, 142 insertions(+), 40 deletions(-) create mode 100644 tests/pytest/test_direct_run.py diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index e46f58c9..c02e0e1f 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -729,13 +729,19 @@ async def _collect_result(config, lst, max_retry): except Exception: _log_eval_error("error", data if "data" in locals() else None, passed=False) raise + if asyncio.iscoroutinefunction(test_func): + return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names) + else: + def sync_wrapper_body(**kwargs): + return asyncio.run(wrapper_body(**kwargs)) + return create_dynamically_parameterized_wrapper(test_func, sync_wrapper_body, test_param_names) - return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names) # Create the pytest wrapper pytest_wrapper = create_wrapper_with_signature() pytest_wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(pytest_wrapper) - pytest_wrapper = pytest.mark.asyncio(pytest_wrapper) + if asyncio.iscoroutinefunction(test_func): + pytest_wrapper = pytest.mark.asyncio(pytest_wrapper) def create_dual_mode_wrapper() -> Callable: """ @@ -756,46 +762,55 @@ def create_dual_mode_wrapper() -> Callable: # Check if the test function is async is_async = asyncio.iscoroutinefunction(test_func) - async def call_test_func(**call_kwargs): - """Helper to call test_func with proper async/sync handling""" - if is_async: - return await test_func(**call_kwargs) - else: - return test_func(**call_kwargs) - - async def dual_mode_wrapper(*args, **kwargs): - # Check if this is a direct call with the expected signature - if mode == "pointwise": - # For pointwise mode, check if called with a single row argument - if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: - return await call_test_func(row=args[0]) - else: - # For batch mode, check if called with rows argument - if ( - len(args) == 1 - and isinstance(args[0], list) - and all(isinstance(r, EvaluationRow) for r in args[0]) - and not kwargs - ): - return await call_test_func(rows=args[0]) - # Also check if called with keyword argument 'rows' - if ( - len(args) == 0 - and "rows" in kwargs - and isinstance(kwargs["rows"], list) - and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) - ): - return await call_test_func(**kwargs) - - # If not a direct call, use the pytest wrapper - return await pytest_wrapper(*args, **kwargs) + if is_async: + async def dual_mode_wrapper(*args, **kwargs): + # Check if this is a direct call with the expected signature + if mode == "pointwise": + # For pointwise mode, check if called with a single row argument + if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: + return await test_func(row=args[0]) + else: + # For batch mode, check if called with rows argument + if ( + len(args) == 1 + and isinstance(args[0], list) + and all(isinstance(r, EvaluationRow) for r in args[0]) + and not kwargs + ): + return await test_func(rows=args[0]) + # Also check if called with keyword argument 'rows' + if ( + len(args) == 0 + and "rows" in kwargs + and isinstance(kwargs["rows"], list) + and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) + ): + return await test_func(**kwargs) + + # If not a direct call, use the pytest wrapper + return await pytest_wrapper(*args, **kwargs) + + _dual_model_wrapper_fn = dual_mode_wrapper + else: + def dual_mode_wrapper(*args, **kwargs): + if mode == "pointwise": + if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: + return test_func(row=args[0]) + else: + if len(args) == 1 and isinstance(args[0], list) and all(isinstance(r, EvaluationRow) for r in args[0]) and not kwargs: + return test_func(rows=args[0]) + if "rows" in kwargs and isinstance(kwargs["rows"], list) and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]): + return test_func(**kwargs) + return pytest_wrapper(*args, **kwargs) + + _dual_model_wrapper_fn = dual_mode_wrapper # Copy all attributes from the pytest wrapper to our dual mode wrapper import functools - functools.update_wrapper(dual_mode_wrapper, pytest_wrapper) + functools.update_wrapper(_dual_model_wrapper_fn, pytest_wrapper) - return dual_mode_wrapper + return _dual_model_wrapper_fn # Create the dual mode wrapper dual_mode_wrapper = create_dual_mode_wrapper() diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index e0b8328a..ec1bd362 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -98,9 +98,14 @@ def create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param """ from functools import wraps - @wraps(test_func) - async def wrapper(**kwargs): - return await wrapper_body(**kwargs) + if asyncio.iscoroutinefunction(wrapper_body): + @wraps(test_func) + async def wrapper(**kwargs): + return await wrapper_body(**kwargs) + else: + @wraps(test_func) + def wrapper(**kwargs): + return wrapper_body(**kwargs) parameters = [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in test_param_names] wrapper.__signature__ = inspect.Signature(parameters) diff --git a/tests/pytest/test_direct_run.py b/tests/pytest/test_direct_run.py new file mode 100644 index 00000000..a35d5490 --- /dev/null +++ b/tests/pytest/test_direct_run.py @@ -0,0 +1,82 @@ +from eval_protocol.models import Message, EvaluationRow, EvaluateResult +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test +from typing import List +import pytest + + +@evaluation_test( + input_messages=[ + [ + Message(role="user", content="What is the capital of France?"), + ], + [ + Message(role="user", content="What is the capital of the moon?"), + ], + ], + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], + rollout_processor=SingleTurnRolloutProcessor(), + mode="listwise", +) +def test_direct_run(rows: List[EvaluationRow]) -> List[EvaluationRow]: + """Run math evaluation on sample dataset using pytest interface.""" + for idx, row in enumerate(rows): + row.evaluation_result = EvaluateResult(score=idx, reason="test") + return rows + + +def test_direct_run_main(): + rows = [ + EvaluationRow( + messages=[ + Message(role="user", content="What is the capital of France?"), + ], + ), + EvaluationRow( + messages=[ + Message(role="user", content="What is the capital of the moon?"), + ], + ), + ] + res = test_direct_run(rows) + assert res[0].evaluation_result.score == 0 + assert res[1].evaluation_result.score == 1 + + +@evaluation_test( + input_messages=[ + [ + Message(role="user", content="What is the capital of France?"), + ], + [ + Message(role="user", content="What is the capital of the moon?"), + ], + ], + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], + rollout_processor=SingleTurnRolloutProcessor(), + mode="listwise", +) +async def test_direct_run_async(rows: List[EvaluationRow]) -> List[EvaluationRow]: + """Run math evaluation on sample dataset using pytest interface.""" + for idx, row in enumerate(rows): + row.evaluation_result = EvaluateResult(score=idx, reason="test") + return rows + + + +@pytest.mark.asyncio +async def test_direct_run_async_main(): + rows = [ + EvaluationRow( + messages=[ + Message(role="user", content="What is the capital of France?"), + ], + ), + EvaluationRow( + messages=[ + Message(role="user", content="What is the capital of the moon?"), + ], + ), + ] + res = await test_direct_run_async(rows) + assert res[0].evaluation_result.score == 0 + assert res[1].evaluation_result.score == 1 \ No newline at end of file From 999ed18a5d4027b1e7db1d72842731d75134469d Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Tue, 19 Aug 2025 21:55:53 -0700 Subject: [PATCH 2/4] format --- eval_protocol/pytest/evaluation_test.py | 23 ++++++++++++++++++----- eval_protocol/pytest/utils.py | 2 ++ tests/pytest/test_direct_run.py | 13 +++++++------ 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index c02e0e1f..a749a454 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -729,13 +729,15 @@ async def _collect_result(config, lst, max_retry): except Exception: _log_eval_error("error", data if "data" in locals() else None, passed=False) raise + if asyncio.iscoroutinefunction(test_func): return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names) else: + def sync_wrapper_body(**kwargs): return asyncio.run(wrapper_body(**kwargs)) - return create_dynamically_parameterized_wrapper(test_func, sync_wrapper_body, test_param_names) + return create_dynamically_parameterized_wrapper(test_func, sync_wrapper_body, test_param_names) # Create the pytest wrapper pytest_wrapper = create_wrapper_with_signature() @@ -763,6 +765,7 @@ def create_dual_mode_wrapper() -> Callable: is_async = asyncio.iscoroutinefunction(test_func) if is_async: + async def dual_mode_wrapper(*args, **kwargs): # Check if this is a direct call with the expected signature if mode == "pointwise": @@ -789,20 +792,30 @@ async def dual_mode_wrapper(*args, **kwargs): # If not a direct call, use the pytest wrapper return await pytest_wrapper(*args, **kwargs) - + _dual_model_wrapper_fn = dual_mode_wrapper else: + def dual_mode_wrapper(*args, **kwargs): if mode == "pointwise": if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: return test_func(row=args[0]) else: - if len(args) == 1 and isinstance(args[0], list) and all(isinstance(r, EvaluationRow) for r in args[0]) and not kwargs: + if ( + len(args) == 1 + and isinstance(args[0], list) + and all(isinstance(r, EvaluationRow) for r in args[0]) + and not kwargs + ): return test_func(rows=args[0]) - if "rows" in kwargs and isinstance(kwargs["rows"], list) and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]): + if ( + "rows" in kwargs + and isinstance(kwargs["rows"], list) + and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) + ): return test_func(**kwargs) return pytest_wrapper(*args, **kwargs) - + _dual_model_wrapper_fn = dual_mode_wrapper # Copy all attributes from the pytest wrapper to our dual mode wrapper diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index ec1bd362..fbd44f25 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -99,10 +99,12 @@ def create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param from functools import wraps if asyncio.iscoroutinefunction(wrapper_body): + @wraps(test_func) async def wrapper(**kwargs): return await wrapper_body(**kwargs) else: + @wraps(test_func) def wrapper(**kwargs): return wrapper_body(**kwargs) diff --git a/tests/pytest/test_direct_run.py b/tests/pytest/test_direct_run.py index a35d5490..7b9c7a5d 100644 --- a/tests/pytest/test_direct_run.py +++ b/tests/pytest/test_direct_run.py @@ -15,7 +15,7 @@ ], completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], rollout_processor=SingleTurnRolloutProcessor(), - mode="listwise", + mode="all", ) def test_direct_run(rows: List[EvaluationRow]) -> List[EvaluationRow]: """Run math evaluation on sample dataset using pytest interface.""" @@ -53,7 +53,7 @@ def test_direct_run_main(): ], completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], rollout_processor=SingleTurnRolloutProcessor(), - mode="listwise", + mode="all", ) async def test_direct_run_async(rows: List[EvaluationRow]) -> List[EvaluationRow]: """Run math evaluation on sample dataset using pytest interface.""" @@ -62,21 +62,22 @@ async def test_direct_run_async(rows: List[EvaluationRow]) -> List[EvaluationRow return rows - @pytest.mark.asyncio async def test_direct_run_async_main(): rows = [ EvaluationRow( messages=[ - Message(role="user", content="What is the capital of France?"), + Message(role="user", content="1"), ], ), EvaluationRow( messages=[ - Message(role="user", content="What is the capital of the moon?"), + Message(role="user", content="2"), ], ), ] res = await test_direct_run_async(rows) + assert res[0].messages[0].content == "1" + assert res[1].messages[0].content == "2" assert res[0].evaluation_result.score == 0 - assert res[1].evaluation_result.score == 1 \ No newline at end of file + assert res[1].evaluation_result.score == 1 From 0a78e79b258b15b30faca6c476d2d7a27da756ec Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Tue, 19 Aug 2025 23:24:46 -0700 Subject: [PATCH 3/4] fix test --- eval_protocol/pytest/utils.py | 7 +++++++ tests/pytest/test_pytest_ensure_logging.py | 4 ++-- tests/pytest/test_pytest_ids.py | 8 ++++---- tests/pytest/test_pytest_stable_row_id.py | 12 ++++++------ 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index fbd44f25..29135651 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -16,6 +16,13 @@ RolloutProcessorConfig, ) +def is_in_event_loop(): + try: + asyncio.get_event_loop() + return True + except RuntimeError: + return False + def execute_function(func: Callable, **kwargs) -> Any: """ diff --git a/tests/pytest/test_pytest_ensure_logging.py b/tests/pytest/test_pytest_ensure_logging.py index e57b3c8c..89828fe8 100644 --- a/tests/pytest/test_pytest_ensure_logging.py +++ b/tests/pytest/test_pytest_ensure_logging.py @@ -2,7 +2,7 @@ from unittest.mock import Mock, patch -async def test_ensure_logging(monkeypatch): +def test_ensure_logging(monkeypatch): """ Ensure that default SQLITE logger gets called by mocking the storage and checking that the storage is called. """ @@ -37,7 +37,7 @@ async def test_ensure_logging(monkeypatch): def eval_fn(row: EvaluationRow) -> EvaluationRow: return row - await eval_fn( + eval_fn( dataset_path=["tests/pytest/data/markdown_dataset.jsonl"], completion_params={"temperature": 0.0, "model": "dummy/local-model"}, ) diff --git a/tests/pytest/test_pytest_ids.py b/tests/pytest/test_pytest_ids.py index b6bb4a35..0a4053ed 100644 --- a/tests/pytest/test_pytest_ids.py +++ b/tests/pytest/test_pytest_ids.py @@ -19,7 +19,7 @@ def read(self): return list(self._rows.values()) -async def test_evaluation_test_decorator(monkeypatch): +def test_evaluation_test_decorator(monkeypatch): from eval_protocol.pytest.evaluation_test import evaluation_test logger = InMemoryLogger() @@ -45,13 +45,13 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: # Manually invoke all parameter combinations within a single test for ds_path in dataset_paths: - await eval_fn(dataset_path=[ds_path], completion_params={"temperature": 0.0, "model": "dummy/local-model"}) + eval_fn(dataset_path=[ds_path], completion_params={"temperature": 0.0, "model": "dummy/local-model"}) # Assertions on IDs generated by the decorator logic assert len(logger.read()) == 38 -async def test_evaluation_test_decorator_ids_single(monkeypatch): +def test_evaluation_test_decorator_ids_single(monkeypatch): in_memory_logger = InMemoryLogger() unique_run_ids = set() unique_experiment_ids = set() @@ -97,7 +97,7 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: # Manually invoke all parameter combinations within a single test for ds_path in dataset_paths: for params in completion_params_list: - await eval_fn(dataset_path=[ds_path], completion_params=params) + eval_fn(dataset_path=[ds_path], completion_params=params) # Assertions on IDs generated by the decorator logic assert len(unique_invocation_ids) == 1 diff --git a/tests/pytest/test_pytest_stable_row_id.py b/tests/pytest/test_pytest_stable_row_id.py index c2a5709a..a5d6f5db 100644 --- a/tests/pytest/test_pytest_stable_row_id.py +++ b/tests/pytest/test_pytest_stable_row_id.py @@ -5,7 +5,7 @@ from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row -async def test_evaluation_test_decorator_ids_single(): +def test_evaluation_test_decorator_ids_single(): from eval_protocol.pytest.evaluation_test import evaluation_test row_ids = set() @@ -35,18 +35,18 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: # Manually invoke all parameter combinations within a single test for ds_path in input_dataset: for params in completion_params_list: - await eval_fn(dataset_path=[ds_path], completion_params=params) + eval_fn(dataset_path=[ds_path], completion_params=params) # Second invocation to ensure that IDs are stable across multiple invocations for ds_path in input_dataset: for params in completion_params_list: - await eval_fn(dataset_path=[ds_path], completion_params=params) + eval_fn(dataset_path=[ds_path], completion_params=params) # Assertions on IDs generated by the decorator logic assert len(row_ids) == 19 # from the markdown dataset -async def test_evaluation_test_generated_row_ids_without_dataset_keys(): +def test_evaluation_test_generated_row_ids_without_dataset_keys(): from eval_protocol.pytest.evaluation_test import evaluation_test # Adapter that does NOT set row_id; lets evaluation_test generate IDs @@ -86,12 +86,12 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: # Single invocation (one dataset, one param set) with multiple runs for ds_path in input_dataset: for params in completion_params: - await eval_fn(dataset_path=[ds_path], completion_params=params) + eval_fn(dataset_path=[ds_path], completion_params=params) # Second invocation to ensure that IDs are stable across multiple invocations for ds_path in input_dataset: for params in completion_params: - await eval_fn(dataset_path=[ds_path], completion_params=params) + eval_fn(dataset_path=[ds_path], completion_params=params) # Even with multiple runs, generated row_ids should be stable within the invocation assert len(row_ids) == 19 # equals dataset size when IDs are generated once and preserved across runs From 58d9cf2ced22207ada280c0f2a91e0b5bd7a0268 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Wed, 20 Aug 2025 01:03:13 -0700 Subject: [PATCH 4/4] format --- eval_protocol/pytest/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 29135651..8dce9c6d 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -16,6 +16,7 @@ RolloutProcessorConfig, ) + def is_in_event_loop(): try: asyncio.get_event_loop()