Skip to content

Commit 2b5e308

Browse files
cursoragentbenjibc
andcommitted
Support dict and list message types in reward functions
Co-authored-by: bchen <bchen@fireworks.ai>
1 parent 1054bf6 commit 2b5e308

File tree

5 files changed

+177
-28
lines changed

5 files changed

+177
-28
lines changed

eval_protocol/rewards/multiple_choice_math_reward.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def extract_mcq_option(text: str) -> List[Tuple[str, str]]:
8383

8484
@reward_function # type: ignore[arg-type]
8585
def multiple_choice_math_reward(
86-
messages: List[Message],
87-
ground_truth: List[Message],
86+
messages: Union[List[Message], List[Dict[str, Any]]],
87+
ground_truth: Union[List[Message], List[Dict[str, Any]]],
8888
**kwargs: Any,
8989
) -> EvaluateResult:
9090
"""
@@ -130,11 +130,34 @@ def multiple_choice_math_reward(
130130
},
131131
)
132132

133+
def _to_text(content: Any) -> str:
134+
if content is None:
135+
return ""
136+
if isinstance(content, list):
137+
parts: List[str] = []
138+
for part in content:
139+
if isinstance(part, dict):
140+
val = part.get("text")
141+
if isinstance(val, str):
142+
parts.append(val)
143+
else:
144+
text_attr = getattr(part, "text", None)
145+
if isinstance(text_attr, str):
146+
parts.append(text_attr)
147+
return "".join(parts)
148+
if isinstance(content, str):
149+
return content
150+
return str(content)
151+
133152
gen_content = ""
134153
if messages and len(messages) > 0:
135-
gen_response_message = messages[-1]
136-
if gen_response_message.role == "assistant":
137-
gen_content = gen_response_message.content or ""
154+
last_msg = messages[-1]
155+
if isinstance(last_msg, Message):
156+
if last_msg.role == "assistant":
157+
gen_content = _to_text(last_msg.content)
158+
elif isinstance(last_msg, dict):
159+
if last_msg.get("role") == "assistant":
160+
gen_content = _to_text(last_msg.get("content"))
138161

139162
if not gen_content:
140163
metrics["error_generated_message"] = MetricResult(
@@ -150,9 +173,13 @@ def multiple_choice_math_reward(
150173

151174
orig_content = ""
152175
if ground_truth and len(ground_truth) > 0:
153-
orig_response_message = ground_truth[0]
154-
if orig_response_message.role == "assistant":
155-
orig_content = orig_response_message.content or ""
176+
first_gt = ground_truth[0]
177+
if isinstance(first_gt, Message):
178+
if first_gt.role == "assistant":
179+
orig_content = _to_text(first_gt.content)
180+
elif isinstance(first_gt, dict):
181+
if first_gt.get("role") == "assistant":
182+
orig_content = _to_text(first_gt.get("content"))
156183

157184
if not orig_content:
158185
metrics["error_original_message"] = MetricResult(

eval_protocol/rewards/reasoning_steps.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
@reward_function
1616
def reasoning_steps_reward(
17-
messages: List[Message],
17+
messages: Union[List[Message], List[Dict[str, Any]]],
1818
pattern: Optional[str] = None,
1919
min_steps: int = 3,
2020
max_steps: Optional[int] = None,
@@ -48,7 +48,33 @@ def reasoning_steps_reward(
4848

4949
response = messages[-1]
5050

51-
if response.role != "assistant" or not response.content:
51+
def _to_text(content: Any) -> str:
52+
if content is None:
53+
return ""
54+
if isinstance(content, list):
55+
parts = []
56+
for part in content:
57+
if isinstance(part, dict):
58+
val = part.get("text")
59+
if isinstance(val, str):
60+
parts.append(val)
61+
else:
62+
val = getattr(part, "text", None)
63+
if isinstance(val, str):
64+
parts.append(val)
65+
return "".join(parts)
66+
if isinstance(content, str):
67+
return content
68+
return str(content)
69+
70+
if isinstance(response, Message):
71+
role_ok = response.role == "assistant"
72+
text: str = _to_text(response.content)
73+
else:
74+
role_ok = response.get("role") == "assistant"
75+
text = str(response.get("content") or "")
76+
77+
if not role_ok or not text:
5278
return EvaluateResult(
5379
score=0.0,
5480
reason="No assistant response found or response has no content",
@@ -60,7 +86,7 @@ def reasoning_steps_reward(
6086
)
6187
},
6288
)
63-
text: str = response.content
89+
# text already set
6490

6591
# Default patterns for detecting reasoning steps
6692
default_patterns = [
@@ -154,7 +180,7 @@ def reasoning_steps_reward(
154180

155181
@reward_function
156182
def sequence_reward(
157-
messages: List[Message],
183+
messages: Union[List[Message], List[Dict[str, Any]]],
158184
sequence_terms: Optional[List[str]] = None,
159185
min_matches: int = 3,
160186
case_sensitive: bool = False,
@@ -187,7 +213,33 @@ def sequence_reward(
187213

188214
response = messages[-1]
189215

190-
if response.role != "assistant" or not response.content:
216+
def _to_text(content: Any) -> str:
217+
if content is None:
218+
return ""
219+
if isinstance(content, list):
220+
parts = []
221+
for part in content:
222+
if isinstance(part, dict):
223+
val = part.get("text")
224+
if isinstance(val, str):
225+
parts.append(val)
226+
else:
227+
val = getattr(part, "text", None)
228+
if isinstance(val, str):
229+
parts.append(val)
230+
return "".join(parts)
231+
if isinstance(content, str):
232+
return content
233+
return str(content)
234+
235+
if isinstance(response, Message):
236+
role_ok = response.role == "assistant"
237+
text: str = _to_text(response.content)
238+
else:
239+
role_ok = response.get("role") == "assistant"
240+
text = str(response.get("content") or "")
241+
242+
if not role_ok or not text:
191243
return EvaluateResult(
192244
score=0.0,
193245
reason="No assistant response found or response has no content",
@@ -199,7 +251,7 @@ def sequence_reward(
199251
)
200252
},
201253
)
202-
text: str = response.content
254+
# text already set
203255

204256
if not sequence_terms:
205257
sequence_terms = [

eval_protocol/rewards/repetition.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,25 @@ def repetition_penalty_reward(
8181

8282
response = messages[-1]
8383

84+
def _to_text(content: Any) -> str:
85+
if content is None:
86+
return ""
87+
if isinstance(content, list):
88+
parts: List[str] = []
89+
for part in content:
90+
if isinstance(part, dict):
91+
val = part.get("text")
92+
if isinstance(val, str):
93+
parts.append(val)
94+
else:
95+
text_attr = getattr(part, "text", None)
96+
if isinstance(text_attr, str):
97+
parts.append(text_attr)
98+
return "".join(parts)
99+
if isinstance(content, str):
100+
return content
101+
return str(content)
102+
84103
if isinstance(response, Message):
85104
if response.role != "assistant":
86105
return EvaluateResult(
@@ -94,7 +113,7 @@ def repetition_penalty_reward(
94113
)
95114
},
96115
)
97-
text = response.content or ""
116+
text = _to_text(response.content)
98117
elif isinstance(response, dict):
99118
if response.get("role") != "assistant":
100119
return EvaluateResult(
@@ -108,7 +127,7 @@ def repetition_penalty_reward(
108127
)
109128
},
110129
)
111-
text = response.get("content", "")
130+
text = _to_text(response.get("content"))
112131
else:
113132
return EvaluateResult(
114133
score=0.0,
@@ -222,6 +241,25 @@ def diversity_reward(
222241

223242
response = messages[-1]
224243

244+
def _to_text(content: Any) -> str:
245+
if content is None:
246+
return ""
247+
if isinstance(content, list):
248+
parts: List[str] = []
249+
for part in content:
250+
if isinstance(part, dict):
251+
val = part.get("text")
252+
if isinstance(val, str):
253+
parts.append(val)
254+
else:
255+
text_attr = getattr(part, "text", None)
256+
if isinstance(text_attr, str):
257+
parts.append(text_attr)
258+
return "".join(parts)
259+
if isinstance(content, str):
260+
return content
261+
return str(content)
262+
225263
if isinstance(response, Message):
226264
if response.role != "assistant":
227265
return EvaluateResult(
@@ -235,7 +273,7 @@ def diversity_reward(
235273
)
236274
},
237275
)
238-
text = response.content or ""
276+
text = _to_text(response.content)
239277
elif isinstance(response, dict):
240278
if response.get("role") != "assistant":
241279
return EvaluateResult(
@@ -249,7 +287,7 @@ def diversity_reward(
249287
)
250288
},
251289
)
252-
text = response.get("content", "")
290+
text = _to_text(response.get("content"))
253291
else:
254292
return EvaluateResult(
255293
score=0.0,

eval_protocol/rewards/tag_count.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
@reward_function # type: ignore[arg-type]
1616
def tag_count_reward(
17-
messages: List[Message],
17+
messages: Union[List[Message], List[Dict[str, Any]]],
1818
*, # Make subsequent parameters keyword-only
1919
required_tags: List[str],
2020
score_per_tag: float = 0.25,
@@ -46,7 +46,33 @@ def tag_count_reward(
4646

4747
response = messages[-1]
4848

49-
if response.role != "assistant" or not response.content:
49+
def _to_text(content: Any) -> str:
50+
if content is None:
51+
return ""
52+
if isinstance(content, list):
53+
parts: List[str] = []
54+
for part in content:
55+
if isinstance(part, dict):
56+
val = part.get("text")
57+
if isinstance(val, str):
58+
parts.append(val)
59+
else:
60+
text_attr = getattr(part, "text", None)
61+
if isinstance(text_attr, str):
62+
parts.append(text_attr)
63+
return "".join(parts)
64+
if isinstance(content, str):
65+
return content
66+
return str(content)
67+
68+
if isinstance(response, Message):
69+
role_ok = response.role == "assistant"
70+
text: str = _to_text(response.content)
71+
else:
72+
role_ok = response.get("role") == "assistant"
73+
text = str(response.get("content") or "")
74+
75+
if not role_ok or not text:
5076
return EvaluateResult(
5177
score=0.0,
5278
reason="No assistant response found or response has no content",
@@ -58,7 +84,7 @@ def tag_count_reward(
5884
)
5985
},
6086
)
61-
text: str = response.content
87+
# text already populated above
6288

6389
tag_metrics = {}
6490
found_tags: Set[str] = set()

eval_protocol/typed_interface.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_args,
1515
get_origin,
1616
)
17+
from typing import ParamSpec # noqa: F401
1718

1819
from pydantic import TypeAdapter, ValidationError
1920

@@ -32,7 +33,7 @@
3233
# Define a type for the mode parameter
3334
EvaluationMode = Literal["pointwise", "batch"]
3435

35-
# TypeVar for the function being decorated, to preserve its signature as much as possible.
36+
# Simple TypeVar preserving original callable signature for better type inference
3637
F = TypeVar("F", bound=Callable[..., Any])
3738

3839

@@ -125,13 +126,18 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
125126
return typed_list
126127

127128
# 1. Conditional Pydantic conversion for 'messages' (pointwise) or 'rollouts_messages' (batch)
129+
def _ann_allows_list_of_message(ann: Any) -> bool:
130+
origin = get_origin(ann)
131+
if origin in (list, List):
132+
inner = get_args(ann)
133+
return bool(inner) and inner[0] == Message
134+
if origin is Union:
135+
return any(_ann_allows_list_of_message(opt) for opt in get_args(ann))
136+
return False
137+
128138
if mode == "pointwise" and "messages" in params and "messages" in final_func_args:
129139
messages_param_annotation = params["messages"].annotation
130-
if (
131-
get_origin(messages_param_annotation) in (list, List)
132-
and get_args(messages_param_annotation)
133-
and get_args(messages_param_annotation)[0] == Message
134-
):
140+
if _ann_allows_list_of_message(messages_param_annotation):
135141
try:
136142
final_func_args["messages"] = _coerce_to_list_message(final_func_args["messages"], "messages")
137143
except Exception as err:
@@ -155,7 +161,7 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
155161
# Ground truth coercion (if needed)
156162
if "ground_truth" in params and "ground_truth" in final_func_args:
157163
gt_ann = params["ground_truth"].annotation
158-
if get_origin(gt_ann) in (list, List) and get_args(gt_ann) and get_args(gt_ann)[0] == Message:
164+
if _ann_allows_list_of_message(gt_ann):
159165
if final_func_args["ground_truth"] is not None:
160166
try:
161167
final_func_args["ground_truth"] = _coerce_to_list_message(

0 commit comments

Comments
 (0)