Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions eval_protocol/rewards/multiple_choice_math_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def extract_mcq_option(text: str) -> List[Tuple[str, str]]:

@reward_function # type: ignore[arg-type]
def multiple_choice_math_reward(
messages: List[Message],
ground_truth: List[Message],
messages: Union[List[Message], List[Dict[str, Any]]],
ground_truth: Union[List[Message], List[Dict[str, Any]]],
**kwargs: Any,
) -> EvaluateResult:
"""
Expand Down Expand Up @@ -130,11 +130,34 @@ def multiple_choice_math_reward(
},
)

def _to_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, list):
parts: List[str] = []
for part in content:
if isinstance(part, dict):
val = part.get("text")
if isinstance(val, str):
parts.append(val)
else:
text_attr = getattr(part, "text", None)
if isinstance(text_attr, str):
parts.append(text_attr)
return "".join(parts)
if isinstance(content, str):
return content
return str(content)

gen_content = ""
if messages and len(messages) > 0:
gen_response_message = messages[-1]
if gen_response_message.role == "assistant":
gen_content = gen_response_message.content or ""
last_msg = messages[-1]
if isinstance(last_msg, Message):
if last_msg.role == "assistant":
gen_content = _to_text(last_msg.content)
elif isinstance(last_msg, dict):
if last_msg.get("role") == "assistant":
gen_content = _to_text(last_msg.get("content"))

if not gen_content:
metrics["error_generated_message"] = MetricResult(
Expand All @@ -150,9 +173,13 @@ def multiple_choice_math_reward(

orig_content = ""
if ground_truth and len(ground_truth) > 0:
orig_response_message = ground_truth[0]
if orig_response_message.role == "assistant":
orig_content = orig_response_message.content or ""
first_gt = ground_truth[0]
if isinstance(first_gt, Message):
if first_gt.role == "assistant":
orig_content = _to_text(first_gt.content)
elif isinstance(first_gt, dict):
if first_gt.get("role") == "assistant":
orig_content = _to_text(first_gt.get("content"))

if not orig_content:
metrics["error_original_message"] = MetricResult(
Expand Down
64 changes: 58 additions & 6 deletions eval_protocol/rewards/reasoning_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@reward_function
def reasoning_steps_reward(
messages: List[Message],
messages: Union[List[Message], List[Dict[str, Any]]],
pattern: Optional[str] = None,
min_steps: int = 3,
max_steps: Optional[int] = None,
Expand Down Expand Up @@ -48,7 +48,33 @@ def reasoning_steps_reward(

response = messages[-1]

if response.role != "assistant" or not response.content:
def _to_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, list):
parts = []
for part in content:
if isinstance(part, dict):
val = part.get("text")
if isinstance(val, str):
parts.append(val)
else:
val = getattr(part, "text", None)
if isinstance(val, str):
parts.append(val)
return "".join(parts)
if isinstance(content, str):
return content
return str(content)

if isinstance(response, Message):
role_ok = response.role == "assistant"
text: str = _to_text(response.content)
else:
role_ok = response.get("role") == "assistant"
text = str(response.get("content") or "")

if not role_ok or not text:
return EvaluateResult(
score=0.0,
reason="No assistant response found or response has no content",
Expand All @@ -60,7 +86,7 @@ def reasoning_steps_reward(
)
},
)
text: str = response.content
# text already set

# Default patterns for detecting reasoning steps
default_patterns = [
Expand Down Expand Up @@ -154,7 +180,7 @@ def reasoning_steps_reward(

@reward_function
def sequence_reward(
messages: List[Message],
messages: Union[List[Message], List[Dict[str, Any]]],
sequence_terms: Optional[List[str]] = None,
min_matches: int = 3,
case_sensitive: bool = False,
Expand Down Expand Up @@ -187,7 +213,33 @@ def sequence_reward(

response = messages[-1]

if response.role != "assistant" or not response.content:
def _to_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, list):
parts = []
for part in content:
if isinstance(part, dict):
val = part.get("text")
if isinstance(val, str):
parts.append(val)
else:
val = getattr(part, "text", None)
if isinstance(val, str):
parts.append(val)
return "".join(parts)
if isinstance(content, str):
return content
return str(content)

if isinstance(response, Message):
role_ok = response.role == "assistant"
text: str = _to_text(response.content)
else:
role_ok = response.get("role") == "assistant"
text = str(response.get("content") or "")

if not role_ok or not text:
return EvaluateResult(
score=0.0,
reason="No assistant response found or response has no content",
Expand All @@ -199,7 +251,7 @@ def sequence_reward(
)
},
)
text: str = response.content
# text already set

if not sequence_terms:
sequence_terms = [
Expand Down
46 changes: 42 additions & 4 deletions eval_protocol/rewards/repetition.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ def repetition_penalty_reward(

response = messages[-1]

def _to_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, list):
parts: List[str] = []
for part in content:
if isinstance(part, dict):
val = part.get("text")
if isinstance(val, str):
parts.append(val)
else:
text_attr = getattr(part, "text", None)
if isinstance(text_attr, str):
parts.append(text_attr)
return "".join(parts)
if isinstance(content, str):
return content
return str(content)

if isinstance(response, Message):
if response.role != "assistant":
return EvaluateResult(
Expand All @@ -94,7 +113,7 @@ def repetition_penalty_reward(
)
},
)
text = response.content or ""
text = _to_text(response.content)
elif isinstance(response, dict):
if response.get("role") != "assistant":
return EvaluateResult(
Expand All @@ -108,7 +127,7 @@ def repetition_penalty_reward(
)
},
)
text = response.get("content", "")
text = _to_text(response.get("content"))
else:
return EvaluateResult(
score=0.0,
Expand Down Expand Up @@ -222,6 +241,25 @@ def diversity_reward(

response = messages[-1]

def _to_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, list):
parts: List[str] = []
for part in content:
if isinstance(part, dict):
val = part.get("text")
if isinstance(val, str):
parts.append(val)
else:
text_attr = getattr(part, "text", None)
if isinstance(text_attr, str):
parts.append(text_attr)
return "".join(parts)
if isinstance(content, str):
return content
return str(content)

if isinstance(response, Message):
if response.role != "assistant":
return EvaluateResult(
Expand All @@ -235,7 +273,7 @@ def diversity_reward(
)
},
)
text = response.content or ""
text = _to_text(response.content)
elif isinstance(response, dict):
if response.get("role") != "assistant":
return EvaluateResult(
Expand All @@ -249,7 +287,7 @@ def diversity_reward(
)
},
)
text = response.get("content", "")
text = _to_text(response.get("content"))
else:
return EvaluateResult(
score=0.0,
Expand Down
32 changes: 29 additions & 3 deletions eval_protocol/rewards/tag_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@reward_function # type: ignore[arg-type]
def tag_count_reward(
messages: List[Message],
messages: Union[List[Message], List[Dict[str, Any]]],
*, # Make subsequent parameters keyword-only
required_tags: List[str],
score_per_tag: float = 0.25,
Expand Down Expand Up @@ -46,7 +46,33 @@ def tag_count_reward(

response = messages[-1]

if response.role != "assistant" or not response.content:
def _to_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, list):
parts: List[str] = []
for part in content:
if isinstance(part, dict):
val = part.get("text")
if isinstance(val, str):
parts.append(val)
else:
text_attr = getattr(part, "text", None)
if isinstance(text_attr, str):
parts.append(text_attr)
return "".join(parts)
if isinstance(content, str):
return content
return str(content)

if isinstance(response, Message):
role_ok = response.role == "assistant"
text: str = _to_text(response.content)
else:
role_ok = response.get("role") == "assistant"
text = str(response.get("content") or "")

if not role_ok or not text:
return EvaluateResult(
score=0.0,
reason="No assistant response found or response has no content",
Expand All @@ -58,7 +84,7 @@ def tag_count_reward(
)
},
)
text: str = response.content
# text already populated above

tag_metrics = {}
found_tags: Set[str] = set()
Expand Down
20 changes: 13 additions & 7 deletions eval_protocol/typed_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_args,
get_origin,
)
from typing import ParamSpec # noqa: F401

from pydantic import TypeAdapter, ValidationError

Expand All @@ -32,7 +33,7 @@
# Define a type for the mode parameter
EvaluationMode = Literal["pointwise", "batch"]

# TypeVar for the function being decorated, to preserve its signature as much as possible.
# Simple TypeVar preserving original callable signature for better type inference
F = TypeVar("F", bound=Callable[..., Any])


Expand Down Expand Up @@ -125,13 +126,18 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
return typed_list

# 1. Conditional Pydantic conversion for 'messages' (pointwise) or 'rollouts_messages' (batch)
def _ann_allows_list_of_message(ann: Any) -> bool:
origin = get_origin(ann)
if origin in (list, List):
inner = get_args(ann)
return bool(inner) and inner[0] == Message
if origin is Union:
return any(_ann_allows_list_of_message(opt) for opt in get_args(ann))
return False

if mode == "pointwise" and "messages" in params and "messages" in final_func_args:
messages_param_annotation = params["messages"].annotation
if (
get_origin(messages_param_annotation) in (list, List)
and get_args(messages_param_annotation)
and get_args(messages_param_annotation)[0] == Message
):
if _ann_allows_list_of_message(messages_param_annotation):
try:
final_func_args["messages"] = _coerce_to_list_message(final_func_args["messages"], "messages")
except Exception as err:
Expand All @@ -155,7 +161,7 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
# Ground truth coercion (if needed)
if "ground_truth" in params and "ground_truth" in final_func_args:
gt_ann = params["ground_truth"].annotation
if get_origin(gt_ann) in (list, List) and get_args(gt_ann) and get_args(gt_ann)[0] == Message:
if _ann_allows_list_of_message(gt_ann):
if final_func_args["ground_truth"] is not None:
try:
final_func_args["ground_truth"] = _coerce_to_list_message(
Expand Down
Loading