Skip to content

Commit 5e75383

Browse files
author
Dylan Huang
authored
Fix evaluation test type checks (#134)
* fix type checks in IDE * convert to basedpyright and temporarily disable in CI * Update VSCode settings to include default formatter for Ruff * save * part 2 * part 3 * use basedpyright for pre-commit * evaluation_test type checks pass * fix input_messages usage * test_pydantic_multi_agent runs * remove debugger * vite build * fix * fix test_math_dataset * fix test_pytest_tools_are_added_to_row * fix test_pytest_propagate_error * fix input_message data type * fix input_messages
1 parent cbbd407 commit 5e75383

40 files changed

+1273
-1078
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
experiment_results/
2+
13
# Byte-compiled / optimized / DLL files
24
__pycache__/
35
*.py[cod]

.pre-commit-config.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ repos:
2222
- id: ruff-format
2323
- id: ruff
2424
args: ["--fix"]
25-
26-
- repo: https://github.com/RobertCraigie/pyright-python
27-
rev: v1.1.403
25+
- repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror
26+
rev: 1.31.3
2827
hooks:
29-
- id: pyright
28+
- id: basedpyright

.vscode/settings.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,10 @@
55
"python.testing.autoTestDiscoverOnSaveEnabled": true,
66
"python.defaultInterpreterPath": "./.venv/bin/python",
77
"python.testing.cwd": "${workspaceFolder}",
8-
"cursorpyright.analysis.diagnosticMode": "openFilesOnly"
8+
"cursorpyright.analysis.diagnosticMode": "openFilesOnly",
9+
"editor.defaultFormatter": "charliermarsh.ruff",
10+
"editor.formatOnSave": true,
11+
"[python]": {
12+
"editor.defaultFormatter": "charliermarsh.ruff"
13+
}
914
}

eval_protocol/benchmarks/test_gpqa.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import csv
33
import io
44
import re
5-
from typing import List
65

76
import requests
87

@@ -20,12 +19,12 @@
2019
)
2120

2221

23-
def _load_gpqa_messages_from_csv() -> List[List[Message]]:
22+
def _load_gpqa_messages_from_csv() -> list[list[list[Message]]]:
2423
url = "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
2524
resp = requests.get(url, timeout=60)
2625
resp.raise_for_status()
2726

28-
messages_list: List[List[Message]] = []
27+
messages_list: list[list[Message]] = []
2928
reader = csv.DictReader(io.StringIO(resp.text))
3029
for ex in reader:
3130
q = str(ex.get("Question", ""))
@@ -45,7 +44,7 @@ def _load_gpqa_messages_from_csv() -> List[List[Message]]:
4544
)
4645
if not messages_list:
4746
raise RuntimeError("Failed to load GPQA messages: no rows found from source")
48-
return messages_list
47+
return [messages_list]
4948

5049

5150
def _extract_abcd_letter(text: str) -> str | None:
@@ -58,7 +57,7 @@ def _extract_abcd_letter(text: str) -> str | None:
5857
_GPQA_INPUT_MESSAGES = _load_gpqa_messages_from_csv()
5958

6059

61-
def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
60+
def _strip_gt_messages(msgs: list[Message]) -> list[Message]:
6261
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
6362

6463

@@ -69,9 +68,9 @@ def __init__(self):
6968
super().__init__()
7069
self.single_turn_processor = SingleTurnRolloutProcessor()
7170

72-
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
71+
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
7372
"""Preprocess rows and delegate to SingleTurnRolloutProcessor."""
74-
processed: List[EvaluationRow] = []
73+
processed: list[EvaluationRow] = []
7574

