Skip to content

Commit 9044855

Browse files
Benny ChenBenny Chen
authored andcommitted
fix more errors
1 parent 6459b31 commit 9044855

File tree

4 files changed

+54
-35
lines changed

4 files changed

+54
-35
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,21 @@ def extract_text_content(msg_dict):
287287

288288
# calc llm usage stats happened in this turn if there is aany
289289
if usage_stats:
290-
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
291-
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
292-
trajectory.usage["total_tokens"] += usage_stats.total_tokens
290+
try:
291+
trajectory.usage["prompt_tokens"] += getattr(usage_stats, "prompt_tokens", 0)
292+
trajectory.usage["completion_tokens"] += getattr(usage_stats, "completion_tokens", 0)
293+
trajectory.usage["total_tokens"] += getattr(usage_stats, "total_tokens", 0)
294+
except Exception:
295+
# Fallback if usage_stats is a dict
296+
trajectory.usage["prompt_tokens"] += int(
297+
getattr(usage_stats, "get", lambda _k, _d=0: 0)("prompt_tokens", 0)
298+
)
299+
trajectory.usage["completion_tokens"] += int(
300+
getattr(usage_stats, "get", lambda _k, _d=0: 0)("completion_tokens", 0)
301+
)
302+
trajectory.usage["total_tokens"] += int(
303+
getattr(usage_stats, "get", lambda _k, _d=0: 0)("total_tokens", 0)
304+
)
293305

294306
# If no tool call is generated, turn is finished
295307
if len(tool_calls) == 1:
@@ -300,7 +312,7 @@ def extract_text_content(msg_dict):
300312
# If there's no user simulator, then it marks the end of the episode as LLM think there is no tool call needed.
301313
elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]:
302314
trajectory.terminated = True
303-
trajectory.termination_reason = TerminationReason.from_str(finish_reason)
315+
trajectory.termination_reason = TerminationReason.from_str(str(finish_reason))
304316
break
305317

