Skip to content

Commit f5a1d7b

Browse files
author
Dylan Huang
committed
evaluation_test type checks pass
1 parent c74aef2 commit f5a1d7b

12 files changed

+590
-561
lines changed
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import Any, Dict, List
2-
1+
from typing import Any
32
from eval_protocol.models import EvaluationRow
43

54

6-
def default_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
5+
def default_dataset_adapter(rows: list[dict[str, Any]]) -> list[EvaluationRow]: # pyright: ignore[reportExplicitAny]
76
"""
87
Default dataset adapter that simply returns the rows as is.
98
"""
10-
return [EvaluationRow(**row) for row in rows]
9+
return [EvaluationRow(**row) for row in rows] # pyright: ignore[reportAny]
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import asyncio
2+
from collections.abc import Callable
3+
import functools
4+
5+
from eval_protocol.models import EvaluationRow
6+
from eval_protocol.pytest.types import EvaluationTestMode, TestFunction
7+
8+
9+
def create_dual_mode_wrapper( # pyright: ignore[reportUnknownParameterType]
10+
test_func: TestFunction,
11+
mode: EvaluationTestMode,
12+
max_concurrent_rollouts: int,
13+
max_concurrent_evaluations: int,
14+
pytest_wrapper: Callable[[], None],
15+
):
16+
"""
17+
Creates a wrapper that supports both pytest parameterized execution and direct function calls.
18+
19+
This wrapper enables the decorated evaluation test function to be used in two ways:
20+
1. As a pytest test (via pytest.mark.parametrize) with full parameterization
21+
2. As a direct function call with EvaluationRow data for programmatic use
22+
23+
The wrapper automatically detects the calling pattern and routes to the appropriate
24+
execution path, ensuring consistent behavior regardless of how the function is invoked.
25+
26+
Returns:
27+
A callable that can handle both pytest test execution and direct function calls
28+
"""
29+
30+
# Check if the test function is async
31+
is_async = asyncio.iscoroutinefunction(test_func)
32+
33+
async def call_test_func(**call_kwargs): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
34+
"""Helper to call test_func with proper async/sync handling"""
35+
if is_async:
36+
return await test_func(**call_kwargs) # pyright: ignore[reportUnknownVariableType, reportGeneralTypeIssues, reportCallIssue]
37+
else:
38+
return test_func(**call_kwargs) # pyright: ignore[reportUnknownVariableType, reportCallIssue]
39+
40+
async def dual_mode_wrapper(*args, **kwargs): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
41+
# Check if this is a direct call with the expected signature
42+
if mode == "pointwise":
43+
# For pointwise mode, check if called with a single row argument
44+
if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: # pyright: ignore[reportUnknownArgumentType]
45+
return await call_test_func(row=args[0]) # pyright: ignore[reportUnknownVariableType]
46+
else:
47+
# For batch mode, check if called with rows argument
48+
if (
49+
len(args) == 1 # pyright: ignore[reportUnknownArgumentType]
50+
and isinstance(args[0], list)
51+
and all(isinstance(r, EvaluationRow) for r in args[0]) # pyright: ignore[reportUnknownVariableType]
52+
and not kwargs
53+
):
54+
return await call_test_func(rows=args[0]) # pyright: ignore[reportUnknownVariableType]
55+
# Also check if called with keyword argument 'rows'
56+
if (
57+
len(args) == 0 # pyright: ignore[reportUnknownArgumentType]
58+
and "rows" in kwargs
59+
and isinstance(kwargs["rows"], list)
60+
and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) # pyright: ignore[reportUnknownVariableType]
61+
):
62+
return await call_test_func(**kwargs) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
63+
64+
# If not a direct call, use the pytest wrapper
65+
return await pytest_wrapper(*args, **kwargs) # pyright: ignore[reportUnknownVariableType, reportGeneralTypeIssues]
66+
67+
dual_mode_wrapper._origin_func = test_func # pyright: ignore[reportFunctionMemberAccess]
68+
dual_mode_wrapper._metainfo = { # pyright: ignore[reportFunctionMemberAccess]
69+
"mode": mode,
70+
"max_rollout_concurrency": max_concurrent_rollouts,
71+
"max_evaluation_concurrency": max_concurrent_evaluations,
72+
}
73+
74+
# Copy all attributes from the pytest wrapper to our dual mode wrapper
75+
76+
functools.update_wrapper(dual_mode_wrapper, pytest_wrapper) # pyright: ignore[reportUnknownArgumentType]
77+
78+
return dual_mode_wrapper # pyright: ignore[reportUnknownVariableType]

0 commit comments

Comments
 (0)