7675
for r in rows:
7776
gt_tokens = [

eval_protocol/benchmarks/test_livebench_data_analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
409409

410410
@evaluation_test(
411411
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
412-
input_messages=[[m for m in r.messages] for r in _CTA_ROWS],
412+
input_messages=[[[m for m in r.messages] for r in _CTA_ROWS]],
413413
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
414414
rollout_processor=SingleTurnRolloutProcessor(),
415415
aggregation_method="mean",
@@ -451,7 +451,7 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
451451

452452
@evaluation_test(
453453
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
454-
input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS],
454+
input_messages=[[[m for m in r.messages] for r in _TABLEJOIN_ROWS]],
455455
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
456456
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS),
457457
aggregation_method="mean",

eval_protocol/models.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,11 +439,19 @@ class EvaluationThreshold(BaseModel):
439439
success: float = Field(
440440
..., description="Minimum success rate threshold (fraction of total score, 0.0 to 1.0)", ge=0.0, le=1.0
441441
)
442-
standard_error: Optional[float] = Field(
443-
None, description="Maximum standard error threshold (fraction of total score, 0.0 to 1.0)", ge=0.0, le=1.0
442+
standard_error: float | None = Field(
443+
default=None,
444+
description="Maximum standard error threshold (fraction of total score, 0.0 to 1.0)",
445+
ge=0.0,
446+
le=1.0,
444447
)
445448

446449

450+
class EvaluationThresholdDict(TypedDict):
451+
success: float
452+
standard_error: float | None
453+
454+
447455
class EvalMetadata(BaseModel):
448456
"""Metadata about the evaluation that was run."""
449457

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import Any, Dict, List
2-
1+
from typing import Any
32
from eval_protocol.models import EvaluationRow
43

54

6-
def default_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
5+
def default_dataset_adapter(rows: list[dict[str, Any]]) -> list[EvaluationRow]: # pyright: ignore[reportExplicitAny]
76
"""
87
Default dataset adapter that simply returns the rows as is.
98
"""
10-
return [EvaluationRow(**row) for row in rows]
9+
return [EvaluationRow(**row) for row in rows] # pyright: ignore[reportAny]
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import asyncio
2+
from collections.abc import Callable
3+
import functools
4+
5+
from eval_protocol.models import EvaluationRow
6+
from eval_protocol.pytest.types import EvaluationTestMode, TestFunction
7+
8+
9+
def create_dual_mode_wrapper( # pyright: ignore[reportUnknownParameterType]
10+
test_func: TestFunction,
11+
mode: EvaluationTestMode,
12+
max_concurrent_rollouts: int,
13+
max_concurrent_evaluations: int,
14+
pytest_wrapper: Callable[[], None],
15+
):
16+
"""
17+
Creates a wrapper that supports both pytest parameterized execution and direct function calls.
18+
19+
This wrapper enables the decorated evaluation test function to be used in two ways:
20+
1. As a pytest test (via pytest.mark.parametrize) with full parameterization
21+
2. As a direct function call with EvaluationRow data for programmatic use
22+
23+
The wrapper automatically detects the calling pattern and routes to the appropriate
24+
execution path, ensuring consistent behavior regardless of how the function is invoked.
25+
26+
Returns:
27+
A callable that can handle both pytest test execution and direct function calls
28+
"""
29+
30+
# Check if the test function is async
31+
is_async = asyncio.iscoroutinefunction(test_func)
32+
33+
async def call_test_func(**call_kwargs): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
34+
"""Helper to call test_func with proper async/sync handling"""
35+
if is_async:
36+
return await test_func(**call_kwargs) # pyright: ignore[reportUnknownVariableType, reportGeneralTypeIssues, reportCallIssue]
37+
else:
38+
return test_func(**call_kwargs) # pyright: ignore[reportUnknownVariableType, reportCallIssue]
39+
40+
async def dual_mode_wrapper(*args, **kwargs): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
41+
# Check if this is a direct call with the expected signature
42+
if mode == "pointwise":
43+
# For pointwise mode, check if called with a single row argument
44+
if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: # pyright: ignore[reportUnknownArgumentType]
45+
return await call_test_func(row=args[0]) # pyright: ignore[reportUnknownVariableType]
46+
else:
47+
# For batch mode, check if called with rows argument
48+
if (
49+
len(args) == 1 # pyright: ignore[reportUnknownArgumentType]
50+
and isinstance(args[0], list)
51+
and all(isinstance(r, EvaluationRow) for r in args[0]) # pyright: ignore[reportUnknownVariableType]
52+
and not kwargs
53+
):
54+
return await call_test_func(rows=args[0]) # pyright: ignore[reportUnknownVariableType]
55+
# Also check if called with keyword argument 'rows'
56+
if (
57+
len(args) == 0 # pyright: ignore[reportUnknownArgumentType]
58+
and "rows" in kwargs
59+
and isinstance(kwargs["rows"], list)
60+
and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) # pyright: ignore[reportUnknownVariableType]
61+
):
62+
return await call_test_func(**kwargs) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
63+
64+
# If not a direct call, use the pytest wrapper
65+
return await pytest_wrapper(*args, **kwargs) # pyright: ignore[reportUnknownVariableType, reportGeneralTypeIssues]
66+
67+
dual_mode_wrapper._origin_func = test_func # pyright: ignore[reportFunctionMemberAccess]
68+
dual_mode_wrapper._metainfo = { # pyright: ignore[reportFunctionMemberAccess]
69+
"mode": mode,
70+
"max_rollout_concurrency": max_concurrent_rollouts,
71+
"max_evaluation_concurrency": max_concurrent_evaluations,
72+
}
73+
74+
# Copy all attributes from the pytest wrapper to our dual mode wrapper
75+
76+
functools.update_wrapper(dual_mode_wrapper, pytest_wrapper) # pyright: ignore[reportUnknownArgumentType]
77+
78+
return dual_mode_wrapper # pyright: ignore[reportUnknownVariableType]

0 commit comments

Comments
 (0)