Skip to content

Commit 58a70d0

Browse files
cursoragentbenjibc
andcommitted
Fix type handling and add optional checks in reward and evaluation functions
Co-authored-by: bchen <bchen@fireworks.ai>
1 parent 56ea7ec commit 58a70d0

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ def tau2_airline_eval(
747747
elif role == "user":
748748
trajectory_objects.append(UserMessage(role=role, content=content))
749749
elif role == "tool":
750-
tool_id = msg.tool_call_id
750+
tool_id = msg.tool_call_id or ""
751751
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content, requestor="assistant"))
752752

753753
reward = 1.0
@@ -756,6 +756,7 @@ def tau2_airline_eval(
756756
nl_assertions=nl_assertions,
757757
communicate_info=communicate_info,
758758
actions=actions,
759+
env_assertions=None,
759760
reward_basis=[
760761
RewardType.NL_ASSERTION,
761762
RewardType.DB,
@@ -796,7 +797,8 @@ def tau2_airline_eval(
796797
action_bases = {RewardType.ACTION}
797798
nl_bases = {RewardType.NL_ASSERTION}
798799
comm_bases = {RewardType.COMMUNICATE}
799-
task_reward_basis = set(task.evaluation_criteria.reward_basis)
800+
# task.evaluation_criteria can be Optional in the type hints; guard for None
801+
task_reward_basis = set(task.evaluation_criteria.reward_basis) if task.evaluation_criteria else set()
800802

801803
reward_breakdown = {}
802804
if task_reward_basis & env_bases:

eval_protocol/rewards/lean_prover.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -467,17 +467,16 @@ def deepseek_huggingface_prover_benchmark(
467467
current_top_level_reason += f" Sub-evaluation: {result_reason}"
468468

469469
if verbose:
470+
info_payload = {
471+
"id": (dataset_item.get("id", "") if dataset_item else ""),
472+
"has_expected_proof": expected_proof is not None,
473+
"has_reference_solution": reference_solution is not None,
474+
"has_answer": (("answer" in dataset_item) if dataset_item else False),
475+
}
470476
combined_metrics["dataset_info"] = MetricResult(
471477
score=1.0,
472478
is_score_valid=True,
473-
reason=json.dumps(
474-
{
475-
"id": dataset_item.get("id", ""),
476-
"has_expected_proof": expected_proof is not None,
477-
"has_reference_solution": reference_solution is not None,
478-
"has_answer": "answer" in dataset_item if dataset_item else False,
479-
}
480-
),
479+
reason=json.dumps(info_payload),
481480
)
482481

483482
return EvaluateResult(score=result_score, reason=current_top_level_reason, metrics=combined_metrics)

eval_protocol/typed_interface.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
cast,
1414
get_args,
1515
get_origin,
16+
overload,
1617
)
1718

1819
from pydantic import TypeAdapter, ValidationError
@@ -36,6 +37,31 @@
3637
F = TypeVar("F", bound=Callable[..., Any])
3738

3839

40+
# Precise overloads help static type checkers preserve the original function signature.
41+
@overload
42+
def reward_function(
43+
_func: F,
44+
*,
45+
mode: EvaluationMode = "pointwise",
46+
id: Optional[str] = None,
47+
requirements: Optional[List[str]] = None,
48+
resources: Optional[ResourceDict] = None,
49+
concurrency: Optional[int] = None,
50+
timeout: Optional[int] = None,
51+
) -> F: ...
52+
53+
@overload
54+
def reward_function(
55+
_func: None = ..., # when used as @reward_function(...)
56+
*,
57+
mode: EvaluationMode = "pointwise",
58+
id: Optional[str] = None,
59+
requirements: Optional[List[str]] = None,
60+
resources: Optional[ResourceDict] = None,
61+
concurrency: Optional[int] = None,
62+
timeout: Optional[int] = None,
63+
) -> Callable[[F], F]: ...
64+
3965
def reward_function(
4066
_func: Optional[F] = None,
4167
*,

0 commit comments

Comments
 (0)