@@ -99,6 +99,23 @@ def decorator(func: F) -> F:
9999 # Detect if the user supplied function is a coroutine (async def)
100100 _is_async_function = inspect .iscoroutinefunction (func )
101101
102+ def _is_list_of_message_annotation (annotation : Any ) -> bool :
103+ origin = get_origin (annotation )
104+ args = get_args (annotation )
105+ # Direct List[Message]
106+ if origin in (list , List ) and args and args [0 ] == Message :
107+ return True
108+ # Optional[List[Message]] or Union[List[Message], None]
109+ if origin is Union and args :
110+ # Filter out NoneType
111+ non_none = [a for a in args if a is not type (None )] # noqa: E721
112+ if len (non_none ) == 1 :
113+ inner = non_none [0 ]
114+ inner_origin = get_origin (inner )
115+ inner_args = get_args (inner )
116+ return inner_origin in (list , List ) and inner_args and inner_args [0 ] == Message
117+ return False
118+
102119 def _prepare_final_args (* args : Any , ** kwargs : Any ):
103120 """Prepare final positional and keyword arguments for the user function call.
104121 This includes Pydantic coercion and resource injection. Returns a tuple of
@@ -127,16 +144,11 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
127144 # 1. Conditional Pydantic conversion for 'messages' (pointwise) or 'rollouts_messages' (batch)
128145 if mode == "pointwise" and "messages" in params and "messages" in final_func_args :
129146 messages_param_annotation = params ["messages" ].annotation
130- if (
131- get_origin (messages_param_annotation ) in (list , List )
132- and get_args (messages_param_annotation )
133- and get_args (messages_param_annotation )[0 ] == Message
134- ):
147+ if _is_list_of_message_annotation (messages_param_annotation ):
135148 try :
136149 final_func_args ["messages" ] = _coerce_to_list_message (final_func_args ["messages" ], "messages" )
137- except Exception :
138- # Be lenient: leave messages as-is if coercion fails (backward compatibility)
139- pass
150+ except Exception as err :
151+ raise ValueError (f"Input 'messages' failed Pydantic validation: { err } " ) from None
140152
141153 elif mode == "batch" and "rollouts_messages" in params and "rollouts_messages" in final_func_args :
142154 param_annotation = params ["rollouts_messages" ].annotation
@@ -156,28 +168,22 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
156168 # Ground truth coercion (if needed)
157169 if "ground_truth" in params and "ground_truth" in final_func_args :
158170 gt_ann = params ["ground_truth" ].annotation
159- if get_origin (gt_ann ) in ( list , List ) and get_args ( gt_ann ) and get_args ( gt_ann )[ 0 ] == Message :
171+ if _is_list_of_message_annotation (gt_ann ):
160172 if final_func_args ["ground_truth" ] is not None :
161- # Accept flexible ground_truth inputs: list, dict, or str
162173 gt_val = final_func_args ["ground_truth" ]
163- if isinstance ( gt_val , list ) :
164- try :
174+ try :
175+ if isinstance ( gt_val , list ) :
165176 final_func_args ["ground_truth" ] = _coerce_to_list_message (gt_val , "ground_truth" )
166- except Exception :
167- # Leave as-is if strict coercion fails
168- pass
169- elif isinstance (gt_val , dict ):
170- try :
177+ elif isinstance (gt_val , dict ):
171178 final_func_args ["ground_truth" ] = _coerce_to_list_message ([gt_val ], "ground_truth" )
172- except Exception :
173- pass
174- elif isinstance (gt_val , str ):
175- try :
179+ elif isinstance (gt_val , str ):
176180 final_func_args ["ground_truth" ] = _coerce_to_list_message (
177181 [{"role" : "system" , "content" : gt_val }], "ground_truth"
178182 )
179- except Exception :
180- pass
183+ except Exception as err :
184+ raise ValueError (
185+ f"Input 'ground_truth' failed Pydantic validation for List[Message]: { err } "
186+ ) from None
181187
182188 # Inject resource clients into kwargs (resources are already setup)
183189 if resource_managers :
0 commit comments