Skip to content

Commit a9aad40

Browse files
Benny ChenBenny Chen
authored andcommitted
fix a few more
1 parent a58161c commit a9aad40

File tree

5 files changed

+73
-19
lines changed

5 files changed

+73
-19
lines changed

eval_protocol/rewards/accuracy.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,22 @@
1010
import re
1111
from typing import Any, Callable, Dict, List, Optional, Union, cast
1212

13-
from ..models import EvaluateResult, Message, MetricResult
13+
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
14+
15+
16+
def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]]) -> str:
17+
"""Coerce Message.content into a plain string for regex and comparisons."""
18+
if content is None:
19+
return ""
20+
if isinstance(content, str):
21+
return content
22+
# List[ChatCompletionContentPartTextParam]
23+
try:
24+
return "\n".join(part.text for part in content)
25+
except Exception:
26+
return ""
27+
28+
1429
from ..typed_interface import reward_function
1530

1631

@@ -334,7 +349,7 @@ def accuracy_reward(
334349
model_last_message = messages[-1]
335350
if isinstance(model_last_message, Message):
336351
if model_last_message.role == "assistant" and model_last_message.content is not None:
337-
model_response_text = model_last_message.content
352+
model_response_text = _to_text(model_last_message.content)
338353
else:
339354
return EvaluateResult(
340355
score=0.0,
@@ -386,7 +401,7 @@ def accuracy_reward(
386401
first_gt_message = ground_truth[0]
387402
if isinstance(first_gt_message, Message):
388403
if first_gt_message.content is not None:
389-
ground_truth_comparison_text = first_gt_message.content
404+
ground_truth_comparison_text = _to_text(first_gt_message.content)
390405
else:
391406
return EvaluateResult(
392407
score=0.0,

eval_protocol/rewards/json_schema.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
from typing import Any, Dict, List, Optional, Union
44

5-
from ..models import EvaluateResult, Message, MetricResult
5+
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
66
from ..typed_interface import reward_function
77
from .function_calling import (
88
calculate_jaccard_similarity,
@@ -54,7 +54,15 @@ def json_schema_reward(
5454

5555
if isinstance(last_message, Message):
5656
if last_message.role == "assistant" and last_message.content is not None:
57-
content_text = last_message.content
57+
# Coerce to string if content is list parts
58+
if isinstance(last_message.content, str):
59+
content_text = last_message.content
60+
else:
61+
try:
62+
parts: List[ChatCompletionContentPartTextParam] = last_message.content # type: ignore[assignment]
63+
content_text = "\n".join(p.text for p in parts)
64+
except Exception:
65+
content_text = ""
5866
else:
5967
return EvaluateResult(
6068
score=0.0,

eval_protocol/rewards/language_consistency.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import re
1010
from typing import Any, Dict, List, Optional, Set, Tuple, Union
1111

12-
from ..models import EvaluateResult, Message, MetricResult
12+
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
1313
from ..typed_interface import reward_function
1414

1515
# Dictionary mapping language codes to common words/patterns in that language
@@ -560,12 +560,7 @@ def language_consistency_reward(
560560
Returns:
561561
EvaluateResult with score based on language consistency.
562562
"""
563-
if (
564-
not messages
565-
or not isinstance(messages[-1], Message)
566-
or messages[-1].role != "assistant"
567-
or messages[-1].content is None
568-
):
563+
if not messages or not isinstance(messages[-1], Message) or messages[-1].role != "assistant":
569564
return EvaluateResult(
570565
score=0.0,
571566
reason="Invalid or missing assistant response in messages.",
@@ -578,7 +573,17 @@ def language_consistency_reward(
578573
},
579574
)
580575

581-
text_to_evaluate = messages[-1].content
576+
def _to_text(content: Union[str, List[ChatCompletionContentPartTextParam], None]) -> str:
577+
if content is None:
578+
return ""
579+
if isinstance(content, str):
580+
return content
581+
try:
582+
return "\n".join(part.text for part in content)
583+
except Exception:
584+
return ""
585+
586+
text_to_evaluate = _to_text(messages[-1].content)
582587

583588
# For test_spanish_consistency - special handling for Spanish test case
584589
if "está escrita completamente en español" in text_to_evaluate:
@@ -593,7 +598,7 @@ def language_consistency_reward(
593598
prompt_messages = messages[:-1]
594599
for msg in prompt_messages:
595600
if isinstance(msg, Message) and msg.role == "user": # Decorator ensures msg is Message
596-
content_text: str = msg.content if msg.content is not None else ""
601+
content_text: str = _to_text(msg.content)
597602
if "in Spanish" in content_text:
598603
target_language = "es"
599604
break

eval_protocol/rewards/repetition.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,20 @@
88
import re
99
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
1010

11-
from ..models import EvaluateResult, Message, MetricResult
11+
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
12+
13+
14+
def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]]) -> str:
15+
if content is None:
16+
return ""
17+
if isinstance(content, str):
18+
return content
19+
try:
20+
return "\n".join(part.text for part in content)
21+
except Exception:
22+
return ""
23+
24+
1225
from ..typed_interface import reward_function
1326

1427

@@ -94,7 +107,7 @@ def repetition_penalty_reward(
94107
)
95108
},
96109
)
97-
text = response.content or ""
110+
text = _to_text(response.content)
98111
elif isinstance(response, dict):
99112
if response.get("role") != "assistant":
100113
return EvaluateResult(

eval_protocol/rewards/tag_count.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,20 @@
88
import re
99
from typing import Any, Dict, List, Set, Union
1010

11-
from ..models import EvaluateResult, Message, MetricResult
11+
from ..models import EvaluateResult, Message, MetricResult, ChatCompletionContentPartTextParam
12+
13+
14+
def _to_text(content: Union[str, List[ChatCompletionContentPartTextParam], None]) -> str:
15+
if content is None:
16+
return ""
17+
if isinstance(content, str):
18+
return content
19+
try:
20+
return "\n".join(part.text for part in content)
21+
except Exception:
22+
return ""
23+
24+
1225
from ..typed_interface import reward_function
1326

1427

@@ -46,7 +59,7 @@ def tag_count_reward(
4659

4760
response = messages[-1]
4861

49-
if response.role != "assistant" or not response.content:
62+
if response.role != "assistant" or response.content is None:
5063
return EvaluateResult(
5164
score=0.0,
5265
reason="No assistant response found or response has no content",
@@ -58,7 +71,7 @@ def tag_count_reward(
5871
)
5972
},
6073
)
61-
text: str = response.content
74+
text: str = _to_text(response.content)
6275

6376
tag_metrics = {}
6477
found_tags: Set[str] = set()

0 commit comments

Comments
 (0)