diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 45ad02ac..b6e10ab7 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -11,7 +11,8 @@ from dataclasses import replace from typing import Any, Callable, Dict, List, Literal, Optional, Union from collections import defaultdict - +import hashlib +import ast from mcp.types import Completion import pytest @@ -244,6 +245,7 @@ def evaluation_test( # noqa: C901 max_dataset_rows: Optional[int] = None, mcp_config_path: Optional[str] = None, max_concurrent_rollouts: int = 8, + max_concurrent_evaluations: int = 64, server_script_path: Optional[str] = None, steps: int = 30, mode: EvaluationTestMode = "pointwise", @@ -308,6 +310,7 @@ def evaluation_test( # noqa: C901 max_dataset_rows: Limit dataset to the first N rows. mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel. + max_concurrent_evaluations: Maximum number of concurrent evaluations to run in parallel. server_script_path: Path to the MCP server script to run (default: "examples/tau2_mcp/server.py"). steps: Number of rollout steps to execute (default: 30). mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result). @@ -582,18 +585,17 @@ def _log_eval_error( for row in fresh_dataset: active_logger.log(row) - if mode == "pointwise": - # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution - semaphore = asyncio.Semaphore(max_concurrent_rollouts) - tasks = [] + # prepare parallel eval helper function + semaphore = asyncio.Semaphore(max_concurrent_evaluations) - async def _execute_with_semaphore(row): - async with semaphore: - # NOTE: we will still evaluate errored rows (give users control over this) - # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func + async def _execute_eval_with_semaphore(**inner_kwargs): + async with semaphore: + # NOTE: we will still evaluate errored rows (give users control over this) + # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func + if "row" in inner_kwargs: result = await execute_with_params( test_func, - processed_row=row, + processed_row=inner_kwargs["row"], evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, ) if result is None or not isinstance(result, EvaluationRow): @@ -601,10 +603,24 @@ async def _execute_with_semaphore(row): f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." ) return result + if "rows" in inner_kwargs: + results = await execute_with_params( + test_func, + processed_dataset=inner_kwargs["rows"], + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + ) + if results is None or not isinstance(results, list): + raise ValueError( + f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." + ) + return results + if mode == "pointwise": + # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution + tasks = [] # Use wrapper that handles retry logic internally async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config): - tasks.append(asyncio.create_task(_execute_with_semaphore(row))) + tasks.append(asyncio.create_task(_execute_eval_with_semaphore(row=row))) results = await asyncio.gather(*tasks) @@ -645,14 +661,13 @@ async def _collect_result(config, lst): for result in rollout_results: for row in result: row_groups[row.input_metadata.row_id].append(row) - results = [] + tasks = [] for row_id, rows in row_groups.items(): - result = await execute_with_params( - test_func, - processed_dataset=rows, - evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, - ) - results.extend(result) + tasks.append(asyncio.create_task(_execute_eval_with_semaphore(rows=rows))) + results = [] + for task in tasks: + res = await task + results.extend(res) all_results[i] = results else: # Batch mode: collect all results first, then evaluate (no pipelining) @@ -789,6 +804,13 @@ async def dual_mode_wrapper(*args, **kwargs): # If not a direct call, use the pytest wrapper return await pytest_wrapper(*args, **kwargs) + dual_mode_wrapper._origin_func = test_func + dual_mode_wrapper._metainfo = { + "mode": mode, + "max_rollout_concurrency": max_concurrent_rollouts, + "max_evaluation_concurrency": max_concurrent_evaluations, + } + # Copy all attributes from the pytest wrapper to our dual mode wrapper import functools diff --git a/tests/pytest/test_get_metadata.py b/tests/pytest/test_get_metadata.py new file mode 100644 index 00000000..3917fb3b --- /dev/null +++ b/tests/pytest/test_get_metadata.py @@ -0,0 +1,34 @@ +import asyncio +from typing import Dict, List + +from eval_protocol.pytest import evaluation_test +from eval_protocol.models import EvaluationRow, Message + + +@evaluation_test( + input_messages=[ + [ + Message(role="user", content="What is the capital of France?"), + ], + [ + Message(role="user", content="What is the capital of the moon?"), + ], + ], + completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct"}] * 2, + mode="groupwise", + max_concurrent_rollouts=5, + max_concurrent_evaluations=10, +) +def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]: + """Run math evaluation on sample dataset using pytest interface.""" + return rows + + +def test_pytest_func_metainfo(): + assert hasattr(test_pytest_async, "_origin_func") + origin_func = test_pytest_async._origin_func + assert not asyncio.iscoroutinefunction(origin_func) + assert asyncio.iscoroutinefunction(test_pytest_async) + assert test_pytest_async._metainfo["mode"] == "groupwise" + assert test_pytest_async._metainfo["max_rollout_concurrency"] == 5 + assert test_pytest_async._metainfo["max_evaluation_concurrency"] == 10