Skip to content

Commit 84f3fc3

Browse files
Benny ChenBenny Chen
authored andcommitted
fix more tests
1 parent ffaebe8 commit 84f3fc3

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

eval_protocol/mcp_agent/orchestration/local_docker_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ async def list_tools_on_instance(self, instance: ManagedInstanceInfo) -> types.L
649649
)
650650
target_base_url = instance.mcp_endpoint_url.rstrip("/")
651651
try:
652-
async with streamablehttp_client(target_base_url) as (
652+
async with streamablehttp_client(base_url=target_base_url) as (
653653
read_s,
654654
write_s,
655655
_, # get_session_id_func usually not needed for a single call

eval_protocol/typed_interface.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,23 @@ def decorator(func: F) -> F:
9999
# Detect if the user supplied function is a coroutine (async def)
100100
_is_async_function = inspect.iscoroutinefunction(func)
101101

102+
def _is_list_of_message_annotation(annotation: Any) -> bool:
103+
origin = get_origin(annotation)
104+
args = get_args(annotation)
105+
# Direct List[Message]
106+
if origin in (list, List) and args and args[0] == Message:
107+
return True
108+
# Optional[List[Message]] or Union[List[Message], None]
109+
if origin is Union and args:
110+
# Filter out NoneType
111+
non_none = [a for a in args if a is not type(None)] # noqa: E721
112+
if len(non_none) == 1:
113+
inner = non_none[0]
114+
inner_origin = get_origin(inner)
115+
inner_args = get_args(inner)
116+
return inner_origin in (list, List) and inner_args and inner_args[0] == Message
117+
return False
118+
102119
def _prepare_final_args(*args: Any, **kwargs: Any):
103120
"""Prepare final positional and keyword arguments for the user function call.
104121
This includes Pydantic coercion and resource injection. Returns a tuple of
@@ -127,16 +144,11 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
127144
# 1. Conditional Pydantic conversion for 'messages' (pointwise) or 'rollouts_messages' (batch)
128145
if mode == "pointwise" and "messages" in params and "messages" in final_func_args:
129146
messages_param_annotation = params["messages"].annotation
130-
if (
131-
get_origin(messages_param_annotation) in (list, List)
132-
and get_args(messages_param_annotation)
133-
and get_args(messages_param_annotation)[0] == Message
134-
):
147+
if _is_list_of_message_annotation(messages_param_annotation):
135148
try:
136149
final_func_args["messages"] = _coerce_to_list_message(final_func_args["messages"], "messages")
137-
except Exception:
138-
# Be lenient: leave messages as-is if coercion fails (backward compatibility)
139-
pass
150+
except Exception as err:
151+
raise ValueError(f"Input 'messages' failed Pydantic validation: {err}") from None
140152

141153
elif mode == "batch" and "rollouts_messages" in params and "rollouts_messages" in final_func_args:
142154
param_annotation = params["rollouts_messages"].annotation
@@ -156,28 +168,22 @@ def _coerce_to_list_message(data_list: Any, arg_name_for_error: str) -> List[Mes
156168
# Ground truth coercion (if needed)
157169
if "ground_truth" in params and "ground_truth" in final_func_args:
158170
gt_ann = params["ground_truth"].annotation
159-
if get_origin(gt_ann) in (list, List) and get_args(gt_ann) and get_args(gt_ann)[0] == Message:
171+
if _is_list_of_message_annotation(gt_ann):
160172
if final_func_args["ground_truth"] is not None:
161-
# Accept flexible ground_truth inputs: list, dict, or str
162173
gt_val = final_func_args["ground_truth"]
163-
if isinstance(gt_val, list):
164-
try:
174+
try:
175+
if isinstance(gt_val, list):
165176
final_func_args["ground_truth"] = _coerce_to_list_message(gt_val, "ground_truth")
166-
except Exception:
167-
# Leave as-is if strict coercion fails
168-
pass
169-
elif isinstance(gt_val, dict):
170-
try:
177+
elif isinstance(gt_val, dict):
171178
final_func_args["ground_truth"] = _coerce_to_list_message([gt_val], "ground_truth")
172-
except Exception:
173-
pass
174-
elif isinstance(gt_val, str):
175-
try:
179+
elif isinstance(gt_val, str):
176180
final_func_args["ground_truth"] = _coerce_to_list_message(
177181
[{"role": "system", "content": gt_val}], "ground_truth"
178182
)
179-
except Exception:
180-
pass
183+
except Exception as err:
184+
raise ValueError(
185+
f"Input 'ground_truth' failed Pydantic validation for List[Message]: {err}"
186+
) from None
181187

182188
# Inject resource clients into kwargs (resources are already setup)
183189
if resource_managers:

0 commit comments

Comments
 (0)