306318
# Execute each tool call sequentially
@@ -404,11 +416,22 @@ def extract_text_content(msg_dict):
404416
)
405417
update_evaluation_row_messages()
406418
if usage_stats:
407-
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
408-
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
409-
trajectory.usage["total_tokens"] += usage_stats.total_tokens
419+
try:
420+
trajectory.usage["prompt_tokens"] += getattr(usage_stats, "prompt_tokens", 0)
421+
trajectory.usage["completion_tokens"] += getattr(usage_stats, "completion_tokens", 0)
422+
trajectory.usage["total_tokens"] += getattr(usage_stats, "total_tokens", 0)
423+
except Exception:
424+
trajectory.usage["prompt_tokens"] += int(
425+
getattr(usage_stats, "get", lambda _k, _d=0: 0)("prompt_tokens", 0)
426+
)
427+
trajectory.usage["completion_tokens"] += int(
428+
getattr(usage_stats, "get", lambda _k, _d=0: 0)("completion_tokens", 0)
429+
)
430+
trajectory.usage["total_tokens"] += int(
431+
getattr(usage_stats, "get", lambda _k, _d=0: 0)("total_tokens", 0)
432+
)
410433
trajectory.terminated = True
411-
trajectory.termination_reason = TerminationReason.from_str(finish_reason)
434+
trajectory.termination_reason = TerminationReason.from_str(str(finish_reason))
412435
trajectory.control_plane_summary.update(
413436
{
414437
"total_reward": trajectory.total_reward,

eval_protocol/mcp/mcp_multi_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mcp.client.stdio import stdio_client
1010
from mcp.client.streamable_http import streamablehttp_client
1111
from mcp.types import CallToolResult
12-
from openai.types import FunctionDefinition
12+
from openai.types.shared_params.function_definition import FunctionDefinition
1313
from openai.types.chat import ChatCompletionToolParam
1414

1515
from eval_protocol.models import (
@@ -135,7 +135,7 @@ async def get_available_tools(self) -> List[ChatCompletionToolParam]:
135135
all_tools.append(
136136
ChatCompletionToolParam(
137137
function=FunctionDefinition(
138-
name=tool.name, # Prefix with server name
138+
name=tool.name,
139139
description=tool.description,
140140
parameters=tool.inputSchema,
141141
),
@@ -155,7 +155,8 @@ async def call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> CallTool
155155
result = await session.call_tool(tool_name, tool_args)
156156
return result
157157
except Exception as e:
158-
return f"Error calling tool {tool_name}: {e}"
158+
# Re-raise the exception so the return type remains CallToolResult
159+
raise RuntimeError(f"Error calling tool {tool_name}: {e}") from e
159160

160161
async def cleanup(self):
161162
"""Clean up resources"""

eval_protocol/mcp/mcpgym.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from concurrent.futures import ThreadPoolExecutor
2525
from datetime import date, datetime
2626
from enum import Enum
27-
from typing import Any, Callable, Dict, Optional, Tuple
27+
from typing import Any, Callable, Dict, Optional, Tuple, Literal, cast
2828

2929
import uvicorn
3030
from mcp.server.fastmcp import Context, FastMCP
@@ -152,12 +152,14 @@ def _get_session_id(self, ctx: Context) -> str:
152152
print(f"🔍 _get_session_id: hasattr(client_params, 'clientInfo'): {hasattr(client_params, 'clientInfo')}")
153153

154154
if hasattr(client_params, "clientInfo"):
155-
client_info = client_params.clientInfo
155+
client_info = getattr(client_params, "clientInfo", None)
156156
print(f"🔍 _get_session_id: client_info: {client_info}")
157-
print(f"🔍 _get_session_id: hasattr(client_info, '_extra'): {hasattr(client_info, '_extra')}")
157+
print(
158+
f"🔍 _get_session_id: hasattr(client_info, '_extra'): {hasattr(client_info, '_extra') if client_info is not None else False}"
159+
)
158160

159161
if client_info and hasattr(client_info, "_extra"):
160-
extra_data = client_info._extra
162+
extra_data = getattr(client_info, "_extra", None)
161163
print(f"🔍 _get_session_id: extra_data: {extra_data}")
162164
print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}")
163165

@@ -547,7 +549,7 @@ def format_observation(self, obs: Any, env: Any) -> Dict[str, Any]:
547549
else:
548550
return {"observation": serialized_obs}
549551

550-
def run(self, transport: str = "streamable-http", **kwargs):
552+
def run(self, transport: Literal["stdio", "sse", "streamable-http"] = "streamable-http", **kwargs):
551553
"""Run the unified MCP-Gym server with high concurrency settings."""
552554
if transport == "streamable-http":
553555
# Run with custom high-concurrency uvicorn config
@@ -558,7 +560,7 @@ async def run_with_high_concurrency():
558560
if not kwargs.get("redirect_slashes", True) and hasattr(starlette_app, "router"):
559561
starlette_app.router.redirect_slashes = False
560562

561-
starlette_app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
563+
starlette_app.add_middleware(cast(Any, ProxyHeadersMiddleware), trusted_hosts="*")
562564

563565
config = uvicorn.Config(
564566
starlette_app,
@@ -606,7 +608,7 @@ def _to_json_serializable(self, obj: Any) -> Any:
606608
return obj.model_dump()
607609

608610
# Handle dataclasses
609-
elif dataclasses.is_dataclass(obj):
611+
elif dataclasses.is_dataclass(obj) and not isinstance(obj, type):
610612
return dataclasses.asdict(obj)
611613

612614
# Handle dictionaries

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,13 @@ async def setup(self):
3838
async def _get_tools(self) -> Optional[List[dict[str, Any]]]:
3939
if self.evaluation_row.tools is None:
4040
if self.mcp_client:
41-
raw_tools = await self.mcp_client.get_available_tools()
41+
raw_tools = await self.mcp_client.connect_to_servers() or None # ensure servers are connected
42+
raw_tools = await self.mcp_client.get_available_tools() if self.mcp_client else None
4243
tools_dicts: List[dict[str, Any]] = []
4344
for t in raw_tools or []:
4445
if isinstance(t, dict):
45-
# Already a dict-like structure
4646
tools_dicts.append(t)
4747
continue
48-
# Fallback: extract attributes from OpenAI types
4948
tool_type = getattr(t, "type", "function")
5049
func = getattr(t, "function", None)
5150
name = getattr(func, "name", None)
@@ -104,12 +103,10 @@ async def call_agent(self) -> Optional[Union[str, List[ChatCompletionContentPart
104103
return message.content
105104

106105
async def _call_model(self, messages: list[Message], tools: Optional[List[dict[str, Any]]]) -> Message:
107-
# Convert Message models to plain dicts for LLM call
108106
messages_payload: List[Dict[str, Any]] = [
109-
message.model_dump() if hasattr(message, "model_dump") else message # type: ignore[misc]
107+
(message.model_dump() if hasattr(message, "model_dump") else message) # type: ignore[misc]
110108
for message in messages
111109
]
112-
# Normalize tool definitions into OpenAI-compatible dicts
113110
payload_tools: List[Dict[str, Any]] = []
114111
for tool in tools or []:
115112
if isinstance(tool, dict):
@@ -119,7 +116,6 @@ async def _call_model(self, messages: list[Message], tools: Optional[List[dict[s
119116
elif isinstance(fn, dict):
120117
fn_payload = fn
121118
else:
122-
# Best effort fallback
123119
name = getattr(fn, "name", None)
124120
params = getattr(fn, "parameters", None)
125121
if hasattr(params, "model_dump"):
@@ -131,7 +127,6 @@ async def _call_model(self, messages: list[Message], tools: Optional[List[dict[s
131127
fn_payload = {"name": name, "parameters": params_payload}
132128
payload_tools.append({"type": tool.get("type", "function"), "function": fn_payload})
133129
else:
134-
# Attribute-based fallback
135130
tool_type = getattr(tool, "type", "function")
136131
func = getattr(tool, "function", None)
137132
name = getattr(func, "name", None)
@@ -145,14 +140,17 @@ async def _call_model(self, messages: list[Message], tools: Optional[List[dict[s
145140
payload_tools.append({"type": tool_type, "function": {"name": name, "parameters": params_payload}})
146141

147142
response = await self._policy._make_llm_call(messages=messages_payload, tools=payload_tools)
148-
# Coerce content to a string to align with our Message model type expectations
149143
raw_content = response["choices"][0]["message"].get("content")
150144
if isinstance(raw_content, list):
151-
content_for_model = "".join([getattr(p, "text", str(p)) for p in raw_content])
145+
146+
def _part_to_text(p: Any) -> str:
147+
return getattr(p, "text", str(p))
148+
149+
content_for_model: Union[str, List[Any]] = "".join(_part_to_text(p) for p in raw_content)
152150
else:
153151
content_for_model = raw_content
154152
return Message(
155-
role=response["choices"][0]["message"]["role"],
153+
role=response["choices"][0]["message"].get("role", "assistant"),
156154
content=content_for_model,
157155
tool_calls=response["choices"][0]["message"].get("tool_calls"),
158156
)
@@ -184,14 +182,9 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex
184182
def _format_tool_message_content(
185183
self, content: List[TextContent]
186184
) -> Union[str, List[ChatCompletionContentPartTextParam]]:
187-
"""Format tool result content for inclusion in a tool message.
188-
189-
- If a single text item, return plain string per OpenAI semantics.
190-
- If multiple items, return a list of text parts.
191-
"""
192185
if len(content) == 1 and isinstance(content[0], TextContent):
193186
return content[0].text
194-
return [ChatCompletionContentPartTextParam(text=c.text, type="text") for c in content]
187+
return [ChatCompletionContentPartTextParam(text=c.text, type="text") for c in content if hasattr(c, "text")]
195188

196189

197190
class AgentRolloutProcessor(RolloutProcessor):

0 commit comments

Comments
 (0)