Skip to content

Commit e5cb83d

Browse files
benjibccursoragentBenny Chen
authored
Fix all type errors (#152)
* Fix type handling and add optional checks in reward and evaluation functions Co-authored-by: bchen <bchen@fireworks.ai> * fix ruff * fix final type error --------- Co-authored-by: Cursor Agent <cursoragent@cursor.com> Co-authored-by: Benny Chen <bchen@Bennys-MacBook-Air.local>
1 parent 153eacd commit e5cb83d

File tree

3 files changed

+52
-11
lines changed

3 files changed

+52
-11
lines changed

eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -747,15 +747,28 @@ def tau2_airline_eval(
747747
elif role == "user":
748748
trajectory_objects.append(UserMessage(role=role, content=content))
749749
elif role == "tool":
750-
tool_id = msg.tool_call_id
750+
tool_id = msg.tool_call_id or ""
751751
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content, requestor="assistant"))
752752

753753
reward = 1.0
754754

755+
# Convert incoming action dicts to typed Action objects for the evaluator
756+
action_objs: Optional[List[Action]] = None
757+
if actions is not None:
758+
action_objs = []
759+
for a in actions:
760+
if isinstance(a, Action):
761+
action_objs.append(a)
762+
elif isinstance(a, dict):
763+
action_objs.append(Action(**a))
764+
else:
765+
raise TypeError("actions must be a list of Action or dict items")
766+
755767
evaluation_criteria = EvaluationCriteria(
756768
nl_assertions=nl_assertions,
757769
communicate_info=communicate_info,
758-
actions=actions,
770+
actions=action_objs,
771+
env_assertions=None,
759772
reward_basis=[
760773
RewardType.NL_ASSERTION,
761774
RewardType.DB,
@@ -796,7 +809,8 @@ def tau2_airline_eval(
796809
action_bases = {RewardType.ACTION}
797810
nl_bases = {RewardType.NL_ASSERTION}
798811
comm_bases = {RewardType.COMMUNICATE}
799-
task_reward_basis = set(task.evaluation_criteria.reward_basis)
812+
# task.evaluation_criteria can be Optional in the type hints; guard for None
813+
task_reward_basis = set(task.evaluation_criteria.reward_basis) if task.evaluation_criteria else set()
800814

801815
reward_breakdown = {}
802816
if task_reward_basis & env_bases:

eval_protocol/rewards/lean_prover.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -467,17 +467,16 @@ def deepseek_huggingface_prover_benchmark(
467467
current_top_level_reason += f" Sub-evaluation: {result_reason}"
468468

469469
if verbose:
470+
info_payload = {
471+
"id": (dataset_item.get("id", "") if dataset_item else ""),
472+
"has_expected_proof": expected_proof is not None,
473+
"has_reference_solution": reference_solution is not None,
474+
"has_answer": (("answer" in dataset_item) if dataset_item else False),
475+
}
470476
combined_metrics["dataset_info"] = MetricResult(
471477
score=1.0,
472478
is_score_valid=True,
473-
reason=json.dumps(
474-
{
475-
"id": dataset_item.get("id", ""),
476-
"has_expected_proof": expected_proof is not None,
477-
"has_reference_solution": reference_solution is not None,
478-
"has_answer": "answer" in dataset_item if dataset_item else False,
479-
}
480-
),
479+
reason=json.dumps(info_payload),
481480
)
482481

483482
return EvaluateResult(score=result_score, reason=current_top_level_reason, metrics=combined_metrics)

eval_protocol/typed_interface.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
cast,
1414
get_args,
1515
get_origin,
16+
overload,
1617
)
1718

1819
from pydantic import TypeAdapter, ValidationError
@@ -36,6 +37,33 @@
3637
F = TypeVar("F", bound=Callable[..., Any])
3738

3839

40+
# Precise overloads help static type checkers preserve the original function signature.
41+
@overload
42+
def reward_function(
43+
_func: F,
44+
*,
45+
mode: EvaluationMode = "pointwise",
46+
id: Optional[str] = None,
47+
requirements: Optional[List[str]] = None,
48+
resources: Optional[ResourceDict] = None,
49+
concurrency: Optional[int] = None,
50+
timeout: Optional[int] = None,
51+
) -> F: ...
52+
53+
54+
@overload
55+
def reward_function(
56+
_func: None = ..., # when used as @reward_function(...)
57+
*,
58+
mode: EvaluationMode = "pointwise",
59+
id: Optional[str] = None,
60+
requirements: Optional[List[str]] = None,
61+
resources: Optional[ResourceDict] = None,
62+
concurrency: Optional[int] = None,
63+
timeout: Optional[int] = None,
64+
) -> Callable[[F], F]: ...
65+
66+
3967
def reward_function(
4068
_func: Optional[F] = None,
4169
*,

0 commit comments

Comments
 (0)