Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from dataclasses import replace
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from collections import defaultdict

import hashlib
import ast
from mcp.types import Completion
import pytest

Expand Down Expand Up @@ -244,6 +245,7 @@ def evaluation_test( # noqa: C901
max_dataset_rows: Optional[int] = None,
mcp_config_path: Optional[str] = None,
max_concurrent_rollouts: int = 8,
max_concurrent_evaluations: int = 64,
server_script_path: Optional[str] = None,
steps: int = 30,
mode: EvaluationTestMode = "pointwise",
Expand Down Expand Up @@ -308,6 +310,7 @@ def evaluation_test( # noqa: C901
max_dataset_rows: Limit dataset to the first N rows.
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
max_concurrent_evaluations: Maximum number of concurrent evaluations to run in parallel.
server_script_path: Path to the MCP server script to run (default: "examples/tau2_mcp/server.py").
steps: Number of rollout steps to execute (default: 30).
mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result).
Expand Down Expand Up @@ -582,29 +585,42 @@ def _log_eval_error(
for row in fresh_dataset:
active_logger.log(row)

if mode == "pointwise":
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
semaphore = asyncio.Semaphore(max_concurrent_rollouts)
tasks = []
# prepare parallel eval helper function
semaphore = asyncio.Semaphore(max_concurrent_evaluations)

async def _execute_with_semaphore(row):
async with semaphore:
# NOTE: we will still evaluate errored rows (give users control over this)
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
async def _execute_eval_with_semaphore(**inner_kwargs):
async with semaphore:
# NOTE: we will still evaluate errored rows (give users control over this)
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
if "row" in inner_kwargs:
result = await execute_with_params(
test_func,
processed_row=row,
processed_row=inner_kwargs["row"],
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
)
if result is None or not isinstance(result, EvaluationRow):
raise ValueError(
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
)
return result
if "rows" in inner_kwargs:
results = await execute_with_params(
test_func,
processed_dataset=inner_kwargs["rows"],
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
)
if results is None or not isinstance(results, list):
raise ValueError(
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
)
return results

if mode == "pointwise":
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
tasks = []
# Use wrapper that handles retry logic internally
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
tasks.append(asyncio.create_task(_execute_with_semaphore(row)))
tasks.append(asyncio.create_task(_execute_eval_with_semaphore(row=row)))

results = await asyncio.gather(*tasks)

Expand Down Expand Up @@ -645,14 +661,13 @@ async def _collect_result(config, lst):
for result in rollout_results:
for row in result:
row_groups[row.input_metadata.row_id].append(row)
results = []
tasks = []
for row_id, rows in row_groups.items():
result = await execute_with_params(
test_func,
processed_dataset=rows,
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
)
results.extend(result)
tasks.append(asyncio.create_task(_execute_eval_with_semaphore(rows=rows)))
results = []
for task in tasks:
res = await task
results.extend(res)
all_results[i] = results
else:
# Batch mode: collect all results first, then evaluate (no pipelining)
Expand Down Expand Up @@ -789,6 +804,13 @@ async def dual_mode_wrapper(*args, **kwargs):
# If not a direct call, use the pytest wrapper
return await pytest_wrapper(*args, **kwargs)

dual_mode_wrapper._origin_func = test_func
dual_mode_wrapper._metainfo = {
"mode": mode,
"max_rollout_concurrency": max_concurrent_rollouts,
"max_evaluation_concurrency": max_concurrent_evaluations,
}

# Copy all attributes from the pytest wrapper to our dual mode wrapper
import functools

Expand Down
34 changes: 34 additions & 0 deletions tests/pytest/test_get_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import asyncio
from typing import Dict, List

from eval_protocol.pytest import evaluation_test
from eval_protocol.models import EvaluationRow, Message


@evaluation_test(
input_messages=[
[
Message(role="user", content="What is the capital of France?"),
],
[
Message(role="user", content="What is the capital of the moon?"),
],
],
completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct"}] * 2,
mode="groupwise",
max_concurrent_rollouts=5,
max_concurrent_evaluations=10,
)
def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
"""Run math evaluation on sample dataset using pytest interface."""
return rows


def test_pytest_func_metainfo():
assert hasattr(test_pytest_async, "_origin_func")
origin_func = test_pytest_async._origin_func
assert not asyncio.iscoroutinefunction(origin_func)
assert asyncio.iscoroutinefunction(test_pytest_async)
assert test_pytest_async._metainfo["mode"] == "groupwise"
assert test_pytest_async._metainfo["max_rollout_concurrency"] == 5
assert test_pytest_async._metainfo["max_evaluation_concurrency"] == 10
Loading