From e4933e61432dbcd887037f061b2a06ab7452fe7c Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Mon, 1 Sep 2025 23:21:28 +0000 Subject: [PATCH 1/3] Fix type handling and add optional checks in reward and evaluation functions Co-authored-by: bchen --- .../mcp_servers/tau2/tests/test_tau2_e2e.py | 6 +++-- eval_protocol/rewards/lean_prover.py | 15 +++++------ eval_protocol/typed_interface.py | 26 +++++++++++++++++++ 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py b/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py index 5003ae16..1c062b8c 100644 --- a/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py +++ b/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py @@ -747,7 +747,7 @@ def tau2_airline_eval( elif role == "user": trajectory_objects.append(UserMessage(role=role, content=content)) elif role == "tool": - tool_id = msg.tool_call_id + tool_id = msg.tool_call_id or "" trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content, requestor="assistant")) reward = 1.0 @@ -756,6 +756,7 @@ def tau2_airline_eval( nl_assertions=nl_assertions, communicate_info=communicate_info, actions=actions, + env_assertions=None, reward_basis=[ RewardType.NL_ASSERTION, RewardType.DB, @@ -796,7 +797,8 @@ def tau2_airline_eval( action_bases = {RewardType.ACTION} nl_bases = {RewardType.NL_ASSERTION} comm_bases = {RewardType.COMMUNICATE} - task_reward_basis = set(task.evaluation_criteria.reward_basis) + # task.evaluation_criteria can be Optional in the type hints; guard for None + task_reward_basis = set(task.evaluation_criteria.reward_basis) if task.evaluation_criteria else set() reward_breakdown = {} if task_reward_basis & env_bases: diff --git a/eval_protocol/rewards/lean_prover.py b/eval_protocol/rewards/lean_prover.py index 347cf6a4..45b23bf5 100644 --- a/eval_protocol/rewards/lean_prover.py +++ b/eval_protocol/rewards/lean_prover.py @@ -467,17 +467,16 @@ def deepseek_huggingface_prover_benchmark( current_top_level_reason += f" Sub-evaluation: {result_reason}" if verbose: + info_payload = { + "id": (dataset_item.get("id", "") if dataset_item else ""), + "has_expected_proof": expected_proof is not None, + "has_reference_solution": reference_solution is not None, + "has_answer": (("answer" in dataset_item) if dataset_item else False), + } combined_metrics["dataset_info"] = MetricResult( score=1.0, is_score_valid=True, - reason=json.dumps( - { - "id": dataset_item.get("id", ""), - "has_expected_proof": expected_proof is not None, - "has_reference_solution": reference_solution is not None, - "has_answer": "answer" in dataset_item if dataset_item else False, - } - ), + reason=json.dumps(info_payload), ) return EvaluateResult(score=result_score, reason=current_top_level_reason, metrics=combined_metrics) diff --git a/eval_protocol/typed_interface.py b/eval_protocol/typed_interface.py index 97dafdf3..1d5e57ce 100644 --- a/eval_protocol/typed_interface.py +++ b/eval_protocol/typed_interface.py @@ -13,6 +13,7 @@ cast, get_args, get_origin, + overload, ) from pydantic import TypeAdapter, ValidationError @@ -36,6 +37,31 @@ F = TypeVar("F", bound=Callable[..., Any]) +# Precise overloads help static type checkers preserve the original function signature. +@overload +def reward_function( + _func: F, + *, + mode: EvaluationMode = "pointwise", + id: Optional[str] = None, + requirements: Optional[List[str]] = None, + resources: Optional[ResourceDict] = None, + concurrency: Optional[int] = None, + timeout: Optional[int] = None, +) -> F: ... + +@overload +def reward_function( + _func: None = ..., # when used as @reward_function(...) + *, + mode: EvaluationMode = "pointwise", + id: Optional[str] = None, + requirements: Optional[List[str]] = None, + resources: Optional[ResourceDict] = None, + concurrency: Optional[int] = None, + timeout: Optional[int] = None, +) -> Callable[[F], F]: ... + def reward_function( _func: Optional[F] = None, *, From f989cd1da8cc5403ddd32b4c7ac5d5dc7c93af82 Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Tue, 2 Sep 2025 20:17:43 +0800 Subject: [PATCH 2/3] fix ruff --- eval_protocol/typed_interface.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/eval_protocol/typed_interface.py b/eval_protocol/typed_interface.py index 1d5e57ce..19146a1c 100644 --- a/eval_protocol/typed_interface.py +++ b/eval_protocol/typed_interface.py @@ -50,6 +50,7 @@ def reward_function( timeout: Optional[int] = None, ) -> F: ... + @overload def reward_function( _func: None = ..., # when used as @reward_function(...) @@ -62,6 +63,7 @@ def reward_function( timeout: Optional[int] = None, ) -> Callable[[F], F]: ... + def reward_function( _func: Optional[F] = None, *, From fae0deb9b0c4b9f9b4a17a62d6ae9797f509f366 Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Tue, 2 Sep 2025 20:35:15 +0800 Subject: [PATCH 3/3] fix final type error --- .../mcp_servers/tau2/tests/test_tau2_e2e.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py b/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py index 1c062b8c..03a61be4 100644 --- a/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py +++ b/eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py @@ -752,10 +752,22 @@ def tau2_airline_eval( reward = 1.0 + # Convert incoming action dicts to typed Action objects for the evaluator + action_objs: Optional[List[Action]] = None + if actions is not None: + action_objs = [] + for a in actions: + if isinstance(a, Action): + action_objs.append(a) + elif isinstance(a, dict): + action_objs.append(Action(**a)) + else: + raise TypeError("actions must be a list of Action or dict items") + evaluation_criteria = EvaluationCriteria( nl_assertions=nl_assertions, communicate_info=communicate_info, - actions=actions, + actions=action_objs, env_assertions=None, reward_basis=[ RewardType.NL_ASSERTION,