Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
experiment_results/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
7 changes: 3 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ repos:
- id: ruff-format
- id: ruff
args: ["--fix"]

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.403
- repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror
rev: 1.31.3
hooks:
- id: pyright
- id: basedpyright
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,10 @@
"python.testing.autoTestDiscoverOnSaveEnabled": true,
"python.defaultInterpreterPath": "./.venv/bin/python",
"python.testing.cwd": "${workspaceFolder}",
"cursorpyright.analysis.diagnosticMode": "openFilesOnly"
"cursorpyright.analysis.diagnosticMode": "openFilesOnly",
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff"
}
}
13 changes: 6 additions & 7 deletions eval_protocol/benchmarks/test_gpqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import csv
import io
import re
from typing import List

import requests

Expand All @@ -20,12 +19,12 @@
)


def _load_gpqa_messages_from_csv() -> List[List[Message]]:
def _load_gpqa_messages_from_csv() -> list[list[list[Message]]]:
url = "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
resp = requests.get(url, timeout=60)
resp.raise_for_status()

messages_list: List[List[Message]] = []
messages_list: list[list[Message]] = []
reader = csv.DictReader(io.StringIO(resp.text))
for ex in reader:
q = str(ex.get("Question", ""))
Expand All @@ -45,7 +44,7 @@ def _load_gpqa_messages_from_csv() -> List[List[Message]]:
)
if not messages_list:
raise RuntimeError("Failed to load GPQA messages: no rows found from source")
return messages_list
return [messages_list]


def _extract_abcd_letter(text: str) -> str | None:
Expand All @@ -58,7 +57,7 @@ def _extract_abcd_letter(text: str) -> str | None:
_GPQA_INPUT_MESSAGES = _load_gpqa_messages_from_csv()


def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
def _strip_gt_messages(msgs: list[Message]) -> list[Message]:
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]


Expand All @@ -69,9 +68,9 @@ def __init__(self):
super().__init__()
self.single_turn_processor = SingleTurnRolloutProcessor()

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
"""Preprocess rows and delegate to SingleTurnRolloutProcessor."""
processed: List[EvaluationRow] = []
processed: list[EvaluationRow] = []

for r in rows:
gt_tokens = [
Expand Down
4 changes: 2 additions & 2 deletions eval_protocol/benchmarks/test_livebench_data_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:

@evaluation_test(
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
input_messages=[[m for m in r.messages] for r in _CTA_ROWS],
input_messages=[[[m for m in r.messages] for r in _CTA_ROWS]],
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
Expand Down Expand Up @@ -451,7 +451,7 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:

@evaluation_test(
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS],
input_messages=[[[m for m in r.messages] for r in _TABLEJOIN_ROWS]],
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
rollout_processor=LiveBenchGroundTruthRolloutProcessor(_TABLEJOIN_ROWS),
aggregation_method="mean",
Expand Down
12 changes: 10 additions & 2 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,19 @@ class EvaluationThreshold(BaseModel):
success: float = Field(
..., description="Minimum success rate threshold (fraction of total score, 0.0 to 1.0)", ge=0.0, le=1.0
)
standard_error: Optional[float] = Field(
None, description="Maximum standard error threshold (fraction of total score, 0.0 to 1.0)", ge=0.0, le=1.0
standard_error: float | None = Field(
default=None,
description="Maximum standard error threshold (fraction of total score, 0.0 to 1.0)",
ge=0.0,
le=1.0,
)


class EvaluationThresholdDict(TypedDict):
success: float
standard_error: float | None


class EvalMetadata(BaseModel):
"""Metadata about the evaluation that was run."""

Expand Down
7 changes: 3 additions & 4 deletions eval_protocol/pytest/default_dataset_adapter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Any, Dict, List

from typing import Any
from eval_protocol.models import EvaluationRow


def default_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
def default_dataset_adapter(rows: list[dict[str, Any]]) -> list[EvaluationRow]: # pyright: ignore[reportExplicitAny]
"""
Default dataset adapter that simply returns the rows as is.
"""
return [EvaluationRow(**row) for row in rows]
return [EvaluationRow(**row) for row in rows] # pyright: ignore[reportAny]
78 changes: 78 additions & 0 deletions eval_protocol/pytest/dual_mode_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import asyncio
from collections.abc import Callable
import functools

from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.types import EvaluationTestMode, TestFunction


def create_dual_mode_wrapper( # pyright: ignore[reportUnknownParameterType]
test_func: TestFunction,
mode: EvaluationTestMode,
max_concurrent_rollouts: int,
max_concurrent_evaluations: int,
pytest_wrapper: Callable[[], None],
):
"""
Creates a wrapper that supports both pytest parameterized execution and direct function calls.

This wrapper enables the decorated evaluation test function to be used in two ways:
1. As a pytest test (via pytest.mark.parametrize) with full parameterization
2. As a direct function call with EvaluationRow data for programmatic use

The wrapper automatically detects the calling pattern and routes to the appropriate
execution path, ensuring consistent behavior regardless of how the function is invoked.

Returns:
A callable that can handle both pytest test execution and direct function calls
"""

# Check if the test function is async
is_async = asyncio.iscoroutinefunction(test_func)

async def call_test_func(**call_kwargs): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
"""Helper to call test_func with proper async/sync handling"""
if is_async:
return await test_func(**call_kwargs) # pyright: ignore[reportUnknownVariableType, reportGeneralTypeIssues, reportCallIssue]
else:
return test_func(**call_kwargs) # pyright: ignore[reportUnknownVariableType, reportCallIssue]

async def dual_mode_wrapper(*args, **kwargs): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
# Check if this is a direct call with the expected signature
if mode == "pointwise":
# For pointwise mode, check if called with a single row argument
if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: # pyright: ignore[reportUnknownArgumentType]
return await call_test_func(row=args[0]) # pyright: ignore[reportUnknownVariableType]
else:
# For batch mode, check if called with rows argument
if (
len(args) == 1 # pyright: ignore[reportUnknownArgumentType]
and isinstance(args[0], list)
and all(isinstance(r, EvaluationRow) for r in args[0]) # pyright: ignore[reportUnknownVariableType]
and not kwargs
):
return await call_test_func(rows=args[0]) # pyright: ignore[reportUnknownVariableType]
# Also check if called with keyword argument 'rows'
if (
len(args) == 0 # pyright: ignore[reportUnknownArgumentType]
and "rows" in kwargs
and isinstance(kwargs["rows"], list)
and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) # pyright: ignore[reportUnknownVariableType]
):
return await call_test_func(**kwargs) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]

# If not a direct call, use the pytest wrapper
return await pytest_wrapper(*args, **kwargs) # pyright: ignore[reportUnknownVariableType, reportGeneralTypeIssues]

dual_mode_wrapper._origin_func = test_func # pyright: ignore[reportFunctionMemberAccess]
dual_mode_wrapper._metainfo = { # pyright: ignore[reportFunctionMemberAccess]
"mode": mode,
"max_rollout_concurrency": max_concurrent_rollouts,
"max_evaluation_concurrency": max_concurrent_evaluations,
}

# Copy all attributes from the pytest wrapper to our dual mode wrapper

functools.update_wrapper(dual_mode_wrapper, pytest_wrapper) # pyright: ignore[reportUnknownArgumentType]

return dual_mode_wrapper # pyright: ignore[reportUnknownVariableType]
Loading
Loading