Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,9 @@ jobs:
- name: Ruff lint
run: uv run ruff check .

- name: Type check with pyright
- name: Run pre-commit (format, lint, type check)
run: |
# 'set +e' disables immediate exit on error so we can capture and report errors but exit 0
# Note: We currently suppress pyright failures to allow CI to pass while we iteratively fix all type issues.
# Once all type errors are resolved, we will remove this suppression and enforce strict type checking.
set +e
uv run basedpyright || true
uv run pre-commit run --all-files

test-core:
name: Core Tests (Python ${{ matrix.python-version }})
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ repos:
NODE_OPTIONS: "--max-old-space-size=4096"
# Only check Python files in the main package to reduce memory usage
files: ^eval_protocol/.*\.py$
additional_dependencies: ["pre-commit>=3.7.0"]
2 changes: 1 addition & 1 deletion eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def tau2_airline_eval(
elif role == "user":
trajectory_objects.append(UserMessage(role=role, content=content))
elif role == "tool":
tool_id = msg.tool_call_id or ""
tool_id = msg.tool_call_id if isinstance(msg.tool_call_id, str) else ""
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content, requestor="assistant"))

reward = 1.0
Expand Down
11 changes: 7 additions & 4 deletions eval_protocol/rewards/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
import warnings
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union, Callable, cast

# Import OpenAI at module level for mocking in tests
try:
Expand Down Expand Up @@ -451,7 +451,8 @@ def schema_jaccard_reward(
DeprecationWarning,
stacklevel=2,
)
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
_exact_tool_match: Callable[..., EvaluateResult] = cast(Callable[..., EvaluateResult], exact_tool_match_reward)
return _exact_tool_match(messages=messages, ground_truth=ground_truth, **kwargs)


@reward_function
Expand Down Expand Up @@ -493,7 +494,8 @@ def llm_judge_reward(
DeprecationWarning,
stacklevel=2,
)
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
_exact_tool_match: Callable[..., EvaluateResult] = cast(Callable[..., EvaluateResult], exact_tool_match_reward)
return _exact_tool_match(messages=messages, ground_truth=ground_truth, **kwargs)


@reward_function
Expand Down Expand Up @@ -537,7 +539,8 @@ def composite_function_call_reward(
DeprecationWarning,
stacklevel=2,
)
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
_exact_tool_match: Callable[..., EvaluateResult] = cast(Callable[..., EvaluateResult], exact_tool_match_reward)
return _exact_tool_match(messages=messages, ground_truth=ground_truth, **kwargs)


# JSON schema reward functions have been moved to json_schema.py module
2 changes: 1 addition & 1 deletion eval_protocol/rewards/lean_prover.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List, Optional

from eval_protocol.models import EvaluateResult, Message, MetricResult
from eval_protocol.reward_function import reward_function
from eval_protocol.typed_interface import reward_function


@reward_function
Expand Down
Loading