diff --git a/eval_protocol/rewards/multiple_choice_math_reward.py b/eval_protocol/rewards/multiple_choice_math_reward.py index 5768de80..f51cbbeb 100644 --- a/eval_protocol/rewards/multiple_choice_math_reward.py +++ b/eval_protocol/rewards/multiple_choice_math_reward.py @@ -83,8 +83,8 @@ def extract_mcq_option(text: str) -> List[Tuple[str, str]]: @reward_function # type: ignore[arg-type] def multiple_choice_math_reward( - messages: List[Message], - ground_truth: List[Message], + messages: Union[List[Message], List[Dict[str, Any]]], + ground_truth: Union[List[Message], List[Dict[str, Any]]], **kwargs: Any, ) -> EvaluateResult: """ @@ -130,11 +130,34 @@ def multiple_choice_math_reward( }, ) + def _to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, list): + parts: List[str] = [] + for part in content: + if isinstance(part, dict): + val = part.get("text") + if isinstance(val, str): + parts.append(val) + else: + text_attr = getattr(part, "text", None) + if isinstance(text_attr, str): + parts.append(text_attr) + return "".join(parts) + if isinstance(content, str): + return content + return str(content) + gen_content = "" if messages and len(messages) > 0: - gen_response_message = messages[-1] - if gen_response_message.role == "assistant": - gen_content = gen_response_message.content or "" + last_msg = messages[-1] + if isinstance(last_msg, Message): + if last_msg.role == "assistant": + gen_content = _to_text(last_msg.content) + elif isinstance(last_msg, dict): + if last_msg.get("role") == "assistant": + gen_content = _to_text(last_msg.get("content")) if not gen_content: metrics["error_generated_message"] = MetricResult( @@ -150,9 +173,13 @@ def multiple_choice_math_reward( orig_content = "" if ground_truth and len(ground_truth) > 0: - orig_response_message = ground_truth[0] - if orig_response_message.role == "assistant": - orig_content = orig_response_message.content or "" + first_gt = ground_truth[0] + if isinstance(first_gt, Message): + if first_gt.role == "assistant": + orig_content = _to_text(first_gt.content) + elif isinstance(first_gt, dict): + if first_gt.get("role") == "assistant": + orig_content = _to_text(first_gt.get("content")) if not orig_content: metrics["error_original_message"] = MetricResult( diff --git a/eval_protocol/rewards/reasoning_steps.py b/eval_protocol/rewards/reasoning_steps.py index 98da3d85..42a92aee 100644 --- a/eval_protocol/rewards/reasoning_steps.py +++ b/eval_protocol/rewards/reasoning_steps.py @@ -14,7 +14,7 @@ @reward_function def reasoning_steps_reward( - messages: List[Message], + messages: Union[List[Message], List[Dict[str, Any]]], pattern: Optional[str] = None, min_steps: int = 3, max_steps: Optional[int] = None, @@ -48,7 +48,33 @@ def reasoning_steps_reward( response = messages[-1] - if response.role != "assistant" or not response.content: + def _to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, list): + parts = [] + for part in content: + if isinstance(part, dict): + val = part.get("text") + if isinstance(val, str): + parts.append(val) + else: + val = getattr(part, "text", None) + if isinstance(val, str): + parts.append(val) + return "".join(parts) + if isinstance(content, str): + return content + return str(content) + + if isinstance(response, Message): + role_ok = response.role == "assistant" + text: str = _to_text(response.content) + else: + role_ok = response.get("role") == "assistant" + text = str(response.get("content") or "") + + if not role_ok or not text: return EvaluateResult( score=0.0, reason="No assistant response found or response has no content", @@ -60,7 +86,7 @@ def reasoning_steps_reward( ) }, ) - text: str = response.content + # text already set # Default patterns for detecting reasoning steps default_patterns = [ @@ -154,7 +180,7 @@ def reasoning_steps_reward( @reward_function def sequence_reward( - messages: List[Message], + messages: Union[List[Message], List[Dict[str, Any]]], sequence_terms: Optional[List[str]] = None, min_matches: int = 3, case_sensitive: bool = False, @@ -187,7 +213,33 @@ def sequence_reward( response = messages[-1] - if response.role != "assistant" or not response.content: + def _to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, list): + parts = [] + for part in content: + if isinstance(part, dict): + val = part.get("text") + if isinstance(val, str): + parts.append(val) + else: + val = getattr(part, "text", None) + if isinstance(val, str): + parts.append(val) + return "".join(parts) + if isinstance(content, str): + return content + return str(content) + + if isinstance(response, Message): + role_ok = response.role == "assistant" + text: str = _to_text(response.content) + else: + role_ok = response.get("role") == "assistant" + text = str(response.get("content") or "") + + if not role_ok or not text: return EvaluateResult( score=0.0, reason="No assistant response found or response has no content", @@ -199,7 +251,7 @@ def sequence_reward( ) }, ) - text: str = response.content + # text already set if not sequence_terms: sequence_terms = [ diff --git a/eval_protocol/rewards/repetition.py b/eval_protocol/rewards/repetition.py index aaaa247f..64a5088f 100644 --- a/eval_protocol/rewards/repetition.py +++ b/eval_protocol/rewards/repetition.py @@ -81,6 +81,25 @@ def repetition_penalty_reward( response = messages[-1] + def _to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, list): + parts: List[str] = [] + for part in content: + if isinstance(part, dict): + val = part.get("text") + if isinstance(val, str): + parts.append(val) + else: + text_attr = getattr(part, "text", None) + if isinstance(text_attr, str): + parts.append(text_attr) + return "".join(parts) + if isinstance(content, str): + return content + return str(content) + if isinstance(response, Message): if response.role != "assistant": return EvaluateResult( @@ -94,7 +113,7 @@ def repetition_penalty_reward( ) }, ) - text = response.content or "" + text = _to_text(response.content) elif isinstance(response, dict): if response.get("role") != "assistant": return EvaluateResult( @@ -108,7 +127,7 @@ def repetition_penalty_reward( ) }, ) - text = response.get("content", "") + text = _to_text(response.get("content")) else: return EvaluateResult( score=0.0, @@ -222,6 +241,25 @@ def diversity_reward( response = messages[-1] + def _to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, list): + parts: List[str] = [] + for part in content: + if isinstance(part, dict): + val = part.get("text") + if isinstance(val, str): + parts.append(val) + else: + text_attr = getattr(part, "text", None) + if isinstance(text_attr, str): + parts.append(text_attr) + return "".join(parts) + if isinstance(content, str): + return content + return str(content) + if isinstance(response, Message): if response.role != "assistant": return EvaluateResult( @@ -235,7 +273,7 @@ def diversity_reward( ) }, ) - text = response.content or "" + text = _to_text(response.content) elif isinstance(response, dict): if response.get("role") != "assistant": return EvaluateResult( @@ -249,7 +287,7 @@ def diversity_reward( ) }, ) - text = response.get("content", "") + text = _to_text(response.get("content")) else: return EvaluateResult( score=0.0, diff --git a/eval_protocol/rewards/tag_count.py b/eval_protocol/rewards/tag_count.py index 83acef6f..8109bf4a 100644 --- a/eval_protocol/rewards/tag_count.py +++ b/eval_protocol/rewards/tag_count.py @@ -14,7 +14,7 @@ @reward_function # type: ignore[arg-type] def tag_count_reward( - messages: List[Message], + messages: Union[List[Message], List[Dict[str, Any]]], *, # Make subsequent parameters keyword-only required_tags: List[str], score_per_tag: float = 0.25, @@ -46,7 +46,33 @@ def tag_count_reward( response = messages[-1] - if response.role != "assistant" or not response.content: + def _to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, list): + parts: List[str] = [] + for part in content: + if isinstance(part, dict): + val = part.get("text") + if isinstance(val, str): + parts.append(val) + else: + text_attr = getattr(part, "text", None) + if isinstance(text_attr, str): + parts.append(text_attr) + return "".join(parts) + if isinstance(content, str): + return content + return str(content) + + if isinstance(response, Message): + role_ok = response.role == "assistant" + text: str = _to_text(response.content) + else: + role_ok = response.get("role") == "assistant" + text = str(response.get("content") or "") + + if not role_ok or not text: return EvaluateResult( score=0.0, reason="No assistant response found or response has no content", @@ -58,7 +84,7 @@ def tag_count_reward( ) }, ) - text: str = response.content + # text already populated above tag_metrics = {} found_tags: Set[str] = set() diff --git a/eval_protocol/typed_interface.py b/eval_protocol/typed_interface.py index 054babee..7864f6d3 100644 --- a/eval_protocol/typed_interface.py +++ b/eval_protocol/typed_interface.py @@ -14,6 +14,7 @@ get_args, get_origin, ) +from typing import ParamSpec # noqa: F401 from pydantic import TypeAdapter, ValidationError @@ -32,7 +33,7 @@ # Define a type for the mode parameter EvaluationMode = Literal["pointwise", "batch"] -# TypeVar for the function being decorated, to preserve its signature as much as possible. +# Simple TypeVar preserving original callable signature for better type inference F = TypeVar("F", bound=Callable[..., Any]) @@ -125,13 +126,18 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes return typed_list # 1. Conditional Pydantic conversion for 'messages' (pointwise) or 'rollouts_messages' (batch) + def _ann_allows_list_of_message(ann: Any) -> bool: + origin = get_origin(ann) + if origin in (list, List): + inner = get_args(ann) + return bool(inner) and inner[0] == Message + if origin is Union: + return any(_ann_allows_list_of_message(opt) for opt in get_args(ann)) + return False + if mode == "pointwise" and "messages" in params and "messages" in final_func_args: messages_param_annotation = params["messages"].annotation - if ( - get_origin(messages_param_annotation) in (list, List) - and get_args(messages_param_annotation) - and get_args(messages_param_annotation)[0] == Message - ): + if _ann_allows_list_of_message(messages_param_annotation): try: final_func_args["messages"] = _coerce_to_list_message(final_func_args["messages"], "messages") except Exception as err: @@ -155,7 +161,7 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes # Ground truth coercion (if needed) if "ground_truth" in params and "ground_truth" in final_func_args: gt_ann = params["ground_truth"].annotation - if get_origin(gt_ann) in (list, List) and get_args(gt_ann) and get_args(gt_ann)[0] == Message: + if _ann_allows_list_of_message(gt_ann): if final_func_args["ground_truth"] is not None: try: final_func_args["ground_truth"] = _coerce_to_list_message(