Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,15 +747,28 @@ def tau2_airline_eval(
elif role == "user":
trajectory_objects.append(UserMessage(role=role, content=content))
elif role == "tool":
tool_id = msg.tool_call_id
tool_id = msg.tool_call_id or ""
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content, requestor="assistant"))

reward = 1.0

# Convert incoming action dicts to typed Action objects for the evaluator
action_objs: Optional[List[Action]] = None
if actions is not None:
action_objs = []
for a in actions:
if isinstance(a, Action):
action_objs.append(a)
elif isinstance(a, dict):
action_objs.append(Action(**a))
else:
raise TypeError("actions must be a list of Action or dict items")

evaluation_criteria = EvaluationCriteria(
nl_assertions=nl_assertions,
communicate_info=communicate_info,
actions=actions,
actions=action_objs,
env_assertions=None,
reward_basis=[
RewardType.NL_ASSERTION,
RewardType.DB,
Expand Down Expand Up @@ -796,7 +809,8 @@ def tau2_airline_eval(
action_bases = {RewardType.ACTION}
nl_bases = {RewardType.NL_ASSERTION}
comm_bases = {RewardType.COMMUNICATE}
task_reward_basis = set(task.evaluation_criteria.reward_basis)
# task.evaluation_criteria can be Optional in the type hints; guard for None
task_reward_basis = set(task.evaluation_criteria.reward_basis) if task.evaluation_criteria else set()

reward_breakdown = {}
if task_reward_basis & env_bases:
Expand Down
15 changes: 7 additions & 8 deletions eval_protocol/rewards/lean_prover.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,17 +467,16 @@ def deepseek_huggingface_prover_benchmark(
current_top_level_reason += f" Sub-evaluation: {result_reason}"

if verbose:
info_payload = {
"id": (dataset_item.get("id", "") if dataset_item else ""),
"has_expected_proof": expected_proof is not None,
"has_reference_solution": reference_solution is not None,
"has_answer": (("answer" in dataset_item) if dataset_item else False),
}
combined_metrics["dataset_info"] = MetricResult(
score=1.0,
is_score_valid=True,
reason=json.dumps(
{
"id": dataset_item.get("id", ""),
"has_expected_proof": expected_proof is not None,
"has_reference_solution": reference_solution is not None,
"has_answer": "answer" in dataset_item if dataset_item else False,
}
),
reason=json.dumps(info_payload),
)

return EvaluateResult(score=result_score, reason=current_top_level_reason, metrics=combined_metrics)
28 changes: 28 additions & 0 deletions eval_protocol/typed_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
cast,
get_args,
get_origin,
overload,
)

from pydantic import TypeAdapter, ValidationError
Expand All @@ -36,6 +37,33 @@
F = TypeVar("F", bound=Callable[..., Any])


# Precise overloads help static type checkers preserve the original function signature.
@overload
def reward_function(
_func: F,
*,
mode: EvaluationMode = "pointwise",
id: Optional[str] = None,
requirements: Optional[List[str]] = None,
resources: Optional[ResourceDict] = None,
concurrency: Optional[int] = None,
timeout: Optional[int] = None,
) -> F: ...


@overload
def reward_function(
_func: None = ..., # when used as @reward_function(...)
*,
mode: EvaluationMode = "pointwise",
id: Optional[str] = None,
requirements: Optional[List[str]] = None,
resources: Optional[ResourceDict] = None,
concurrency: Optional[int] = None,
timeout: Optional[int] = None,
) -> Callable[[F], F]: ...


def reward_function(
_func: Optional[F] = None,
*,
Expand Down
Loading