55
66import requests
77
8- from eval_protocol .models import EvaluateResult , EvaluationRow , Message , MetricResult
8+ from eval_protocol .models import (
9+ EvaluateResult ,
10+ EvaluationRow ,
11+ Message ,
12+ MetricResult ,
13+ ChatCompletionContentPartTextParam ,
14+ )
915from eval_protocol .pytest .default_single_turn_rollout_process import (
1016 SingleTurnRolloutProcessor ,
1117)
@@ -47,6 +53,14 @@ def _load_gpqa_messages_from_csv() -> list[list[list[Message]]]:
4753 return [messages_list ]
4854
4955
56+ def _coerce_content_to_str (
57+ content : str | list [ChatCompletionContentPartTextParam ] | None ,
58+ ) -> str :
59+ if isinstance (content , list ):
60+ return "" .join ([getattr (p , "text" , str (p )) for p in content ])
61+ return str (content or "" )
62+
63+
5064def _extract_abcd_letter (text : str ) -> str | None :
5165 if not text :
5266 return None
@@ -58,9 +72,12 @@ def _extract_abcd_letter(text: str) -> str | None:
5872
5973
6074def _strip_gt_messages (msgs : list [Message ]) -> list [Message ]:
61- # assert that all the messages just have a plain .content string field
62- assert all (isinstance (m .content , str ) for m in msgs ), "Messages must have a plain .content string field"
63- return [m for m in msgs if not (m .role == "system" and (m .content or "" ).startswith ("__GT__:" ))]
75+ result : list [Message ] = []
76+ for m in msgs :
77+ content_str = _coerce_content_to_str (m .content )
78+ if not (m .role == "system" and content_str .startswith ("__GT__:" )):
79+ result .append (m )
80+ return result
6481
6582
6683class GPQAStripGTRolloutProcessor (RolloutProcessor ):
@@ -75,15 +92,23 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
7592 processed : list [EvaluationRow ] = []
7693
7794 for r in rows :
78- gt_tokens = [
79- m .content for m in r .messages if m .role == "system" and (m .content or "" ).startswith ("__GT__:" )
80- ]
95+ gt_tokens : list [str ] = []
96+ for m in r .messages :
97+ if m .role == "system" :
98+ content_str = _coerce_content_to_str (m .content )
99+ if content_str .startswith ("__GT__:" ):
100+ gt_tokens .append (content_str )
81101 if gt_tokens :
82102 gt_val = gt_tokens [- 1 ].split (":" , 1 )[1 ].strip ()
83103 r .ground_truth = gt_val
84- r .messages = [
85- m for m in r .messages if not (m .role == "system" and (m .content or "" ).startswith ("__GT__:" ))
86- ]
104+ filtered : list [Message ] = []
105+ for m in r .messages :
106+ if m .role == "system" :
107+ content_str = _coerce_content_to_str (m .content )
108+ if content_str .startswith ("__GT__:" ):
109+ continue
110+ filtered .append (m )
111+ r .messages = filtered
87112 processed .append (r )
88113
89114 # Delegate to SingleTurnRolloutProcessor
@@ -103,9 +128,10 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
103128)
104129def test_gpqa_pointwise (row : EvaluationRow ) -> EvaluationRow :
105130 assistant_msgs = [m for m in row .messages if m .role == "assistant" ]
106- content = assistant_msgs [- 1 ].content if assistant_msgs else ""
131+ raw_content = assistant_msgs [- 1 ].content if assistant_msgs else ""
132+ content_str = _coerce_content_to_str (raw_content )
107133
108- pred = _extract_abcd_letter (content or "" )
134+ pred = _extract_abcd_letter (content_str )
109135 # GPQA diamond CSV constructs options so that the correct answer is always A
110136 gt = "A"
111137
0 commit comments