diff --git a/AGENTS.md b/AGENTS.md index 845893e..ad5ff65 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -380,8 +380,34 @@ async def deploy_multiple_agents(): # 运行 agents = asyncio.run(deploy_multiple_agents()) + +### 8. 变更后的类型检查要求 + +所有由 AI(或自动化 agent)提交或修改的代码变更,必须在提交/合并前后执行静态类型检查,并在变更记录中包含检查结果摘要: + +- **运行命令**:使用项目根目录的 mypy 配置运行: + + ```bash + mypy --config-file mypy.ini . + ``` + +- **必需项**:AI 在每次修改代码并准备提交时,必须: + - 运行上述类型检查命令并等待完成; + - 若检查通过,在提交消息或 PR 描述中写入简短摘要(例如:"类型检查通过,检查文件数 N"); + - 若检查失败,AI 应在 PR 描述中列出前 30 条错误(或最关键的若干条),并给出优先修复建议或自动修复方案。 + +- **CI 行为**:项目 CI 可根据仓库策略决定是否将类型检查失败作为阻断条件;AI 应遵从仓库当前 CI 策略并在 PR 中说明检查结果。 + +此要求旨在保证类型安全随代码变更持续得到验证,减少回归并提高编辑器与 Copilot 的诊断可靠性。 ``` +### 9. 运行命令约定 + +请使用 `uv run ...` 执行所有 Python 相关命令,避免直接调用系统的 `python`。例如: + +- `uv run mypy --config-file mypy.ini .` +- `uv run python examples/quick_start.py` + ## 常见问题 ### Q: Agent Runtime 启动失败怎么办? diff --git a/agentrun/__init__.py b/agentrun/__init__.py index 7ca1290..4508200 100644 --- a/agentrun/__init__.py +++ b/agentrun/__init__.py @@ -90,11 +90,8 @@ # Server from agentrun.server import ( AgentRequest, - AgentResponse, AgentResult, AgentRunServer, - AgentStreamIterator, - AgentStreamResponse, AsyncInvokeAgentHandler, InvokeAgentHandler, Message, diff --git a/agentrun/integration/langchain/__init__.py b/agentrun/integration/langchain/__init__.py index 4fc8bc2..a094bfa 100644 --- a/agentrun/integration/langchain/__init__.py +++ b/agentrun/integration/langchain/__init__.py @@ -1,8 +1,29 @@ -"""LangChain 集成模块,提供 AgentRun 模型与沙箱的 LangChain 适配。 / LangChain 集成 Module""" +"""LangChain 集成模块 + +使用 AgentRunConverter 将 LangChain 事件转换为 AG-UI 协议事件: + + >>> from agentrun.integration.langchain import AgentRunConverter + >>> + >>> async def invoke_agent(request: AgentRequest): + ... converter = AgentRunConverter() + ... async for event in agent.astream_events(input_data, version="v2"): + ... for item in converter.convert(event): + ... yield item + +支持多种调用方式: +- agent.astream_events(input, version="v2") - 支持 token by token +- agent.stream(input, stream_mode="updates") - 按节点输出 +- agent.astream(input, stream_mode="updates") - 异步按节点输出 +""" + +from agentrun.integration.langgraph.agent_converter import ( + AgentRunConverter, +) # 向后兼容 from .builtin import model, sandbox_toolset, toolset __all__ = [ + "AgentRunConverter", "model", "toolset", "sandbox_toolset", diff --git a/agentrun/integration/langchain/model_adapter.py b/agentrun/integration/langchain/model_adapter.py index cc729e6..8f9e494 100644 --- a/agentrun/integration/langchain/model_adapter.py +++ b/agentrun/integration/langchain/model_adapter.py @@ -23,7 +23,6 @@ def __init__(self): def wrap_model(self, common_model: Any) -> Any: """包装 CommonModel 为 LangChain BaseChatModel / LangChain Model Adapter""" - from httpx import AsyncClient from langchain_openai import ChatOpenAI info = common_model.get_model_info() # 确保模型可用 @@ -32,6 +31,7 @@ def wrap_model(self, common_model: Any) -> Any: api_key=info.api_key, model=info.model, base_url=info.base_url, - async_client=AsyncClient(headers=info.headers), + default_headers=info.headers, stream_usage=True, + streaming=True, # 启用流式输出以支持 token by token ) diff --git a/agentrun/integration/langgraph/__init__.py b/agentrun/integration/langgraph/__init__.py index d90648a..a0e9a68 100644 --- a/agentrun/integration/langgraph/__init__.py +++ b/agentrun/integration/langgraph/__init__.py @@ -1,12 +1,34 @@ -"""LangGraph 集成模块。 / LangGraph 集成 Module +"""LangGraph 集成模块 -提供 AgentRun 模型与沙箱工具的 LangGraph 适配入口。 / 提供 AgentRun 模型with沙箱工具的 LangGraph 适配入口。 -LangGraph 与 LangChain 兼容,因此直接复用 LangChain 的转换逻辑。 / LangGraph with LangChain 兼容,因此直接复用 LangChain 的转换逻辑。 +使用 AgentRunConverter 将 LangGraph 事件转换为 AG-UI 协议事件: + + >>> from agentrun.integration.langgraph import AgentRunConverter + >>> + >>> async def invoke_agent(request: AgentRequest): + ... converter = AgentRunConverter() + ... async for event in agent.astream_events(input_data, version="v2"): + ... for item in converter.convert(event): + ... yield item + +或使用静态方法(无状态): + + >>> from agentrun.integration.langgraph import AgentRunConverter + >>> + >>> async for event in agent.astream_events(input_data, version="v2"): + ... for item in AgentRunConverter.to_agui_events(event): + ... yield item + +支持多种调用方式: +- agent.astream_events(input, version="v2") - 支持 token by token +- agent.stream(input, stream_mode="updates") - 按节点输出 +- agent.astream(input, stream_mode="updates") - 异步按节点输出 """ +from .agent_converter import AgentRunConverter from .builtin import model, sandbox_toolset, toolset __all__ = [ + "AgentRunConverter", "model", "toolset", "sandbox_toolset", diff --git a/agentrun/integration/langgraph/agent_converter.py b/agentrun/integration/langgraph/agent_converter.py new file mode 100644 index 0000000..2055b4b --- /dev/null +++ b/agentrun/integration/langgraph/agent_converter.py @@ -0,0 +1,1073 @@ +"""LangGraph/LangChain 事件转换模块 / LangGraph/LangChain Event Converter + +提供将 LangGraph/LangChain 流式事件转换为 AG-UI 协议事件的方法。 + +使用示例: + + # 使用 AgentRunConverter 类(推荐) + >>> converter = AgentRunConverter() + >>> async for event in agent.astream_events(input_data, version="v2"): + ... for item in converter.convert(event): + ... yield item + + # 使用静态方法(无状态) + >>> async for event in agent.astream_events(input_data, version="v2"): + ... for item in AgentRunConverter.to_agui_events(event): + ... yield item + + # 使用 stream (updates 模式) + >>> for event in agent.stream(input_data, stream_mode="updates"): + ... for item in AgentRunConverter.to_agui_events(event): + ... yield item + + # 使用 astream (updates 模式) + >>> async for event in agent.astream(input_data, stream_mode="updates"): + ... for item in AgentRunConverter.to_agui_events(event): + ... yield item +""" + +import json +from typing import Any, Dict, Iterator, List, Optional, Union + +from agentrun.server.model import AgentResult, EventType +from agentrun.utils.log import logger + +# 需要从工具输入中过滤掉的内部字段(LangGraph/MCP 注入的运行时对象) +_TOOL_INPUT_INTERNAL_KEYS = frozenset({ + "runtime", # MCP ToolRuntime 对象 + "__pregel_runtime", + "__pregel_task_id", + "__pregel_send", + "__pregel_read", + "__pregel_checkpointer", + "__pregel_scratchpad", + "__pregel_call", + "config", # LangGraph config 对象,包含内部状态 + "configurable", +}) + + +class AgentRunConverter: + """AgentRun 事件转换器 + + 将 LangGraph/LangChain 流式事件转换为 AG-UI 协议事件。 + 此类维护必要的状态以确保: + 1. 流式工具调用的 tool_call_id 一致性 + 2. AG-UI 协议要求的事件顺序(TOOL_CALL_START → TOOL_CALL_ARGS → TOOL_CALL_END) + + 在流式工具调用中,第一个 chunk 包含 id 和 name,后续 chunk 只有 index 和 args。 + 此类维护 index -> id 的映射,确保所有相关事件使用相同的 tool_call_id。 + + 同时,此类跟踪已发送 TOOL_CALL_START 的工具调用,确保: + - 在流式场景中,TOOL_CALL_START 在第一个参数 chunk 前发送 + - 避免在 on_tool_start 中重复发送 TOOL_CALL_START + + Example: + >>> from agentrun.integration.langgraph import AgentRunConverter + >>> + >>> async def invoke_agent(request: AgentRequest): + ... converter = AgentRunConverter() + ... async for event in agent.astream_events(input, version="v2"): + ... for item in converter.convert(event): + ... yield item + """ + + def __init__(self) -> None: + self._tool_call_id_map: Dict[int, str] = {} + self._tool_call_started_set: set = set() + # tool_name -> [tool_call_id] 队列映射 + # 用于在 on_tool_start 中查找对应的 tool_call_id(当 runtime.tool_call_id 不可用时) + self._tool_name_to_call_ids: Dict[str, List[str]] = {} + # run_id -> tool_call_id 映射 + # 用于在 on_tool_end 中查找对应的 tool_call_id + self._run_id_to_tool_call_id: Dict[str, str] = {} + + def convert( + self, + event: Union[Dict[str, Any], Any], + messages_key: str = "messages", + ) -> Iterator[Union[AgentResult, str]]: + """转换单个事件为 AG-UI 协议事件 + + Args: + event: LangGraph/LangChain 流式事件(StreamEvent 对象或 Dict) + messages_key: state 中消息列表的 key,默认 "messages" + + Yields: + str (文本内容) 或 AgentResult (AG-UI 事件) + """ + # 调试日志:输入事件 + event_dict = self._event_to_dict(event) + event_type = event_dict.get("event", "") + + # 始终打印事件类型,用于调试 + logger.debug( + f"[AgentRunConverter] Raw event type: {type(event).__name__}, " + f"event_type={event_type}, " + f"is_dict={isinstance(event, dict)}" + ) + + if event_type in ( + "on_chat_model_stream", + "on_tool_start", + "on_tool_end", + ): + logger.debug( + f"[AgentRunConverter] Input event: type={event_type}, " + f"run_id={event_dict.get('run_id', '')}, " + f"name={event_dict.get('name', '')}, " + f"tool_call_started_set={self._tool_call_started_set}, " + f"tool_name_to_call_ids={self._tool_name_to_call_ids}" + ) + + for item in self.to_agui_events( + event, + messages_key, + self._tool_call_id_map, + self._tool_call_started_set, + self._tool_name_to_call_ids, + self._run_id_to_tool_call_id, + ): + # 调试日志:输出事件 + if isinstance(item, AgentResult): + logger.debug(f"[AgentRunConverter] Output event: {item}") + yield item + + def reset(self) -> None: + """重置状态,清空 tool_call_id 映射和已发送状态 + + 在处理新的请求时,建议创建新的 AgentRunConverter 实例, + 而不是复用旧实例并调用 reset。 + """ + self._tool_call_id_map.clear() + self._tool_call_started_set.clear() + self._tool_name_to_call_ids.clear() + self._run_id_to_tool_call_id.clear() + + # ========================================================================= + # 内部工具方法(静态方法) + # ========================================================================= + + @staticmethod + def _format_tool_output(output: Any) -> str: + """格式化工具输出为字符串,优先提取常见字段或 content 属性,最后回退到 JSON/str。""" + if output is None: + return "" + # dict-like + if isinstance(output, dict): + for key in ("content", "result", "output"): + if key in output: + v = output[key] + if isinstance(v, (dict, list)): + return json.dumps(v, ensure_ascii=False) + return str(v) if v is not None else "" + try: + return json.dumps(output, ensure_ascii=False) + except Exception: + return str(output) + + # 对象有 content 属性 + if hasattr(output, "content"): + c = AgentRunConverter._get_message_content(output) + if isinstance(c, (dict, list)): + try: + return json.dumps(c, ensure_ascii=False) + except Exception: + return str(c) + return c or "" + + try: + return str(output) + except Exception: + return "" + + @staticmethod + def _safe_json_dumps(obj: Any) -> str: + """JSON 序列化兜底,无法序列化则回退到 str。""" + try: + return json.dumps(obj, ensure_ascii=False, default=str) + except Exception: + try: + return str(obj) + except Exception: + return "" + + @staticmethod + def _filter_tool_input(tool_input: Any) -> Any: + """过滤工具输入中的内部字段,只保留用户传入的实际参数。 + + Args: + tool_input: 工具输入(可能是 dict 或其他类型) + + Returns: + 过滤后的工具输入 + """ + if not isinstance(tool_input, dict): + return tool_input + + filtered = {} + for key, value in tool_input.items(): + # 跳过内部字段 + if key in _TOOL_INPUT_INTERNAL_KEYS: + continue + # 跳过所有下划线前缀的内部字段(包含单下划线与双下划线) + if key.startswith("_"): + continue + filtered[key] = value + + return filtered + + @staticmethod + def _extract_tool_call_id(tool_input: Any) -> Optional[str]: + """从工具输入中提取原始的 tool_call_id。 + + MCP 工具会在 input 中注入 runtime 对象,其中包含 LLM 返回的原始 tool_call_id。 + 使用这个 ID 可以保证工具调用事件的 ID 一致性。 + + Args: + tool_input: 工具输入(可能是 dict 或其他类型) + + Returns: + tool_call_id 或 None + """ + if not isinstance(tool_input, dict): + return None + + # 尝试从 runtime 对象中提取 tool_call_id + runtime = tool_input.get("runtime") + if runtime is not None and hasattr(runtime, "tool_call_id"): + tc_id = runtime.tool_call_id + if isinstance(tc_id, str) and tc_id: + return tc_id + + return None + + @staticmethod + def _extract_content(chunk: Any) -> Optional[str]: + """从 chunk 中提取文本内容""" + if chunk is None: + return None + + if hasattr(chunk, "content"): + content = chunk.content + if isinstance(content, str): + return content if content else None + if isinstance(content, list): + text_parts = [] + for item in content: + if isinstance(item, str): + text_parts.append(item) + elif isinstance(item, dict) and item.get("type") == "text": + text_parts.append(item.get("text", "")) + return "".join(text_parts) if text_parts else None + + return None + + @staticmethod + def _extract_tool_call_chunks(chunk: Any) -> List[Dict]: + """从 AIMessageChunk 中提取工具调用增量""" + tool_calls = [] + + if hasattr(chunk, "tool_call_chunks") and chunk.tool_call_chunks: + for tc in chunk.tool_call_chunks: + if isinstance(tc, dict): + tool_calls.append(tc) + else: + tool_calls.append({ + "id": getattr(tc, "id", None), + "name": getattr(tc, "name", None), + "args": getattr(tc, "args", None), + "index": getattr(tc, "index", None), + }) + + return tool_calls + + @staticmethod + def _get_message_type(msg: Any) -> str: + """获取消息类型""" + if hasattr(msg, "type"): + return str(msg.type).lower() + + if isinstance(msg, dict): + msg_type = msg.get("type", msg.get("role", "")) + return str(msg_type).lower() + + class_name = type(msg).__name__.lower() + if "ai" in class_name or "assistant" in class_name: + return "ai" + if "tool" in class_name: + return "tool" + if "human" in class_name or "user" in class_name: + return "human" + + return "unknown" + + @staticmethod + def _get_message_content(msg: Any) -> Optional[str]: + """获取消息内容""" + if hasattr(msg, "content"): + content = msg.content + if isinstance(content, str): + return content + return str(content) if content else None + + if isinstance(msg, dict): + return msg.get("content") + + return None + + @staticmethod + def _get_message_tool_calls(msg: Any) -> List[Dict]: + """获取消息中的工具调用""" + if hasattr(msg, "tool_calls") and msg.tool_calls: + tool_calls = [] + for tc in msg.tool_calls: + if isinstance(tc, dict): + tool_calls.append(tc) + else: + tool_calls.append({ + "id": getattr(tc, "id", None), + "name": getattr(tc, "name", None), + "args": getattr(tc, "args", None), + }) + return tool_calls + + if isinstance(msg, dict) and msg.get("tool_calls"): + return msg["tool_calls"] + + return [] + + @staticmethod + def _get_tool_call_id(msg: Any) -> Optional[str]: + """获取 ToolMessage 的 tool_call_id""" + if hasattr(msg, "tool_call_id"): + return msg.tool_call_id + + if isinstance(msg, dict): + return msg.get("tool_call_id") + + return None + + # ========================================================================= + # 事件格式检测(静态方法) + # ========================================================================= + + @staticmethod + def _event_to_dict(event: Any) -> Dict[str, Any]: + """将 StreamEvent 或 dict 标准化为 dict 以便后续处理""" + if isinstance(event, dict): + return event + + result: Dict[str, Any] = {} + # 常见属性映射,兼容多种 StreamEvent 实现 + if hasattr(event, "event"): + result["event"] = getattr(event, "event") + if hasattr(event, "data"): + result["data"] = getattr(event, "data") + if hasattr(event, "name"): + result["name"] = getattr(event, "name") + if hasattr(event, "run_id"): + result["run_id"] = getattr(event, "run_id") + + return result + + @staticmethod + def is_astream_events_format(event_dict: Dict[str, Any]) -> bool: + """检测是否是 astream_events 格式的事件 + + astream_events 格式特征:有 "event" 字段,值以 "on_" 开头 + """ + event_type = event_dict.get("event", "") + return isinstance(event_type, str) and event_type.startswith("on_") + + @staticmethod + def is_stream_updates_format(event_dict: Dict[str, Any]) -> bool: + """检测是否是 stream/astream(stream_mode="updates") 格式的事件 + + updates 格式特征:{node_name: {messages_key: [...]}} 或 {node_name: state_dict} + 没有 "event" 字段,键是 node 名称(如 "model", "agent", "tools"),值是 state 更新 + + 与 values 格式的区别: + - updates: {node_name: {messages: [...]}} - 嵌套结构 + - values: {messages: [...]} - 扁平结构 + """ + if "event" in event_dict: + return False + + # 如果直接包含 "messages" 键且值是 list,这是 values 格式,不是 updates + if "messages" in event_dict and isinstance( + event_dict["messages"], list + ): + return False + + # 检查是否有类似 node 更新的结构 + for key, value in event_dict.items(): + if key == "__end__": + continue + # value 应该是一个 dict(state 更新),包含 messages 等字段 + if isinstance(value, dict): + return True + + return False + + @staticmethod + def is_stream_values_format(event_dict: Dict[str, Any]) -> bool: + """检测是否是 stream/astream(stream_mode="values") 格式的事件 + + values 格式特征:直接是完整 state,如 {messages: [...], ...} + 没有 "event" 字段,直接包含 "messages" 或类似的 state 字段 + + 与 updates 格式的区别: + - values: {messages: [...]} - 扁平结构,messages 值直接是 list + - updates: {node_name: {messages: [...]}} - 嵌套结构 + """ + if "event" in event_dict: + return False + + # 检查是否直接包含 messages 列表(扁平结构) + if "messages" in event_dict and isinstance( + event_dict["messages"], list + ): + return True + + return False + + # ========================================================================= + # 事件转换器(静态方法) + # ========================================================================= + + @staticmethod + def _convert_stream_updates_event( + event_dict: Dict[str, Any], + messages_key: str = "messages", + ) -> Iterator[Union[AgentResult, str]]: + """转换 stream/astream(stream_mode="updates") 格式的单个事件 + + Args: + event_dict: 事件字典,格式为 {node_name: state_update} + messages_key: state 中消息列表的 key + + Yields: + str (文本内容) 或 AgentResult (事件) + + Note: + 在 updates 模式下,工具调用和结果在不同的事件中: + - AI 消息包含 tool_calls(仅发送 TOOL_CALL_START + TOOL_CALL_ARGS) + - Tool 消息包含结果(发送 TOOL_CALL_RESULT + TOOL_CALL_END) + """ + for node_name, state_update in event_dict.items(): + if node_name == "__end__": + continue + + if not isinstance(state_update, dict): + continue + + messages = state_update.get(messages_key, []) + if not isinstance(messages, list): + # 尝试其他常见的 key + for alt_key in ("message", "output", "response"): + if alt_key in state_update: + alt_value = state_update[alt_key] + if isinstance(alt_value, list): + messages = alt_value + break + elif hasattr(alt_value, "content"): + messages = [alt_value] + break + + for msg in messages: + msg_type = AgentRunConverter._get_message_type(msg) + + if msg_type == "ai": + # 文本内容 + content = AgentRunConverter._get_message_content(msg) + if content: + yield content + + # 工具调用(仅发送 START 和 ARGS,END 在收到结果后发送) + for tc in AgentRunConverter._get_message_tool_calls(msg): + tc_id = tc.get("id", "") + tc_name = tc.get("name", "") + tc_args = tc.get("args", {}) + + if tc_id: + # 发送带有完整参数的 TOOL_CALL_CHUNK + args_str = "" + if tc_args: + args_str = ( + AgentRunConverter._safe_json_dumps(tc_args) + if isinstance(tc_args, dict) + else str(tc_args) + ) + yield AgentResult( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": tc_id, + "name": tc_name, + "args_delta": args_str, + }, + ) + + elif msg_type == "tool": + # 工具结果 + tool_call_id = AgentRunConverter._get_tool_call_id(msg) + if tool_call_id: + tool_content = AgentRunConverter._get_message_content( + msg + ) + yield AgentResult( + event=EventType.TOOL_RESULT, + data={ + "id": tool_call_id, + "result": ( + str(tool_content) if tool_content else "" + ), + }, + ) + + @staticmethod + def _convert_stream_values_event( + event_dict: Dict[str, Any], + messages_key: str = "messages", + ) -> Iterator[Union[AgentResult, str]]: + """转换 stream/astream(stream_mode="values") 格式的单个事件 + + Args: + event_dict: 事件字典,格式为完整的 state dict + messages_key: state 中消息列表的 key + + Yields: + str (文本内容) 或 AgentResult (事件) + + Note: + 在 values 模式下,工具调用和结果可能在同一事件中或不同事件中。 + 我们只处理最后一条消息。 + """ + messages = event_dict.get(messages_key, []) + if not isinstance(messages, list): + return + + # 对于 values 模式,我们只关心最后一条消息(通常是最新的) + if not messages: + return + + last_msg = messages[-1] + msg_type = AgentRunConverter._get_message_type(last_msg) + + if msg_type == "ai": + content = AgentRunConverter._get_message_content(last_msg) + if content: + yield content + + # 工具调用 + for tc in AgentRunConverter._get_message_tool_calls(last_msg): + tc_id = tc.get("id", "") + tc_name = tc.get("name", "") + tc_args = tc.get("args", {}) + + if tc_id: + # 发送带有完整参数的 TOOL_CALL_CHUNK + args_str = "" + if tc_args: + args_str = ( + AgentRunConverter._safe_json_dumps(tc_args) + if isinstance(tc_args, dict) + else str(tc_args) + ) + yield AgentResult( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": tc_id, + "name": tc_name, + "args_delta": args_str, + }, + ) + + elif msg_type == "tool": + tool_call_id = AgentRunConverter._get_tool_call_id(last_msg) + if tool_call_id: + tool_content = AgentRunConverter._get_message_content(last_msg) + yield AgentResult( + event=EventType.TOOL_RESULT, + data={ + "id": tool_call_id, + "result": str(tool_content) if tool_content else "", + }, + ) + + @staticmethod + def _convert_astream_events_event( + event_dict: Dict[str, Any], + tool_call_id_map: Optional[Dict[int, str]] = None, + tool_call_started_set: Optional[set] = None, + tool_name_to_call_ids: Optional[Dict[str, List[str]]] = None, + run_id_to_tool_call_id: Optional[Dict[str, str]] = None, + ) -> Iterator[Union[AgentResult, str]]: + """转换 astream_events 格式的单个事件 + + Args: + event_dict: 事件字典,格式为 {"event": "on_xxx", "data": {...}} + tool_call_id_map: 可选的 index -> tool_call_id 映射字典。 + 在流式工具调用中,第一个 chunk 有 id,后续只有 index。 + 此映射用于确保所有 chunk 使用一致的 tool_call_id。 + tool_call_started_set: 可选的已发送 TOOL_CALL_START 的 tool_call_id 集合。 + 用于确保每个工具调用只发送一次 TOOL_CALL_START。 + tool_name_to_call_ids: 可选的 tool_name -> [tool_call_id] 队列映射。 + 用于在 on_tool_start 中查找对应的 tool_call_id。 + run_id_to_tool_call_id: 可选的 run_id -> tool_call_id 映射。 + 用于在 on_tool_end 中查找对应的 tool_call_id。 + + Yields: + str (文本内容) 或 AgentResult (事件) + """ + event_type = event_dict.get("event", "") + data = event_dict.get("data", {}) + + # 1. LangGraph 格式: on_chat_model_stream + if event_type == "on_chat_model_stream": + chunk = data.get("chunk") + if chunk: + # 文本内容 + content = AgentRunConverter._extract_content(chunk) + if content: + yield content + + # 流式工具调用参数 + for tc in AgentRunConverter._extract_tool_call_chunks(chunk): + tc_index = tc.get("index") + tc_raw_id = tc.get("id") + tc_name = tc.get("name", "") + tc_args = tc.get("args", "") + + # 解析 tool_call_id: + # 1. 如果有 id 且非空,使用它并更新映射 + # 2. 如果 id 为空但有 index,从映射中查找 + # 3. 最后回退到使用 index 字符串 + if tc_raw_id: + tc_id = tc_raw_id + # 更新映射(如果提供了映射字典) + # 重要:即使这个 chunk 没有 args,也要更新映射, + # 因为后续 chunk 可能只有 index 没有 id + if ( + tool_call_id_map is not None + and tc_index is not None + ): + tool_call_id_map[tc_index] = tc_id + elif tc_index is not None: + # 从映射中查找,如果没有则使用 index + if ( + tool_call_id_map is not None + and tc_index in tool_call_id_map + ): + tc_id = tool_call_id_map[tc_index] + else: + tc_id = str(tc_index) + else: + tc_id = "" + + if not tc_id: + continue + + # 流式工具调用:第一个 chunk 包含 id 和 name,后续只有 args_delta + # 协议层会自动处理 START/END 边界事件 + is_first_chunk = ( + tc_raw_id + and tc_name + and ( + tool_call_started_set is None + or tc_id not in tool_call_started_set + ) + ) + + if is_first_chunk: + if tool_call_started_set is not None: + tool_call_started_set.add(tc_id) + # 记录 tool_name -> tool_call_id 映射,用于 on_tool_start 查找 + if tool_name_to_call_ids is not None and tc_name: + if tc_name not in tool_name_to_call_ids: + tool_name_to_call_ids[tc_name] = [] + tool_name_to_call_ids[tc_name].append(tc_id) + # 第一个 chunk 包含 id 和 name + args_delta = "" + if tc_args: + args_delta = ( + AgentRunConverter._safe_json_dumps(tc_args) + if isinstance(tc_args, (dict, list)) + else str(tc_args) + ) + yield AgentResult( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": tc_id, + "name": tc_name, + "args_delta": args_delta, + }, + ) + elif tc_args: + # 后续 chunk 只有 args_delta + args_delta = ( + AgentRunConverter._safe_json_dumps(tc_args) + if isinstance(tc_args, (dict, list)) + else str(tc_args) + ) + yield AgentResult( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": tc_id, + "args_delta": args_delta, + }, + ) + + # 2. LangChain 格式: on_chain_stream + elif ( + event_type == "on_chain_stream" + and event_dict.get("name") == "model" + ): + chunk_data = data.get("chunk", {}) + if isinstance(chunk_data, dict): + messages = chunk_data.get("messages", []) + + for msg in messages: + content = AgentRunConverter._get_message_content(msg) + if content: + yield content + + for tc in AgentRunConverter._get_message_tool_calls(msg): + tc_id = tc.get("id", "") + tc_name = tc.get("name", "") + tc_args = tc.get("args", {}) + + if tc_id: + # 检查是否已经发送过这个 tool call + already_started = ( + tool_call_started_set is not None + and tc_id in tool_call_started_set + ) + + if not already_started: + # 标记为已开始,防止 on_tool_start 重复发送 + if tool_call_started_set is not None: + tool_call_started_set.add(tc_id) + + # 记录 tool_name -> tool_call_id 映射 + if ( + tool_name_to_call_ids is not None + and tc_name + ): + tool_name_to_call_ids.setdefault( + tc_name, [] + ).append(tc_id) + + args_delta = "" + if tc_args: + args_delta = ( + AgentRunConverter._safe_json_dumps( + tc_args + ) + if isinstance(tc_args, dict) + else str(tc_args) + ) + yield AgentResult( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": tc_id, + "name": tc_name, + "args_delta": args_delta, + }, + ) + + # 3. 工具开始 + elif event_type == "on_tool_start": + run_id = event_dict.get("run_id", "") + tool_name = event_dict.get("name", "") + tool_input_raw = data.get("input", {}) + # 优先使用 runtime 中的原始 tool_call_id,保证 ID 一致性 + tool_call_id = AgentRunConverter._extract_tool_call_id( + tool_input_raw + ) + + # 如果 runtime.tool_call_id 不可用,尝试从 tool_name_to_call_ids 映射中查找 + # 这用于处理非 MCP 工具的情况,其中 on_chat_model_stream 已经发送了 TOOL_CALL_START + if ( + not tool_call_id + and tool_name_to_call_ids is not None + and tool_name + ): + call_ids = tool_name_to_call_ids.get(tool_name, []) + if call_ids: + # 使用队列中的第一个 ID(FIFO),并从队列中移除 + tool_call_id = call_ids.pop(0) + + # 最后回退到 run_id + if not tool_call_id: + tool_call_id = run_id + + # 记录 run_id -> tool_call_id 映射,用于 on_tool_end 查找 + if run_id_to_tool_call_id is not None and run_id and tool_call_id: + run_id_to_tool_call_id[run_id] = tool_call_id + + # 过滤掉内部字段(如 MCP 注入的 runtime) + tool_input = AgentRunConverter._filter_tool_input(tool_input_raw) + + if tool_call_id: + # 检查是否已在 on_chat_model_stream 中发送过 + already_started = ( + tool_call_started_set is not None + and tool_call_id in tool_call_started_set + ) + + if not already_started: + # 非流式场景或未收到流式事件,发送完整的 TOOL_CALL_CHUNK + if tool_call_started_set is not None: + tool_call_started_set.add(tool_call_id) + + args_delta = "" + if tool_input: + args_delta = ( + AgentRunConverter._safe_json_dumps(tool_input) + if isinstance(tool_input, dict) + else str(tool_input) + ) + yield AgentResult( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": tool_call_id, + "name": tool_name, + "args_delta": args_delta, + }, + ) + # 协议层会自动处理边界事件,无需手动发送 TOOL_CALL_END + + # 4. 工具结束 + elif event_type == "on_tool_end": + run_id = event_dict.get("run_id", "") + output = data.get("output", "") + tool_input_raw = data.get("input", {}) + # 优先使用 runtime 中的原始 tool_call_id,保证 ID 一致性 + tool_call_id = AgentRunConverter._extract_tool_call_id( + tool_input_raw + ) + + # 如果 runtime.tool_call_id 不可用,尝试从 run_id_to_tool_call_id 映射中查找 + # 这个映射在 on_tool_start 中建立 + if ( + not tool_call_id + and run_id_to_tool_call_id is not None + and run_id + ): + tool_call_id = run_id_to_tool_call_id.get(run_id) + + # 最后回退到 run_id + if not tool_call_id: + tool_call_id = run_id + + if tool_call_id: + # 工具执行完成后发送结果 + yield AgentResult( + event=EventType.TOOL_RESULT, + data={ + "id": tool_call_id, + "result": AgentRunConverter._format_tool_output(output), + }, + ) + + # 5. LLM 结束 + elif event_type == "on_chat_model_end": + # 无状态模式下不处理,避免重复 + pass + + # 6. 工具错误 + elif event_type == "on_tool_error": + run_id = event_dict.get("run_id", "") + error = data.get("error") + tool_input_raw = data.get("input", {}) + tool_name = event_dict.get("name", "") + # 优先使用 runtime 中的原始 tool_call_id + tool_call_id = AgentRunConverter._extract_tool_call_id( + tool_input_raw + ) + + # 如果 runtime.tool_call_id 不可用,尝试从 run_id_to_tool_call_id 映射中查找 + if ( + not tool_call_id + and run_id_to_tool_call_id is not None + and run_id + ): + tool_call_id = run_id_to_tool_call_id.get(run_id) + + # 最后回退到 run_id + if not tool_call_id: + tool_call_id = run_id + + # 格式化错误信息 + error_message = "" + if error is not None: + if isinstance(error, Exception): + error_message = f"{type(error).__name__}: {str(error)}" + elif isinstance(error, str): + error_message = error + else: + error_message = str(error) + + # 发送 ERROR 事件 + yield AgentResult( + event=EventType.ERROR, + data={ + "message": ( + f"Tool '{tool_name}' error: {error_message}" + if tool_name + else error_message + ), + "code": "TOOL_ERROR", + "tool_call_id": tool_call_id, + }, + ) + + # 7. LLM 错误 + elif event_type == "on_llm_error": + error = data.get("error") + error_message = "" + if error is not None: + if isinstance(error, Exception): + error_message = f"{type(error).__name__}: {str(error)}" + elif isinstance(error, str): + error_message = error + else: + error_message = str(error) + + yield AgentResult( + event=EventType.ERROR, + data={ + "message": f"LLM error: {error_message}", + "code": "LLM_ERROR", + }, + ) + + # 8. Chain 错误 + elif event_type == "on_chain_error": + error = data.get("error") + chain_name = event_dict.get("name", "") + error_message = "" + if error is not None: + if isinstance(error, Exception): + error_message = f"{type(error).__name__}: {str(error)}" + elif isinstance(error, str): + error_message = error + else: + error_message = str(error) + + yield AgentResult( + event=EventType.ERROR, + data={ + "message": ( + f"Chain '{chain_name}' error: {error_message}" + if chain_name + else error_message + ), + "code": "CHAIN_ERROR", + }, + ) + + # 9. Retriever 错误 + elif event_type == "on_retriever_error": + error = data.get("error") + retriever_name = event_dict.get("name", "") + error_message = "" + if error is not None: + if isinstance(error, Exception): + error_message = f"{type(error).__name__}: {str(error)}" + elif isinstance(error, str): + error_message = error + else: + error_message = str(error) + + yield AgentResult( + event=EventType.ERROR, + data={ + "message": ( + f"Retriever '{retriever_name}' error: {error_message}" + if retriever_name + else error_message + ), + "code": "RETRIEVER_ERROR", + }, + ) + + # ========================================================================= + # 主要 API(静态方法) + # ========================================================================= + + @staticmethod + def to_agui_events( + event: Union[Dict[str, Any], Any], + messages_key: str = "messages", + tool_call_id_map: Optional[Dict[int, str]] = None, + tool_call_started_set: Optional[set] = None, + tool_name_to_call_ids: Optional[Dict[str, List[str]]] = None, + run_id_to_tool_call_id: Optional[Dict[str, str]] = None, + ) -> Iterator[Union[AgentResult, str]]: + """将 LangGraph/LangChain 流式事件转换为 AG-UI 协议事件 + + 支持多种调用方式产生的事件格式: + - agent.astream_events(input, version="v2") + - agent.stream(input, stream_mode="updates") + - agent.astream(input, stream_mode="updates") + - agent.stream(input, stream_mode="values") + - agent.astream(input, stream_mode="values") + + Args: + event: LangGraph/LangChain 流式事件(StreamEvent 对象或 Dict) + messages_key: state 中消息列表的 key,默认 "messages" + tool_call_id_map: 可选的 index -> tool_call_id 映射字典,用于流式工具调用 + 的 ID 一致性。如果提供,函数会自动更新此映射。 + tool_call_started_set: 可选的已发送 TOOL_CALL_START 的 tool_call_id 集合。 + 用于确保每个工具调用只发送一次 TOOL_CALL_START, + 并在正确的时机发送 TOOL_CALL_END。 + tool_name_to_call_ids: 可选的 tool_name -> [tool_call_id] 队列映射。 + 用于在 on_tool_start 中查找对应的 tool_call_id。 + run_id_to_tool_call_id: 可选的 run_id -> tool_call_id 映射。 + 用于在 on_tool_end 中查找对应的 tool_call_id。 + + Yields: + str (文本内容) 或 AgentResult (AG-UI 事件) + + Example: + >>> # 使用 astream_events(推荐使用 AgentRunConverter 类) + >>> async for event in agent.astream_events(input, version="v2"): + ... for item in AgentRunConverter.to_agui_events(event): + ... yield item + + >>> # 使用 stream (updates 模式) + >>> for event in agent.stream(input, stream_mode="updates"): + ... for item in AgentRunConverter.to_agui_events(event): + ... yield item + + >>> # 使用 astream (updates 模式) + >>> async for event in agent.astream(input, stream_mode="updates"): + ... for item in AgentRunConverter.to_agui_events(event): + ... yield item + """ + event_dict = AgentRunConverter._event_to_dict(event) + + # 根据事件格式选择对应的转换器 + if AgentRunConverter.is_astream_events_format(event_dict): + # astream_events 格式:{"event": "on_xxx", "data": {...}} + yield from AgentRunConverter._convert_astream_events_event( + event_dict, + tool_call_id_map, + tool_call_started_set, + tool_name_to_call_ids, + run_id_to_tool_call_id, + ) + + elif AgentRunConverter.is_stream_updates_format(event_dict): + # stream/astream(stream_mode="updates") 格式:{node_name: state_update} + yield from AgentRunConverter._convert_stream_updates_event( + event_dict, messages_key + ) + + elif AgentRunConverter.is_stream_values_format(event_dict): + # stream/astream(stream_mode="values") 格式:完整 state dict + yield from AgentRunConverter._convert_stream_values_event( + event_dict, messages_key + ) diff --git a/agentrun/server/__init__.py b/agentrun/server/__init__.py index f959e29..96ceac3 100644 --- a/agentrun/server/__init__.py +++ b/agentrun/server/__init__.py @@ -1,53 +1,147 @@ """AgentRun Server 模块 / AgentRun Server Module -提供 HTTP Server 集成能力,支持符合 AgentRun 规范的 Agent 调用接口。 +提供 HTTP Server 集成能力,支持符合 AgentRun 规范的 Agent 调用接口。 +支持 OpenAI Chat Completions 和 AG-UI 两种协议。 -Example (基本使用): ->>> from agentrun.server import AgentRunServer, AgentRequest, AgentResponse +Example (基本使用 - 返回字符串): +>>> from agentrun.server import AgentRunServer, AgentRequest >>> ->>> def invoke_agent(request: AgentRequest) -> AgentResponse: -... # 实现你的 Agent 逻辑 -... return AgentResponse(...) +>>> def invoke_agent(request: AgentRequest): +... return "Hello, world!" >>> >>> server = AgentRunServer(invoke_agent=invoke_agent) ->>> server.start(host="0.0.0.0", port=8080) +>>> server.start(port=9000) -Example (异步处理): ->>> async def invoke_agent(request: AgentRequest) -> AgentResponse: -... # 异步实现你的 Agent 逻辑 -... return AgentResponse(...) +Example (流式输出): +>>> def invoke_agent(request: AgentRequest): +... for word in ["Hello", ", ", "world", "!"]: +... yield word >>> ->>> server = AgentRunServer(invoke_agent=invoke_agent) ->>> server.start() +>>> AgentRunServer(invoke_agent=invoke_agent).start() + +Example (使用事件): +>>> from agentrun.server import AgentEvent, EventType +>>> +>>> async def invoke_agent(request: AgentRequest): +... # 发送自定义事件(如步骤开始) +... yield AgentEvent( +... event=EventType.CUSTOM, +... data={"name": "step_started", "value": {"step": "processing"}} +... ) +... +... # 流式输出内容 +... yield "Hello, " +... yield "world!" +... +... # 发送步骤结束事件 +... yield AgentEvent( +... event=EventType.CUSTOM, +... data={"name": "step_finished", "value": {"step": "processing"}} +... ) + +Example (工具调用事件): +>>> async def invoke_agent(request: AgentRequest): +... # 完整工具调用 +... yield AgentEvent( +... event=EventType.TOOL_CALL, +... data={"id": "call_1", "name": "get_time", "args": '{"timezone": "UTC"}'} +... ) +... +... # 执行工具 +... result = "2024-01-01 12:00:00" +... +... # 工具调用结果 +... yield AgentEvent( +... event=EventType.TOOL_RESULT, +... data={"id": "call_1", "result": result} +... ) +... +... yield f"当前时间: {result}" + +Example (流式工具输出): +>>> async def invoke_agent(request: AgentRequest): +... # 发起工具调用 +... yield AgentEvent( +... event=EventType.TOOL_CALL, +... data={"id": "call_1", "name": "run_code", "args": '{"code": "..."}'} +... ) +... +... # 流式输出执行过程 +... yield AgentEvent( +... event=EventType.TOOL_RESULT_CHUNK, +... data={"id": "call_1", "delta": "Step 1: Compiling...\\n"} +... ) +... yield AgentEvent( +... event=EventType.TOOL_RESULT_CHUNK, +... data={"id": "call_1", "delta": "Step 2: Running...\\n"} +... ) +... +... # 最终结果(标识流式输出结束) +... yield AgentEvent( +... event=EventType.TOOL_RESULT, +... data={"id": "call_1", "result": "Execution completed."} +... ) + +Example (HITL - 请求人类介入): +>>> async def invoke_agent(request: AgentRequest): +... # 请求用户确认 +... yield AgentEvent( +... event=EventType.HITL, +... data={ +... "id": "hitl_1", +... "tool_call_id": "call_delete", # 可选 +... "type": "confirmation", +... "prompt": "确认删除文件?", +... "options": ["确认", "取消"] +... } +... ) +... # 用户响应将通过下一轮对话的 messages 传回 -Example (流式响应): +Example (访问原始请求): >>> async def invoke_agent(request: AgentRequest): -... if request.stream: -... async def stream(): -... for chunk in generate_chunks(): -... yield AgentStreamResponse(...) -... return stream() -... return AgentResponse(...)""" +... # 访问当前协议 +... protocol = request.protocol # "openai" 或 "agui" +... +... # 访问原始请求头 +... auth = request.raw_request.headers.get("Authorization") +... +... # 访问查询参数 +... params = request.raw_request.query_params +... +... # 访问客户端 IP +... client_ip = request.raw_request.client.host if request.raw_request.client else None +... +... return "Hello, world!" +""" +from ..utils.helper import MergeOptions +from .agui_normalizer import AguiEventNormalizer +from .agui_protocol import AGUIProtocolHandler from .model import ( + AgentEvent, + AgentEventItem, AgentRequest, - AgentResponse, - AgentResponseChoice, - AgentResponseUsage, AgentResult, - AgentRunResult, - AgentStreamIterator, - AgentStreamResponse, - AgentStreamResponseChoice, - AgentStreamResponseDelta, + AgentResultItem, + AgentReturnType, + AGUIProtocolConfig, + AsyncAgentEventGenerator, + AsyncAgentResultGenerator, + EventType, Message, MessageRole, + OpenAIProtocolConfig, + ProtocolConfig, + ServerConfig, + SyncAgentEventGenerator, + SyncAgentResultGenerator, Tool, ToolCall, ) from .openai_protocol import OpenAIProtocolHandler from .protocol import ( AsyncInvokeAgentHandler, + BaseProtocolHandler, InvokeAgentHandler, ProtocolHandler, SyncInvokeAgentHandler, @@ -57,26 +151,41 @@ __all__ = [ # Server "AgentRunServer", + # Config + "ServerConfig", + "ProtocolConfig", + "OpenAIProtocolConfig", + "AGUIProtocolConfig", # Request/Response Models "AgentRequest", - "AgentResponse", - "AgentResponseChoice", - "AgentResponseUsage", - "AgentRunResult", - "AgentStreamResponse", - "AgentStreamResponseChoice", - "AgentStreamResponseDelta", + "AgentEvent", + "AgentResult", # 兼容别名 "Message", "MessageRole", "Tool", "ToolCall", + # Event Types + "EventType", # Type Aliases - "AgentResult", - "AgentStreamIterator", + "AgentEventItem", + "AgentResultItem", # 兼容别名 + "AgentReturnType", + "SyncAgentEventGenerator", + "SyncAgentResultGenerator", # 兼容别名 + "AsyncAgentEventGenerator", + "AsyncAgentResultGenerator", # 兼容别名 "InvokeAgentHandler", "AsyncInvokeAgentHandler", "SyncInvokeAgentHandler", - # Protocol + # Protocol Base "ProtocolHandler", + "BaseProtocolHandler", + # Protocol - OpenAI "OpenAIProtocolHandler", + # Protocol - AG-UI + "AGUIProtocolHandler", + # Event Normalizer + "AguiEventNormalizer", + # Helpers + "MergeOptions", ] diff --git a/agentrun/server/agui_normalizer.py b/agentrun/server/agui_normalizer.py new file mode 100644 index 0000000..e053a12 --- /dev/null +++ b/agentrun/server/agui_normalizer.py @@ -0,0 +1,180 @@ +"""AG-UI 事件规范化器 + +提供事件流规范化功能,确保事件符合 AG-UI 协议的顺序要求。 + +主要功能: +- 追踪工具调用状态 +- 在 TOOL_RESULT 前确保工具调用已开始 +- 自动补充缺失的状态 + +注意:边界事件(如 TEXT_MESSAGE_START/END、TOOL_CALL_START/END) +由协议层(agui_protocol.py)自动生成,不需要用户关心。 + +使用示例: + + >>> from agentrun.server.agui_normalizer import AguiEventNormalizer + >>> + >>> normalizer = AguiEventNormalizer() + >>> for event in raw_events: + ... for normalized_event in normalizer.normalize(event): + ... yield normalized_event +""" + +from typing import Any, Dict, Iterator, List, Optional, Set, Union + +from .model import AgentEvent, EventType + + +class AguiEventNormalizer: + """AG-UI 事件规范化器 + + 追踪工具调用状态,确保事件顺序正确: + 1. 追踪已开始的工具调用 + 2. 确保 TOOL_RESULT 前工具调用存在 + + 协议层会自动处理边界事件(START/END),这个类主要用于 + 高级用户需要手动控制事件流时。 + + Example: + >>> normalizer = AguiEventNormalizer() + >>> for event in agent_events: + ... for normalized in normalizer.normalize(event): + ... yield normalized + """ + + def __init__(self): + # 已看到的工具调用 ID 集合 + self._seen_tool_calls: Set[str] = set() + # 活跃的工具调用信息(tool_call_id -> tool_call_name) + self._active_tool_calls: Dict[str, str] = {} + + def normalize( + self, + event: Union[AgentEvent, str, Dict[str, Any]], + ) -> Iterator[AgentEvent]: + """规范化单个事件 + + 将事件标准化为 AgentEvent,并追踪工具调用状态。 + + Args: + event: 原始事件(AgentEvent、str 或 dict) + + Yields: + 规范化后的事件 + """ + # 将事件标准化为 AgentEvent + normalized_event = self._to_agent_event(event) + if normalized_event is None: + return + + # 根据事件类型进行处理 + event_type = normalized_event.event + + if event_type == EventType.TOOL_CALL_CHUNK: + yield from self._handle_tool_call_chunk(normalized_event) + + elif event_type == EventType.TOOL_CALL: + yield from self._handle_tool_call(normalized_event) + + elif event_type == EventType.TOOL_RESULT: + yield from self._handle_tool_result(normalized_event) + + else: + # 其他事件类型直接传递 + yield normalized_event + + def _to_agent_event( + self, event: Union[AgentEvent, str, Dict[str, Any]] + ) -> Optional[AgentEvent]: + """将事件转换为 AgentEvent""" + if isinstance(event, AgentEvent): + return event + + if isinstance(event, str): + # 字符串转为 TEXT + return AgentEvent( + event=EventType.TEXT, + data={"delta": event}, + ) + + if isinstance(event, dict): + event_type = event.get("event") + if event_type is None: + return None + + # 尝试解析 event_type + if isinstance(event_type, str): + try: + event_type = EventType(event_type) + except ValueError: + try: + event_type = EventType[event_type] + except KeyError: + return None + + return AgentEvent( + event=event_type, + data=event.get("data", {}), + ) + + return None + + def _handle_tool_call(self, event: AgentEvent) -> Iterator[AgentEvent]: + """处理 TOOL_CALL 事件 + + 记录工具调用并直接传递 + """ + tool_call_id = event.data.get("id", "") + tool_call_name = event.data.get("name", "") + + if tool_call_id: + self._seen_tool_calls.add(tool_call_id) + self._active_tool_calls[tool_call_id] = tool_call_name + + yield event + + def _handle_tool_call_chunk( + self, event: AgentEvent + ) -> Iterator[AgentEvent]: + """处理 TOOL_CALL_CHUNK 事件 + + 记录工具调用并直接传递 + """ + tool_call_id = event.data.get("id", "") + tool_call_name = event.data.get("name", "") + + if tool_call_id: + self._seen_tool_calls.add(tool_call_id) + if tool_call_name: + self._active_tool_calls[tool_call_id] = tool_call_name + + yield event + + def _handle_tool_result(self, event: AgentEvent) -> Iterator[AgentEvent]: + """处理 TOOL_RESULT 事件 + + 标记工具调用完成 + """ + tool_call_id = event.data.get("id", "") + + if tool_call_id: + # 标记工具调用已完成(从活跃列表移除) + self._active_tool_calls.pop(tool_call_id, None) + + yield event + + def get_active_tool_calls(self) -> List[str]: + """获取当前活跃(未结束)的工具调用 ID 列表""" + return list(self._active_tool_calls.keys()) + + def get_seen_tool_calls(self) -> List[str]: + """获取所有已见过的工具调用 ID 列表""" + return list(self._seen_tool_calls) + + def reset(self): + """重置状态 + + 在处理新的请求时,建议创建新的实例而不是复用。 + """ + self._seen_tool_calls.clear() + self._active_tool_calls.clear() diff --git a/agentrun/server/agui_protocol.py b/agentrun/server/agui_protocol.py new file mode 100644 index 0000000..51a88a6 --- /dev/null +++ b/agentrun/server/agui_protocol.py @@ -0,0 +1,1039 @@ +"""AG-UI (Agent-User Interaction Protocol) 协议实现 + +AG-UI 是一种开源、轻量级、基于事件的协议,用于标准化 AI Agent 与前端应用之间的交互。 +参考: https://docs.ag-ui.com/ + +本实现使用 ag-ui-protocol 包提供的事件类型和编码器, +将 AgentResult 事件转换为 AG-UI SSE 格式。 +""" + +from dataclasses import dataclass, field +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + TYPE_CHECKING, +) +import uuid + +from ag_ui.core import AssistantMessage +from ag_ui.core import CustomEvent as AguiCustomEvent +from ag_ui.core import EventType as AguiEventType +from ag_ui.core import Message as AguiMessage +from ag_ui.core import MessagesSnapshotEvent +from ag_ui.core import RawEvent as AguiRawEvent +from ag_ui.core import ( + RunErrorEvent, + RunFinishedEvent, + RunStartedEvent, + StateDeltaEvent, + StateSnapshotEvent, + StepFinishedEvent, + StepStartedEvent, + SystemMessage, + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, +) +from ag_ui.core import Tool as AguiTool +from ag_ui.core import ToolCall as AguiToolCall +from ag_ui.core import ( + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, +) +from ag_ui.core import ToolMessage as AguiToolMessage +from ag_ui.core import UserMessage +from ag_ui.encoder import EventEncoder +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse +import pydash + +from ..utils.helper import merge, MergeOptions +from .model import ( + AgentEvent, + AgentRequest, + EventType, + Message, + MessageRole, + ServerConfig, + Tool, + ToolCall, +) +from .protocol import BaseProtocolHandler + +if TYPE_CHECKING: + from .invoker import AgentInvoker + + +# ============================================================================ +# AG-UI 协议处理器 +# ============================================================================ + +DEFAULT_PREFIX = "/ag-ui/agent" + + +@dataclass +class TextState: + started: bool = False + ended: bool = False + message_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class ToolCallState: + name: str = "" + started: bool = False + ended: bool = False + has_result: bool = False + is_hitl: bool = False + + +@dataclass +class StreamStateMachine: + copilotkit_compatibility: bool + text: TextState = field(default_factory=TextState) + tool_call_states: Dict[str, ToolCallState] = field(default_factory=dict) + tool_result_chunks: Dict[str, List[str]] = field(default_factory=dict) + uuid_to_tool_call_id: Dict[str, str] = field(default_factory=dict) + run_errored: bool = False + active_tool_id: Optional[str] = None + pending_events: List["AgentEvent"] = field(default_factory=list) + + @staticmethod + def _is_uuid_like(value: Optional[str]) -> bool: + if not value: + return False + try: + uuid.UUID(str(value)) + return True + except (ValueError, TypeError, AttributeError): + return False + + def resolve_tool_id(self, tool_id: str, tool_name: str) -> str: + """将 UUID 形式的 ID 映射到已有的 call_xxx ID,避免模糊匹配。""" + if not tool_id: + return "" + if not self._is_uuid_like(tool_id): + return tool_id + if tool_id in self.uuid_to_tool_call_id: + return self.uuid_to_tool_call_id[tool_id] + + candidates = [ + existing_id + for existing_id, state in self.tool_call_states.items() + if not self._is_uuid_like(existing_id) + and state.started + and (state.name == tool_name or not tool_name) + ] + if len(candidates) == 1: + self.uuid_to_tool_call_id[tool_id] = candidates[0] + return candidates[0] + return tool_id + + def end_all_tools( + self, encoder: EventEncoder, exclude: Optional[str] = None + ) -> Iterator[str]: + for tool_id, state in self.tool_call_states.items(): + if exclude and tool_id == exclude: + continue + if state.started and not state.ended: + yield encoder.encode(ToolCallEndEvent(tool_call_id=tool_id)) + state.ended = True + + def ensure_text_started(self, encoder: EventEncoder) -> Iterator[str]: + if not self.text.started or self.text.ended: + if self.text.ended: + self.text = TextState() + yield encoder.encode( + TextMessageStartEvent( + message_id=self.text.message_id, + role="assistant", + ) + ) + self.text.started = True + self.text.ended = False + + def end_text_if_open(self, encoder: EventEncoder) -> Iterator[str]: + if self.text.started and not self.text.ended: + yield encoder.encode( + TextMessageEndEvent(message_id=self.text.message_id) + ) + self.text.ended = True + + def cache_tool_result_chunk(self, tool_id: str, delta: str) -> None: + if not tool_id or delta is None: + return + if delta: + self.tool_result_chunks.setdefault(tool_id, []).append(delta) + + def pop_tool_result_chunks(self, tool_id: str) -> str: + return "".join(self.tool_result_chunks.pop(tool_id, [])) + + +class AGUIProtocolHandler(BaseProtocolHandler): + """AG-UI 协议处理器 + + 实现 AG-UI (Agent-User Interaction Protocol) 兼容接口。 + 参考: https://docs.ag-ui.com/ + + 使用 ag-ui-protocol 包提供的事件类型和编码器。 + + 特点: + - 基于事件的流式通信 + - 完整支持所有 AG-UI 事件类型 + - 支持状态同步 + - 支持工具调用 + + Example: + >>> from agentrun.server import AgentRunServer, AGUIProtocolHandler + >>> + >>> server = AgentRunServer( + ... invoke_agent=my_agent, + ... protocols=[AGUIProtocolHandler()] + ... ) + >>> server.start(port=8000) + # 可访问: POST http://localhost:8000/ag-ui/agent + """ + + name = "ag-ui" + + def __init__(self, config: Optional[ServerConfig] = None): + self._config = config.agui if config else None + self._encoder = EventEncoder() + # 是否串行化工具调用(兼容 CopilotKit 等前端) + self._copilotkit_compatibility = pydash.get( + self._config, "copilotkit_compatibility", False + ) + + def get_prefix(self) -> str: + """AG-UI 协议建议使用 /ag-ui/agent 前缀""" + return pydash.get(self._config, "prefix", DEFAULT_PREFIX) + + def as_fastapi_router(self, agent_invoker: "AgentInvoker") -> APIRouter: + """创建 AG-UI 协议的 FastAPI Router""" + router = APIRouter() + + @router.post("") + async def run_agent(request: Request): + """AG-UI 运行 Agent 端点 + + 接收 AG-UI 格式的请求,返回 SSE 事件流。 + """ + sse_headers = { + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + } + + try: + request_data = await request.json() + agent_request, context = await self.parse_request( + request, request_data + ) + + # 使用 invoke_stream 获取流式结果 + event_stream = self._format_stream( + agent_invoker.invoke_stream(agent_request), + context, + ) + + return StreamingResponse( + event_stream, + media_type=self._encoder.get_content_type(), + headers=sse_headers, + ) + + except ValueError as e: + return StreamingResponse( + self._error_stream(str(e)), + media_type=self._encoder.get_content_type(), + headers=sse_headers, + ) + except Exception as e: + return StreamingResponse( + self._error_stream(f"Internal error: {str(e)}"), + media_type=self._encoder.get_content_type(), + headers=sse_headers, + ) + + @router.get("/health") + async def health_check(): + """健康检查端点""" + return {"status": "ok", "protocol": "ag-ui", "version": "1.0"} + + return router + + async def parse_request( + self, + request: Request, + request_data: Dict[str, Any], + ) -> tuple[AgentRequest, Dict[str, Any]]: + """解析 AG-UI 格式的请求 + + Args: + request: FastAPI Request 对象 + request_data: HTTP 请求体 JSON 数据 + + Returns: + tuple: (AgentRequest, context) + """ + # 创建上下文 + context = { + "thread_id": request_data.get("threadId") or str(uuid.uuid4()), + "run_id": request_data.get("runId") or str(uuid.uuid4()), + } + + # 解析消息列表 + messages = self._parse_messages(request_data.get("messages", [])) + + # 解析工具列表 + tools = self._parse_tools(request_data.get("tools")) + + # 构建 AgentRequest + agent_request = AgentRequest( + protocol="agui", # 设置协议名称 + messages=messages, + stream=True, # AG-UI 总是流式 + tools=tools, + raw_request=request, # 保留原始请求对象 + ) + + return agent_request, context + + def _parse_messages( + self, raw_messages: List[Dict[str, Any]] + ) -> List[Message]: + """解析消息列表 + + Args: + raw_messages: 原始消息数据 + + Returns: + 标准化的消息列表 + """ + messages = [] + + for msg_data in raw_messages: + if not isinstance(msg_data, dict): + continue + + role_str = msg_data.get("role", "user") + try: + role = MessageRole(role_str) + except ValueError: + role = MessageRole.USER + + # 解析 tool_calls + tool_calls = None + if msg_data.get("toolCalls"): + tool_calls = [ + ToolCall( + id=tc.get("id", ""), + type=tc.get("type", "function"), + function=tc.get("function", {}), + ) + for tc in msg_data["toolCalls"] + ] + + messages.append( + Message( + id=msg_data.get("id"), + role=role, + content=msg_data.get("content"), + name=msg_data.get("name"), + tool_calls=tool_calls, + tool_call_id=msg_data.get("toolCallId"), + ) + ) + + return messages + + def _parse_tools( + self, raw_tools: Optional[List[Dict[str, Any]]] + ) -> Optional[List[Tool]]: + """解析工具列表 + + Args: + raw_tools: 原始工具数据 + + Returns: + 标准化的工具列表 + """ + if not raw_tools: + return None + + tools = [] + for tool_data in raw_tools: + if not isinstance(tool_data, dict): + continue + + tools.append( + Tool( + type=tool_data.get("type", "function"), + function=tool_data.get("function", {}), + ) + ) + + return tools if tools else None + + async def _format_stream( + self, + event_stream: AsyncIterator[AgentEvent], + context: Dict[str, Any], + ) -> AsyncIterator[str]: + """将 AgentEvent 流转换为 AG-UI SSE 格式 + + 自动生成边界事件: + - RUN_STARTED / RUN_FINISHED(生命周期) + - TEXT_MESSAGE_START / TEXT_MESSAGE_END(文本边界) + - TOOL_CALL_START / TOOL_CALL_END(工具调用边界) + + 注意:RUN_ERROR 之后不能再发送任何事件(包括 RUN_FINISHED) + + Args: + event_stream: AgentEvent 流 + context: 上下文信息 + + Yields: + SSE 格式的字符串 + """ + state = StreamStateMachine( + copilotkit_compatibility=self._copilotkit_compatibility + ) + + # 发送 RUN_STARTED + yield self._encoder.encode( + RunStartedEvent( + thread_id=context.get("thread_id"), + run_id=context.get("run_id"), + ) + ) + + # 辅助函数:处理队列中的所有事件 + def process_pending_queue() -> Iterator[str]: + """处理队列中的所有待处理事件""" + while state.pending_events: + pending_event = state.pending_events.pop(0) + pending_tool_id = ( + pending_event.data.get("id", "") + if pending_event.data + else "" + ) + + # 如果是新的工具调用,设置为活跃 + if ( + pending_event.event == EventType.TOOL_CALL_CHUNK + or pending_event.event == EventType.TOOL_CALL + ) and state.active_tool_id is None: + state.active_tool_id = pending_tool_id + + for sse_data in self._process_event_with_boundaries( + pending_event, + context, + state, + ): + if sse_data: + yield sse_data + + # 如果处理的是 TOOL_RESULT,检查是否需要继续处理队列 + if pending_event.event == EventType.TOOL_RESULT: + if pending_tool_id == state.active_tool_id: + state.active_tool_id = None + + async for event in event_stream: + # RUN_ERROR 后不再处理任何事件 + if state.run_errored: + continue + + # 检查是否是错误事件 + if event.event == EventType.ERROR: + state.run_errored = True + + # 在 copilotkit_compatibility=True 模式下,实现严格的工具调用序列化 + # 当一个工具调用正在进行时,其他工具的事件会被放入队列 + if self._copilotkit_compatibility and not state.run_errored: + original_tool_id = ( + event.data.get("id", "") if event.data else "" + ) + tool_name = event.data.get("name", "") if event.data else "" + resolved_tool_id = state.resolve_tool_id( + original_tool_id, tool_name + ) + if resolved_tool_id and event.data is not None: + event.data["id"] = resolved_tool_id + tool_id = resolved_tool_id + else: + tool_id = original_tool_id + + # 处理 TOOL_CALL_CHUNK 事件 + if event.event == EventType.TOOL_CALL_CHUNK: + if state.active_tool_id is None: + # 没有活跃的工具调用,直接处理 + state.active_tool_id = tool_id + elif tool_id != state.active_tool_id: + # 有其他活跃的工具调用,放入队列 + state.pending_events.append(event) + continue + # 如果是同一个工具调用,继续处理 + + # 处理 TOOL_CALL 事件 + elif event.event == EventType.TOOL_CALL: + # TOOL_CALL 事件主要用于 UUID 到 call_xxx ID 的映射 + # 在 copilotkit 模式下: + # 1. 结束当前活跃的工具调用(如果有) + # 2. 处理队列中的事件 + # 3. 不将 TOOL_CALL 事件的 tool_id 设置为活跃工具 + if self._copilotkit_compatibility: + if state.active_tool_id is not None: + for sse_data in state.end_all_tools(self._encoder): + yield sse_data + state.active_tool_id = None + # 处理队列中的事件 + if state.pending_events: + for sse_data in process_pending_queue(): + yield sse_data + + # 处理 TOOL_RESULT 事件 + elif event.event == EventType.TOOL_RESULT: + actual_tool_id = resolved_tool_id or tool_id + + # 如果不是当前活跃工具的结果,放入队列 + if ( + state.active_tool_id is not None + and actual_tool_id != state.active_tool_id + ): + state.pending_events.append(event) + continue + + # 标记工具调用已有结果 + if ( + actual_tool_id + and actual_tool_id in state.tool_call_states + ): + state.tool_call_states[actual_tool_id].has_result = True + + # 处理当前事件 + for sse_data in self._process_event_with_boundaries( + event, + context, + state, + ): + if sse_data: + yield sse_data + + # 如果这是当前活跃工具的结果,处理队列中的事件 + if actual_tool_id == state.active_tool_id: + state.active_tool_id = None + # 处理队列中的事件 + for sse_data in process_pending_queue(): + yield sse_data + continue + + # 处理非工具相关事件(如 TEXT) + # 需要先处理队列中的所有事件 + elif event.event == EventType.TEXT: + # 先处理队列中的所有事件 + for sse_data in process_pending_queue(): + yield sse_data + # 清除活跃工具 ID(因为我们要处理文本了) + state.active_tool_id = None + + # 处理边界事件注入 + for sse_data in self._process_event_with_boundaries( + event, + context, + state, + ): + if sse_data: + yield sse_data + + # 在 copilotkit 兼容模式下,如果当前没有活跃工具且队列中有事件,处理队列 + if ( + self._copilotkit_compatibility + and state.active_tool_id is None + and state.pending_events + ): + for sse_data in process_pending_queue(): + yield sse_data + + # RUN_ERROR 后不发送任何清理事件 + if state.run_errored: + return + + # 结束所有未结束的工具调用 + for sse_data in state.end_all_tools(self._encoder): + yield sse_data + + # 发送 TEXT_MESSAGE_END(如果有文本消息且未结束) + for sse_data in state.end_text_if_open(self._encoder): + yield sse_data + + # 发送 RUN_FINISHED + yield self._encoder.encode( + RunFinishedEvent( + thread_id=context.get("thread_id"), + run_id=context.get("run_id"), + ) + ) + + def _process_event_with_boundaries( + self, + event: AgentEvent, + context: Dict[str, Any], + state: StreamStateMachine, + ) -> Iterator[str]: + """处理事件并注入边界事件""" + import json + + # RAW 事件直接透传 + if event.event == EventType.RAW: + raw_data = event.data.get("raw", "") + if raw_data: + if not raw_data.endswith("\n\n"): + raw_data = raw_data.rstrip("\n") + "\n\n" + yield raw_data + return + + # TEXT 事件:在首个 TEXT 前注入 TEXT_MESSAGE_START + # AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先结束所有未结束的 TOOL_CALL + if event.event == EventType.TEXT: + for sse_data in state.end_all_tools(self._encoder): + yield sse_data + + for sse_data in state.ensure_text_started(self._encoder): + yield sse_data + + agui_event = TextMessageContentEvent( + message_id=state.text.message_id, + delta=event.data.get("delta", ""), + ) + if event.addition: + event_dict = agui_event.model_dump( + by_alias=True, exclude_none=True + ) + event_dict = self._apply_addition( + event_dict, + event.addition, + event.addition_merge_options, + ) + json_str = json.dumps(event_dict, ensure_ascii=False) + yield f"data: {json_str}\n\n" + else: + yield self._encoder.encode(agui_event) + return + + # TOOL_CALL_CHUNK 事件:在首个 CHUNK 前注入 TOOL_CALL_START + if event.event == EventType.TOOL_CALL_CHUNK: + tool_id_raw = event.data.get("id", "") + tool_name = event.data.get("name", "") + resolved_tool_id = state.resolve_tool_id(tool_id_raw, tool_name) + tool_id = resolved_tool_id or tool_id_raw + if tool_id and event.data is not None: + event.data["id"] = tool_id + + for sse_data in state.end_text_if_open(self._encoder): + yield sse_data + + if ( + state.copilotkit_compatibility + and state._is_uuid_like(tool_id_raw) + and tool_name + ): + for existing_id, call_state in state.tool_call_states.items(): + if ( + not state._is_uuid_like(existing_id) + and call_state.name == tool_name + and call_state.started + ): + if not call_state.ended: + args_delta = event.data.get("args_delta", "") + if args_delta: + yield self._encoder.encode( + ToolCallArgsEvent( + tool_call_id=existing_id, + delta=args_delta, + ) + ) + return + + need_start = False + current_state = state.tool_call_states.get(tool_id) + if tool_id: + if current_state is None or current_state.ended: + need_start = True + + if need_start: + if state.copilotkit_compatibility: + for sse_data in state.end_all_tools(self._encoder): + yield sse_data + + yield self._encoder.encode( + ToolCallStartEvent( + tool_call_id=tool_id, + tool_call_name=tool_name, + ) + ) + state.tool_call_states[tool_id] = ToolCallState( + name=tool_name, + started=True, + ended=False, + ) + + yield self._encoder.encode( + ToolCallArgsEvent( + tool_call_id=tool_id, + delta=event.data.get("args_delta", ""), + ) + ) + return + + # TOOL_CALL 事件:完整的工具调用事件 + if event.event == EventType.TOOL_CALL: + tool_id_raw = event.data.get("id", "") + tool_name = event.data.get("name", "") + tool_args = event.data.get("args", "") + resolved_tool_id = state.resolve_tool_id(tool_id_raw, tool_name) + tool_id = resolved_tool_id or tool_id_raw + if tool_id and event.data is not None: + event.data["id"] = tool_id + + for sse_data in state.end_text_if_open(self._encoder): + yield sse_data + + # 在 CopilotKit 兼容模式下,检查 UUID 映射 + if ( + state.copilotkit_compatibility + and state._is_uuid_like(tool_id_raw) + and tool_name + ): + for existing_id, call_state in state.tool_call_states.items(): + if ( + not state._is_uuid_like(existing_id) + and call_state.name == tool_name + and call_state.started + ): + if not call_state.ended: + # UUID 事件可能包含参数,发送参数事件 + if tool_args: + yield self._encoder.encode( + ToolCallArgsEvent( + tool_call_id=existing_id, + delta=tool_args, + ) + ) + return # UUID 事件已完成处理,不创建新的工具调用 + + need_start = False + current_state = state.tool_call_states.get(tool_id) + if tool_id: + if current_state is None or current_state.ended: + need_start = True + + if need_start: + if state.copilotkit_compatibility: + for sse_data in state.end_all_tools(self._encoder): + yield sse_data + + yield self._encoder.encode( + ToolCallStartEvent( + tool_call_id=tool_id, + tool_call_name=tool_name, + ) + ) + state.tool_call_states[tool_id] = ToolCallState( + name=tool_name, + started=True, + ended=False, + ) + + # 发送工具参数(如果存在) + if tool_args: + yield self._encoder.encode( + ToolCallArgsEvent( + tool_call_id=tool_id, + delta=tool_args, + ) + ) + return + + # TOOL_RESULT_CHUNK 事件:工具执行过程中的流式输出 + if event.event == EventType.TOOL_RESULT_CHUNK: + tool_id = event.data.get("id", "") + delta = event.data.get("delta", "") + state.cache_tool_result_chunk(tool_id, delta) + return + + # HITL 事件:请求人类介入 + if event.event == EventType.HITL: + hitl_id = event.data.get("id", "") + tool_call_id = event.data.get("tool_call_id", "") + hitl_type = event.data.get("type", "confirmation") + prompt = event.data.get("prompt", "") + options = event.data.get("options") + default = event.data.get("default") + timeout = event.data.get("timeout") + schema = event.data.get("schema") + + for sse_data in state.end_text_if_open(self._encoder): + yield sse_data + + if tool_call_id and tool_call_id in state.tool_call_states: + tool_state = state.tool_call_states[tool_call_id] + if tool_state.started and not tool_state.ended: + yield self._encoder.encode( + ToolCallEndEvent(tool_call_id=tool_call_id) + ) + tool_state.ended = True + tool_state.is_hitl = True + tool_state.has_result = False + return + + import json as json_module + + args_dict: Dict[str, Any] = { + "type": hitl_type, + "prompt": prompt, + } + if options: + args_dict["options"] = options + if default is not None: + args_dict["default"] = default + if timeout is not None: + args_dict["timeout"] = timeout + if schema: + args_dict["schema"] = schema + + args_json = json_module.dumps(args_dict, ensure_ascii=False) + actual_id = tool_call_id or hitl_id + + yield self._encoder.encode( + ToolCallStartEvent( + tool_call_id=actual_id, + tool_call_name=f"hitl_{hitl_type}", + ) + ) + yield self._encoder.encode( + ToolCallArgsEvent( + tool_call_id=actual_id, + delta=args_json, + ) + ) + yield self._encoder.encode(ToolCallEndEvent(tool_call_id=actual_id)) + + state.tool_call_states[actual_id] = ToolCallState( + name=f"hitl_{hitl_type}", + started=True, + ended=True, + has_result=False, + is_hitl=True, + ) + return + + # TOOL_RESULT 事件:确保当前工具调用已结束 + if event.event == EventType.TOOL_RESULT: + tool_id = event.data.get("id", "") + tool_name = event.data.get("name", "") + actual_tool_id = ( + state.resolve_tool_id(tool_id, tool_name) + if state.copilotkit_compatibility + else tool_id + ) + if actual_tool_id and event.data is not None: + event.data["id"] = actual_tool_id + + for sse_data in state.end_text_if_open(self._encoder): + yield sse_data + + if state.copilotkit_compatibility: + for sse_data in state.end_all_tools( + self._encoder, exclude=actual_tool_id + ): + yield sse_data + + tool_state = ( + state.tool_call_states.get(actual_tool_id) + if actual_tool_id + else None + ) + if actual_tool_id and tool_state is None: + yield self._encoder.encode( + ToolCallStartEvent( + tool_call_id=actual_tool_id, + tool_call_name=tool_name or "", + ) + ) + tool_state = ToolCallState( + name=tool_name, started=True, ended=False + ) + state.tool_call_states[actual_tool_id] = tool_state + + if tool_state and tool_state.started and not tool_state.ended: + yield self._encoder.encode( + ToolCallEndEvent(tool_call_id=actual_tool_id) + ) + tool_state.ended = True + + final_result = event.data.get("content") or event.data.get( + "result", "" + ) + if actual_tool_id: + cached_chunks = state.pop_tool_result_chunks(actual_tool_id) + if cached_chunks: + final_result = cached_chunks + final_result + + yield self._encoder.encode( + ToolCallResultEvent( + message_id=event.data.get( + "message_id", f"tool-result-{actual_tool_id}" + ), + tool_call_id=actual_tool_id, + content=final_result, + role="tool", + ) + ) + return + + # ERROR 事件 + if event.event == EventType.ERROR: + yield self._encoder.encode( + RunErrorEvent( + message=event.data.get("message", ""), + code=event.data.get("code"), + ) + ) + return + + # STATE 事件 + if event.event == EventType.STATE: + if "snapshot" in event.data: + yield self._encoder.encode( + StateSnapshotEvent(snapshot=event.data.get("snapshot", {})) + ) + elif "delta" in event.data: + yield self._encoder.encode( + StateDeltaEvent(delta=event.data.get("delta", [])) + ) + else: + yield self._encoder.encode( + StateSnapshotEvent(snapshot=event.data) + ) + return + + # CUSTOM 事件 + if event.event == EventType.CUSTOM: + yield self._encoder.encode( + AguiCustomEvent( + name=event.data.get("name", "custom"), + value=event.data.get("value"), + ) + ) + return + + # 其他未知事件 + event_name = ( + event.event.value + if hasattr(event.event, "value") + else str(event.event) + ) + yield self._encoder.encode( + AguiCustomEvent( + name=event_name, + value=event.data, + ) + ) + + def _convert_messages_for_snapshot( + self, messages: List[Dict[str, Any]] + ) -> List[AguiMessage]: + """将消息列表转换为 ag-ui-protocol 格式 + + Args: + messages: 消息字典列表 + + Returns: + ag-ui-protocol 消息列表 + """ + result = [] + for msg in messages: + if not isinstance(msg, dict): + continue + + role = msg.get("role", "user") + content = msg.get("content", "") + msg_id = msg.get("id", str(uuid.uuid4())) + + if role == "user": + result.append( + UserMessage(id=msg_id, role="user", content=content) + ) + elif role == "assistant": + result.append( + AssistantMessage( + id=msg_id, + role="assistant", + content=content, + ) + ) + elif role == "system": + result.append( + SystemMessage(id=msg_id, role="system", content=content) + ) + elif role == "tool": + result.append( + AguiToolMessage( + id=msg_id, + role="tool", + content=content, + tool_call_id=msg.get("tool_call_id", ""), + ) + ) + + return result + + def _apply_addition( + self, + event_data: Dict[str, Any], + addition: Optional[Dict[str, Any]], + merge_options: Optional[MergeOptions] = None, + ) -> Dict[str, Any]: + """应用 addition 字段 + + Args: + event_data: 原始事件数据 + addition: 附加字段 + merge_options: 合并选项,透传给 utils.helper.merge + + Returns: + 合并后的事件数据 + """ + if not addition: + return event_data + + return merge(event_data, addition, **(merge_options or {})) + + async def _error_stream(self, message: str) -> AsyncIterator[str]: + """生成错误事件流 + + Args: + message: 错误消息 + + Yields: + SSE 格式的错误事件 + """ + thread_id = str(uuid.uuid4()) + run_id = str(uuid.uuid4()) + + # 生命周期开始 + yield self._encoder.encode( + RunStartedEvent(thread_id=thread_id, run_id=run_id) + ) + + # 错误事件 + yield self._encoder.encode( + RunErrorEvent(message=message, code="REQUEST_ERROR") + ) diff --git a/agentrun/server/invoker.py b/agentrun/server/invoker.py index 9664365..438c509 100644 --- a/agentrun/server/invoker.py +++ b/agentrun/server/invoker.py @@ -1,14 +1,29 @@ """Agent 调用器 / Agent Invoker -负责处理 Agent 调用的通用逻辑。 -Handles common logic for agent invocations. +负责处理 Agent 调用的通用逻辑,包括: +- 同步/异步调用处理 +- 字符串到 AgentEvent 的自动转换 +- 流式/非流式结果处理 +- TOOL_CALL 事件的展开 + +边界事件(如生命周期开始/结束、文本消息开始/结束)由协议层处理。 """ import asyncio import inspect -from typing import cast +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Awaitable, + cast, + Iterator, + List, + Union, +) +import uuid -from .model import AgentRequest, AgentResult, AgentRunResult +from .model import AgentEvent, AgentRequest, EventType from .protocol import ( AsyncInvokeAgentHandler, InvokeAgentHandler, @@ -22,110 +37,273 @@ class AgentInvoker: 职责: 1. 调用用户的 invoke_agent 2. 处理同步/异步调用 - 3. 自动转换 string/string迭代器为 AgentRunResult - 4. 错误处理 + 3. 自动转换 string 为 AgentEvent(TEXT) + 4. 展开 TOOL_CALL 为 TOOL_CALL_CHUNK + 5. 处理流式和非流式返回 + + 协议层负责: + - 生成生命周期事件(RUN_STARTED, RUN_FINISHED 等) + - 生成文本边界事件(TEXT_MESSAGE_START, TEXT_MESSAGE_END 等) + - 生成工具调用边界事件(TOOL_CALL_START, TOOL_CALL_END 等) Example: >>> def my_agent(request: AgentRequest) -> str: - ... return "Hello" # 自动转换为 AgentRunResult + ... return "Hello" # 自动转换为 TEXT 事件 >>> >>> invoker = AgentInvoker(my_agent) - >>> result = await invoker.invoke(AgentRequest(...)) - >>> # result 是 AgentRunResult 对象 + >>> async for event in invoker.invoke_stream(AgentRequest(...)): + ... print(event) # AgentEvent 对象 """ def __init__(self, invoke_agent: InvokeAgentHandler): """初始化 Agent 调用器 Args: - invoke_agent: Agent 处理函数,可以是同步或异步 + invoke_agent: Agent 处理函数,可以是同步或异步 """ self.invoke_agent = invoke_agent - self.is_async = inspect.iscoroutinefunction(invoke_agent) + # 检测是否是异步函数或异步生成器 + self.is_async = inspect.iscoroutinefunction( + invoke_agent + ) or inspect.isasyncgenfunction(invoke_agent) - async def invoke(self, request: AgentRequest) -> AgentResult: + async def invoke( + self, request: AgentRequest + ) -> Union[List[AgentEvent], AsyncGenerator[AgentEvent, None]]: """调用 Agent 并返回结果 - 自动处理各种返回类型: - - string 或 string 迭代器 -> 转换为 AgentRunResult - - AgentRunResult -> 直接返回 - - AgentResponse/ModelResponse -> 直接返回 + 根据返回值类型决定返回: + - 非迭代器: 返回 List[AgentEvent] + - 迭代器: 返回 AsyncGenerator[AgentEvent, None] Args: request: AgentRequest 请求对象 Returns: - AgentResult: Agent 返回的结果 + List[AgentEvent] 或 AsyncGenerator[AgentEvent, None] + """ + raw_result = await self._call_handler(request) + + if self._is_iterator(raw_result): + return self._wrap_stream(raw_result) + else: + return self._wrap_non_stream(raw_result) + + async def invoke_stream( + self, request: AgentRequest + ) -> AsyncGenerator[AgentEvent, None]: + """调用 Agent 并返回流式结果 + + 始终返回流式结果,即使原始返回值是非流式的。 + 只输出核心事件,边界事件由协议层生成。 + + Args: + request: AgentRequest 请求对象 + + Yields: + AgentEvent: 事件结果 + """ + try: + raw_result = await self._call_handler(request) + + if self._is_iterator(raw_result): + # 流式结果 - 逐个处理 + async for item in self._iterate_async(raw_result): + if item is None: + continue + + if isinstance(item, str): + if not item: # 跳过空字符串 + continue + yield AgentEvent( + event=EventType.TEXT, + data={"delta": item}, + ) + + elif isinstance(item, AgentEvent): + # 处理用户返回的事件 + for processed_event in self._process_user_event(item): + yield processed_event + else: + # 非流式结果 + results = self._wrap_non_stream(raw_result) + for result in results: + yield result + + except Exception as e: + # 发送错误事件 + from agentrun.utils.log import logger + + logger.error(f"Agent 调用出错: {e}", exc_info=True) + yield AgentEvent( + event=EventType.ERROR, + data={"message": str(e), "code": type(e).__name__}, + ) + + def _process_user_event( + self, + event: AgentEvent, + ) -> Iterator[AgentEvent]: + """处理用户返回的事件 + + - TOOL_CALL 事件会被展开为 TOOL_CALL_CHUNK + - 其他事件直接传递 + + Args: + event: 用户返回的事件 + + Yields: + 处理后的事件 + """ + # 展开 TOOL_CALL 为 TOOL_CALL_CHUNK + if event.event == EventType.TOOL_CALL: + tool_id = event.data.get("id", str(uuid.uuid4())) + tool_name = event.data.get("name", "") + tool_args = event.data.get("args", "") + + # 发送包含名称的 chunk(首个 chunk 包含名称) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": tool_id, + "name": tool_name, + "args_delta": tool_args, + }, + ) + return - Raises: - Exception: Agent 执行中的任何异常 + # 其他事件直接传递 + yield event + + async def _call_handler(self, request: AgentRequest) -> Any: + """调用用户的 handler + + Args: + request: AgentRequest 请求对象 + + Returns: + 原始返回值 """ if self.is_async: - # 异步 handler async_handler = cast(AsyncInvokeAgentHandler, self.invoke_agent) - result = await async_handler(request) + raw_result = async_handler(request) + + if inspect.isawaitable(raw_result): + result = await cast(Awaitable[Any], raw_result) + elif inspect.isasyncgen(raw_result): + result = raw_result + else: + result = raw_result else: - # 同步 handler: 在线程池中运行,避免阻塞事件循环 sync_handler = cast(SyncInvokeAgentHandler, self.invoke_agent) - result = await asyncio.get_event_loop().run_in_executor( - None, sync_handler, request - ) - - # 自动转换 string 或 string 迭代器为 AgentRunResult - result = self._normalize_result(result) + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, sync_handler, request) return result - def _normalize_result(self, result: AgentResult) -> AgentResult: - """标准化返回结果 - - 将 string 或 string 迭代器自动转换为 AgentRunResult。 + def _wrap_non_stream(self, result: Any) -> List[AgentEvent]: + """包装非流式结果为 AgentEvent 列表 Args: - result: 原始返回结果 + result: 原始返回值 Returns: - AgentResult: 标准化后的结果 + AgentEvent 列表 """ - # 如果是字符串,转换为 AgentRunResult + results: List[AgentEvent] = [] + + if result is None: + return results + if isinstance(result, str): - return AgentRunResult(content=result) + results.append( + AgentEvent( + event=EventType.TEXT, + data={"delta": result}, + ) + ) - # 如果是迭代器,检查是否是字符串迭代器 - if self._is_string_iterator(result): - return AgentRunResult(content=result) # type: ignore + elif isinstance(result, AgentEvent): + # 处理可能的 TOOL_CALL 展开 + results.extend(self._process_user_event(result)) - # 其他类型直接返回 - return result + elif isinstance(result, list): + for item in result: + if isinstance(item, AgentEvent): + results.extend(self._process_user_event(item)) + elif isinstance(item, str) and item: + results.append( + AgentEvent( + event=EventType.TEXT, + data={"delta": item}, + ) + ) - def _is_string_iterator(self, obj) -> bool: - """检查是否是字符串迭代器 + return results - 通过类型注解或启发式方法判断。 + async def _wrap_stream( + self, iterator: Any + ) -> AsyncGenerator[AgentEvent, None]: + """包装迭代器为 AgentEvent 异步生成器 Args: - obj: 要检查的对象 + iterator: 原始迭代器 - Returns: - bool: 是否是字符串迭代器 + Yields: + AgentEvent: 事件结果 """ - # 排除已知的复杂类型 - from .model import AgentResponse, AgentRunResult + async for item in self._iterate_async(iterator): + if item is None: + continue - if isinstance(obj, (AgentResponse, AgentRunResult, str, dict)): - return False + if isinstance(item, str): + if not item: + continue + yield AgentEvent( + event=EventType.TEXT, + data={"delta": item}, + ) - # 检查是否是迭代器 - is_iterator = ( - hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, dict)) - ) or hasattr(obj, "__aiter__") + elif isinstance(item, AgentEvent): + for processed_event in self._process_user_event(item): + yield processed_event - if not is_iterator: - return False + async def _iterate_async( + self, content: Union[Iterator[Any], AsyncIterator[Any]] + ) -> AsyncGenerator[Any, None]: + """统一迭代同步和异步迭代器 - # 启发式判断: 如果没有 choices 属性,很可能是字符串迭代器 - # (AgentResponse/ModelResponse 都有 choices 属性) - if hasattr(obj, "choices") or hasattr(obj, "model"): - return False + 对于同步迭代器,每次 next() 调用都在线程池中执行,避免阻塞事件循环。 - return True + Args: + content: 迭代器 + + Yields: + 迭代器中的元素 + """ + if hasattr(content, "__aiter__"): + async for chunk in content: + yield chunk + else: + loop = asyncio.get_running_loop() + iterator = iter(content) + + _STOP = object() + + def _safe_next() -> Any: + try: + return next(iterator) + except StopIteration: + return _STOP + + while True: + chunk = await loop.run_in_executor(None, _safe_next) + if chunk is _STOP: + break + yield chunk + + def _is_iterator(self, obj: Any) -> bool: + """检查对象是否是迭代器""" + if isinstance(obj, (str, bytes, dict, list, AgentEvent)): + return False + return hasattr(obj, "__iter__") or hasattr(obj, "__aiter__") diff --git a/agentrun/server/model.py b/agentrun/server/model.py index d6651b6..40743de 100644 --- a/agentrun/server/model.py +++ b/agentrun/server/model.py @@ -1,12 +1,16 @@ -"""AgentRun Server 模型定义 / AgentRun Server 模型Defines +"""AgentRun Server 模型定义 / AgentRun Server Model Definitions -定义 invokeAgent callback 的参数结构和响应类型""" +定义标准化的 AgentRequest 和 AgentEvent 数据结构。 +采用协议无关的设计,支持多协议转换(OpenAI、AG-UI 等)。 +""" from enum import Enum from typing import ( Any, + AsyncGenerator, AsyncIterator, Dict, + Generator, Iterator, List, Optional, @@ -14,16 +18,55 @@ Union, ) -from pydantic import BaseModel, Field +# 导入 Request 类,用于类型提示和运行时使用 +from starlette.requests import Request -if TYPE_CHECKING: - # 运行时不导入,避免依赖问题 - from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper - from litellm.types.utils import ModelResponse +from ..utils.helper import MergeOptions +from ..utils.model import BaseModel, Field + +# ============================================================================ +# 协议配置 +# ============================================================================ + + +class ProtocolConfig(BaseModel): + prefix: Optional[str] = None + enable: bool = True + + +class AGUIProtocolConfig(ProtocolConfig): + """AG-UI 协议配置 + + Attributes: + prefix: 协议路由前缀,默认 "/ag-ui/agent" + enable: 是否启用协议 + copilotkit_compatibility: 旧版本 CopilotKit 兼容模式。 + 默认 False,遵循标准 AG-UI 协议,支持并行工具调用。 + 设置为 True 时,启用以下兼容行为: + - 在发送新的 TOOL_CALL_START 前自动结束其他活跃的工具调用 + - 将 LangChain 的 UUID 格式 ID 映射到 call_xxx ID + - 将其他工具的事件放入队列,等待当前工具完成后再处理 + 这是为了兼容 CopilotKit 等前端的严格验证。 + """ + + enable: bool = True + prefix: Optional[str] = "/ag-ui/agent" + copilotkit_compatibility: bool = False + + +class ServerConfig(BaseModel): + openai: Optional["OpenAIProtocolConfig"] = None + agui: Optional["AGUIProtocolConfig"] = None + cors_origins: Optional[List[str]] = None + + +# ============================================================================ +# 消息角色和消息体定义 +# ============================================================================ class MessageRole(str, Enum): - """消息角色""" + """消息角色 / Message Role""" SYSTEM = "system" USER = "user" @@ -31,195 +74,362 @@ class MessageRole(str, Enum): TOOL = "tool" +class ToolCall(BaseModel): + """工具调用 / Tool Call""" + + id: str + type: str = "function" + function: Dict[str, Any] + + class Message(BaseModel): - """消息体""" + """标准化消息体 / Standardized Message + + 兼容 AG-UI 和 OpenAI 消息格式。 + + Attributes: + id: 消息唯一标识(AG-UI 格式) + role: 消息角色 + content: 消息内容(字符串或多模态内容列表) + name: 发送者名称(可选) + tool_calls: 工具调用列表(assistant 消息) + tool_call_id: 对应的工具调用 ID(tool 消息) + """ + id: Optional[str] = None role: MessageRole - content: Optional[str] = None + content: Optional[Union[str, List[Dict[str, Any]]]] = None name: Optional[str] = None - tool_calls: Optional[List[Dict[str, Any]]] = None + tool_calls: Optional[List[ToolCall]] = None tool_call_id: Optional[str] = None -class ToolCall(BaseModel): - """工具调用""" +class Tool(BaseModel): + """工具定义 / Tool Definition + + 兼容 AG-UI 和 OpenAI 工具格式。 + """ - id: str type: str = "function" function: Dict[str, Any] -class Tool(BaseModel): - """工具定义 / 工具Defines""" +# ============================================================================ +# 事件类型定义(协议无关) +# ============================================================================ - type: str = "function" - function: Dict[str, Any] +class EventType(str, Enum): + """事件类型(协议无关) -class AgentRequest(BaseModel): - """Agent 请求参数 + 定义核心事件类型,框架会自动转换为对应协议格式(OpenAI、AG-UI 等)。 + 用户只需关心语义,无需关心具体协议细节。 - invokeAgent callback 接收的参数结构 - 符合 OpenAI Completions API 格式 + 边界事件(如消息开始/结束、生命周期开始/结束)由协议层自动处理, + 用户无需关心。 """ - # 必需参数 - messages: List[Message] = Field(..., description="对话历史消息列表") + # ========================================================================= + # 核心事件(用户主要使用) + # ========================================================================= + TEXT = "TEXT" # 文本内容块 + TOOL_CALL = "TOOL_CALL" # 完整工具调用(含 id, name, args) + TOOL_CALL_CHUNK = "TOOL_CALL_CHUNK" # 工具调用参数片段(流式场景) + TOOL_RESULT = "TOOL_RESULT" # 工具执行结果(最终结果,标识流式输出结束) + TOOL_RESULT_CHUNK = "TOOL_RESULT_CHUNK" # 工具执行结果片段(流式输出场景) + ERROR = "ERROR" # 错误事件 + STATE = "STATE" # 状态更新(快照或增量) + + # ========================================================================= + # 人机交互事件 + # ========================================================================= + HITL = "HITL" # Human-in-the-Loop,请求人类介入 + + # ========================================================================= + # 扩展事件 + # ========================================================================= + CUSTOM = "CUSTOM" # 自定义事件(协议层会正确处理) + RAW = "RAW" # 原始协议数据(直接透传到响应流) + + +# ============================================================================ +# Addition 合并参数(使用 MergeOptions) +# ============================================================================ +# 使用 MergeOptions(来自 utils.helper.merge)控制 addition 的合并行为: +# - 默认 (None): 深度合并,允许新增字段 +# - no_new_field=True: 仅覆盖已有字段(等价于原 PROTOCOL_ONLY) +# - concat_list / ignore_empty_list: 透传给 merge 控制列表合并策略 + + +# ============================================================================ +# AgentEvent(标准化事件) +# ============================================================================ + + +class AgentEvent(BaseModel): + """Agent 执行事件 + + 标准化的事件结构,协议无关设计。 + 框架层会自动将 AgentEvent 转换为对应协议的格式(OpenAI、AG-UI 等)。 + + Attributes: + event: 事件类型 + data: 事件数据 + addition: 额外附加字段(可选,用于协议特定扩展) + addition_merge_options: 合并选项(透传给 utils.helper.merge,默认深度合并) + + Example (文本消息): + >>> yield AgentEvent( + ... event=EventType.TEXT, + ... data={"delta": "Hello, world!"} + ... ) - # 可选参数 - model: Optional[str] = Field(None, description="模型名称") - stream: bool = Field(False, description="是否使用流式输出") - temperature: Optional[float] = Field( - None, description="采样温度", ge=0.0, le=2.0 - ) - top_p: Optional[float] = Field( - None, description="核采样参数", ge=0.0, le=1.0 - ) - max_tokens: Optional[int] = Field( - None, description="最大生成 token 数", gt=0 - ) - tools: Optional[List[Tool]] = Field(None, description="可用的工具列表") - tool_choice: Optional[Union[str, Dict[str, Any]]] = Field( - None, description="工具选择策略" - ) - user: Optional[str] = Field(None, description="用户标识") + Example (完整工具调用): + >>> yield AgentEvent( + ... event=EventType.TOOL_CALL, + ... data={ + ... "id": "tc-1", + ... "name": "get_weather", + ... "args": '{"location": "Beijing"}' + ... } + ... ) - # 扩展参数 - extra: Dict[str, Any] = Field( - default_factory=dict, description="其他自定义参数" - ) + Example (流式工具调用): + >>> yield AgentEvent( + ... event=EventType.TOOL_CALL_CHUNK, + ... data={"id": "tc-1", "name": "search", "args_delta": '{"q":'} + ... ) + >>> yield AgentEvent( + ... event=EventType.TOOL_CALL_CHUNK, + ... data={"id": "tc-1", "args_delta": '"test"}'} + ... ) + Example (工具执行结果): + >>> yield AgentEvent( + ... event=EventType.TOOL_RESULT, + ... data={"id": "tc-1", "result": "Sunny, 25°C"} + ... ) -class AgentResponseChoice(BaseModel): - """响应选项""" + Example (流式工具执行结果): + 流式工具输出的使用流程: + 1. TOOL_RESULT_CHUNK 事件会被缓存,不会立即发送 + 2. 必须发送 TOOL_RESULT 事件来标识流式输出结束 + 3. TOOL_RESULT 会将缓存的 chunks 拼接到最终结果前面 - index: int - message: Message - finish_reason: Optional[str] = None + >>> # 工具执行过程中流式输出(这些会被缓存) + >>> yield AgentEvent( + ... event=EventType.TOOL_RESULT_CHUNK, + ... data={"id": "tc-1", "delta": "Executing step 1...\n"} + ... ) + >>> yield AgentEvent( + ... event=EventType.TOOL_RESULT_CHUNK, + ... data={"id": "tc-1", "delta": "Step 1 complete.\n"} + ... ) + >>> # 最终结果(必须发送,标识流式输出结束) + >>> # 发送后会拼接为: "Executing step 1...\nStep 1 complete.\nAll steps completed." + >>> yield AgentEvent( + ... event=EventType.TOOL_RESULT, + ... data={"id": "tc-1", "result": "All steps completed."} + ... ) + >>> # 如果只有流式输出,result 可以为空字符串 + >>> yield AgentEvent( + ... event=EventType.TOOL_RESULT, + ... data={"id": "tc-1", "result": ""} # 只使用缓存的 chunks + ... ) + Example (HITL - Human-in-the-Loop,请求人类介入): + HITL 有两种使用方式: + 1. 关联已存在的工具调用:设置 tool_call_id,复用现有工具 + 2. 创建独立的 HITL 工具调用:只设置 id -class AgentResponseUsage(BaseModel): - """Token 使用统计""" + >>> # 方式 1:关联已存在的工具调用(先发送 TOOL_CALL,再发送 HITL) + >>> yield AgentEvent( + ... event=EventType.TOOL_CALL, + ... data={"id": "tc-delete", "name": "delete_file", "args": '{"file": "a.txt"}'} + ... ) + >>> yield AgentEvent( + ... event=EventType.HITL, + ... data={ + ... "id": "hitl-1", + ... "tool_call_id": "tc-delete", # 关联已存在的工具调用 + ... "type": "confirmation", + ... "prompt": "确认删除文件 a.txt?" + ... } + ... ) + >>> # 方式 2:创建独立的 HITL 工具调用 + >>> yield AgentEvent( + ... event=EventType.HITL, + ... data={ + ... "id": "hitl-2", + ... "type": "input", + ... "prompt": "请输入密码:", + ... "options": ["确认", "取消"], # 可选 + ... "schema": {"type": "string", "minLength": 8} # 可选 + ... } + ... ) + >>> # 用户响应将通过下一轮对话的 messages 中的 tool message 传回 - prompt_tokens: int = 0 - completion_tokens: int = 0 - total_tokens: int = 0 + Example (自定义事件): + >>> yield AgentEvent( + ... event=EventType.CUSTOM, + ... data={"name": "step_started", "value": {"step": "thinking"}} + ... ) + Example (原始协议数据): + >>> yield AgentEvent( + ... event=EventType.RAW, + ... data={"raw": "data: {...}\\n\\n"} + ... ) + """ -class AgentRunResult(BaseModel): - """Agent 运行结果 + event: EventType + data: Dict[str, Any] = Field(default_factory=dict) + addition: Optional[Dict[str, Any]] = None + addition_merge_options: Optional[MergeOptions] = None - 核心数据结构,用于表示 Agent 执行结果。 - content 字段支持字符串或字符串迭代器。 - Example: - >>> # 返回字符串 - >>> AgentRunResult(content="Hello, world!") - >>> - >>> # 返回字符串迭代器(流式) - >>> def stream(): - ... yield "Hello, " - ... yield "world!" - >>> AgentRunResult(content=stream()) - """ +# 兼容别名 +AgentResult = AgentEvent - model_config = {"arbitrary_types_allowed": True} - content: Union[str, Iterator[str], AsyncIterator[str], Any] - """响应内容,支持字符串或字符串迭代器 / 响应内容,Supports字符串或字符串迭代器""" +# ============================================================================ +# AgentRequest(标准化请求) +# ============================================================================ -class AgentResponse(BaseModel): - """Agent 响应(非流式) +class AgentRequest(BaseModel): + """Agent 请求参数(协议无关) + + 标准化的请求结构,统一了 OpenAI 和 AG-UI 协议的输入格式。 + + Attributes: + protocol: 当前交互协议名称(如 "openai", "agui") + messages: 对话历史消息列表(标准化格式) + stream: 是否使用流式输出 + tools: 可用的工具列表 + raw_request: 原始 HTTP 请求对象(Starlette Request) + + Example (基本使用): + >>> def invoke_agent(request: AgentRequest): + ... user_msg = request.messages[-1].content + ... return f"你说的是: {user_msg}" + + Example (流式输出): + >>> async def invoke_agent(request: AgentRequest): + ... for word in ["Hello", " ", "World"]: + ... yield word + + Example (使用事件): + >>> async def invoke_agent(request: AgentRequest): + ... yield AgentEvent( + ... event=EventType.CUSTOM, + ... data={"name": "step_started", "value": {"step": "thinking"}} + ... ) + ... yield "I'm thinking..." + ... yield AgentEvent( + ... event=EventType.CUSTOM, + ... data={"name": "step_finished", "value": {"step": "thinking"}} + ... ) + + Example (工具调用): + >>> async def invoke_agent(request: AgentRequest): + ... # 完整工具调用 + ... yield AgentEvent( + ... event=EventType.TOOL_CALL, + ... data={ + ... "id": "tc-1", + ... "name": "search", + ... "args": '{"query": "weather"}' + ... } + ... ) + ... # 执行工具并返回结果 + ... result = do_search("weather") + ... yield AgentEvent( + ... event=EventType.TOOL_RESULT, + ... data={"id": "tc-1", "result": result} + ... ) + + Example (根据协议差异化处理): + >>> async def invoke_agent(request: AgentRequest): + ... if request.protocol == "openai": + ... # OpenAI 特定处理 + ... pass + ... elif request.protocol == "agui": + ... # AG-UI 特定处理 + ... pass + + Example (访问原始请求): + >>> async def invoke_agent(request: AgentRequest): + ... # 访问原始请求头 + ... auth = request.raw_request.headers.get("Authorization") + ... # 访问原始请求体(已解析的 JSON) + ... body = await request.raw_request.json() + ... # 访问查询参数 + ... params = request.raw_request.query_params + ... # 访问客户端 IP + ... client_ip = request.raw_request.client.host + """ - 灵活的响应数据结构,所有字段都是可选的。 - 用户可以只填充需要的字段,协议层会根据实际协议格式补充或跳过字段。 + model_config = {"arbitrary_types_allowed": True} - Example: - >>> # 最简单 - 只返回内容 - >>> AgentResponse(content="Hello") - >>> - >>> # OpenAI 格式 - 完整字段 - >>> AgentResponse( - ... id="chatcmpl-123", - ... model="gpt-4", - ... choices=[...] - ... ) - """ + # 协议信息 + protocol: str = Field("unknown", description="当前交互协议名称") - # 核心字段 - 协议无关 - content: Optional[str] = None - """响应内容""" - - # OpenAI 协议字段 - 可选 - id: Optional[str] = Field(None, description="响应 ID") - object: Optional[str] = Field(None, description="对象类型") - created: Optional[int] = Field(None, description="创建时间戳") - model: Optional[str] = Field(None, description="使用的模型") - choices: Optional[List[AgentResponseChoice]] = Field( - None, description="响应选项列表" - ) - usage: Optional[AgentResponseUsage] = Field( - None, description="Token 使用情况" + # 标准化参数 + messages: List[Message] = Field( + default_factory=list, description="对话历史消息列表" ) + stream: bool = Field(False, description="是否使用流式输出") + tools: Optional[List[Tool]] = Field(None, description="可用的工具列表") - # 扩展字段 - 其他协议可能需要 - extra: Dict[str, Any] = Field( - default_factory=dict, description="协议特定的额外字段" + # 原始请求对象 + raw_request: Optional[Request] = Field( + None, description="原始 HTTP 请求对象(Starlette Request)" ) -class AgentStreamResponseDelta(BaseModel): - """流式响应增量""" +# ============================================================================ +# OpenAI 协议配置(前置声明) +# ============================================================================ - role: Optional[MessageRole] = None - content: Optional[str] = None - tool_calls: Optional[List[Dict[str, Any]]] = None +class OpenAIProtocolConfig(ProtocolConfig): + """OpenAI 协议配置""" -class AgentStreamResponse(BaseModel): - """流式响应块""" + enable: bool = True + prefix: Optional[str] = "/openai/v1" + model_name: Optional[str] = None - id: Optional[str] = None - object: Optional[str] = None - created: Optional[int] = None - model: Optional[str] = None - choices: Optional[List["AgentStreamResponseChoice"]] = None - extra: Dict[str, Any] = Field(default_factory=dict) +# ============================================================================ +# 返回值类型别名 +# ============================================================================ -class AgentStreamResponseChoice(BaseModel): - """流式响应选项""" - index: int - delta: AgentStreamResponseDelta - finish_reason: Optional[str] = None +# 单个结果项:可以是字符串或 AgentEvent +AgentEventItem = Union[str, AgentEvent] +# 兼容别名 +AgentResultItem = AgentEventItem -# 类型别名 - 流式响应迭代器 -AgentStreamIterator = Union[ - Iterator[AgentResponse], - AsyncIterator[AgentResponse], -] +# 同步生成器 +SyncAgentEventGenerator = Generator[AgentEventItem, None, None] +SyncAgentResultGenerator = SyncAgentEventGenerator # 兼容别名 + +# 异步生成器 +AsyncAgentEventGenerator = AsyncGenerator[AgentEventItem, None] +AsyncAgentResultGenerator = AsyncAgentEventGenerator # 兼容别名 -# Model Service 类型 - 直接返回 litellm 的 ModelResponse -if TYPE_CHECKING: - ModelServiceResult = Union["ModelResponse", "CustomStreamWrapper"] -else: - ModelServiceResult = Any # 运行时使用 Any - -# AgentResult - 支持多种返回形式 -# 用户可以返回: -# 1. string 或 string 迭代器 - 自动转换为 AgentRunResult -# 2. AgentRunResult - 核心数据结构 -# 3. AgentResponse - 完整响应对象 -# 4. ModelResponse - Model Service 响应 -AgentResult = Union[ - str, # 简化: 直接返回字符串 - Iterator[str], # 简化: 字符串流 - AsyncIterator[str], # 简化: 异步字符串流 - AgentRunResult, # 核心: AgentRunResult 对象 - AgentResponse, # 完整: AgentResponse 对象 - AgentStreamIterator, # 流式: AgentResponse 流 - ModelServiceResult, # Model Service: ModelResponse 或 CustomStreamWrapper +# Agent 函数返回值类型 +AgentReturnType = Union[ + # 简单返回 + str, # 直接返回字符串 + AgentEvent, # 返回单个事件 + List[AgentEvent], # 返回多个事件(非流式) + Dict[str, Any], # 返回字典(如 OpenAI/AG-UI 非流式响应) + # 迭代器/生成器返回(流式) + Iterator[AgentEventItem], + AsyncIterator[AgentEventItem], + SyncAgentEventGenerator, + AsyncAgentEventGenerator, ] diff --git a/agentrun/server/openai_protocol.py b/agentrun/server/openai_protocol.py index dff4580..5c82ccf 100644 --- a/agentrun/server/openai_protocol.py +++ b/agentrun/server/openai_protocol.py @@ -1,44 +1,87 @@ -"""OpenAI Completions API 协议实现 / OpenAI Completions API 协议Implements +"""OpenAI Completions API 协议实现 / OpenAI Completions API Protocol Implementation -基于 Router 的设计: -- 协议自己创建 FastAPI Router -- 定义所有端点和处理逻辑 -- Server 只需挂载 Router""" +实现 OpenAI Chat Completions API 兼容接口。 +参考: https://platform.openai.com/docs/api-reference/chat/create + +本实现将 AgentResult 事件转换为 OpenAI 流式响应格式。 +""" import json import time -from typing import Any, AsyncIterator, Dict, Iterator, TYPE_CHECKING, Union +from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING +import uuid from fastapi import APIRouter, Request from fastapi.responses import JSONResponse, StreamingResponse +import pydash +from ..utils.helper import merge, MergeOptions from .model import ( + AgentEvent, AgentRequest, - AgentResponse, - AgentResult, - AgentRunResult, - AgentStreamResponse, - AgentStreamResponseChoice, - AgentStreamResponseDelta, + EventType, Message, MessageRole, + OpenAIProtocolConfig, + ServerConfig, + Tool, + ToolCall, ) -from .protocol import ProtocolHandler +from .protocol import BaseProtocolHandler if TYPE_CHECKING: from .invoker import AgentInvoker -class OpenAIProtocolHandler(ProtocolHandler): +# ============================================================================ +# OpenAI 协议处理器 +# ============================================================================ + + +DEFAULT_PREFIX = "/openai/v1" + + +class OpenAIProtocolHandler(BaseProtocolHandler): """OpenAI Completions API 协议处理器 - 实现 OpenAI Chat Completions API 兼容接口 + 实现 OpenAI Chat Completions API 兼容接口。 参考: https://platform.openai.com/docs/api-reference/chat/create + + 特点: + - 完全兼容 OpenAI API 格式 + - 支持流式和非流式响应 + - 支持工具调用 + - AgentResult 事件自动转换为 OpenAI 格式 + + 支持的事件映射: + - TEXT_MESSAGE_* → delta.content + - TOOL_CALL_* → delta.tool_calls + - RUN_FINISHED → [DONE] + - 其他事件 → 忽略 + + Example: + >>> from agentrun.server import AgentRunServer + >>> + >>> def my_agent(request): + ... return "Hello, world!" + >>> + >>> server = AgentRunServer(invoke_agent=my_agent) + >>> server.start(port=8000) + # 可访问: POST http://localhost:8000/openai/v1/chat/completions """ + name = "openai_chat_completions" + + def __init__(self, config: Optional[ServerConfig] = None): + self.config = config.openai if config else None + def get_prefix(self) -> str: - """OpenAI 协议建议使用 /v1 前缀""" - return "/openai/v1" + """OpenAI 协议建议使用 /openai/v1 前缀""" + return pydash.get(self.config, "prefix", DEFAULT_PREFIX) + + def get_model_name(self) -> str: + """获取默认模型名称""" + return pydash.get(self.config, "model_name", "agentrun") def as_fastapi_router(self, agent_invoker: "AgentInvoker") -> APIRouter: """创建 OpenAI 协议的 FastAPI Router""" @@ -47,27 +90,41 @@ def as_fastapi_router(self, agent_invoker: "AgentInvoker") -> APIRouter: @router.post("/chat/completions") async def chat_completions(request: Request): """OpenAI Chat Completions 端点""" + sse_headers = { + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + } + try: - # 1. 解析请求 request_data = await request.json() - agent_request = await self.parse_request(request_data) - - # 2. 调用 Agent - agent_result = await agent_invoker.invoke(agent_request) - - # 3. 格式化响应 - formatted_result = await self.format_response( - agent_result, agent_request + agent_request, context = await self.parse_request( + request, request_data ) - # 4. 返回响应 - # 自动检测是否为流式响应 - if hasattr(formatted_result, "__aiter__"): + if agent_request.stream: + # 流式响应 + event_stream = self._format_stream( + agent_invoker.invoke_stream(agent_request), + context, + ) return StreamingResponse( - formatted_result, media_type="text/event-stream" + event_stream, + media_type="text/event-stream", + headers=sse_headers, ) else: - return JSONResponse(formatted_result) + # 非流式响应 + results = await agent_invoker.invoke(agent_request) + if hasattr(results, "__aiter__"): + # 收集流式结果 + result_list = [] + async for r in results: + result_list.append(r) + results = result_list + + formatted = self._format_non_stream(results, context) + return JSONResponse(formatted) except ValueError as e: return JSONResponse( @@ -85,14 +142,13 @@ async def chat_completions(request: Request): status_code=500, ) - # 可以添加更多端点 @router.get("/models") async def list_models(): """列出可用模型""" return { "object": "list", "data": [{ - "id": "agentrun-model", + "id": self.get_model_name(), "object": "model", "created": int(time.time()), "owned_by": "agentrun", @@ -101,32 +157,68 @@ async def list_models(): return router - async def parse_request(self, request_data: Dict[str, Any]) -> AgentRequest: + async def parse_request( + self, + request: Request, + request_data: Dict[str, Any], + ) -> tuple[AgentRequest, Dict[str, Any]]: """解析 OpenAI 格式的请求 Args: + request: FastAPI Request 对象 request_data: HTTP 请求体 JSON 数据 Returns: - AgentRequest: 标准化的请求对象 - - Raises: - ValueError: 请求格式不正确 + tuple: (AgentRequest, context) """ # 验证必需字段 if "messages" not in request_data: raise ValueError("Missing required field: messages") + # 创建上下文 + context = { + "response_id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "model": request_data.get("model", self.get_model_name()), + "created": int(time.time()), + } + # 解析消息列表 + messages = self._parse_messages(request_data["messages"]) + + # 解析工具列表 + tools = self._parse_tools(request_data.get("tools")) + + # 构建 AgentRequest + agent_request = AgentRequest( + protocol="openai", # 设置协议名称 + messages=messages, + stream=request_data.get("stream", False), + tools=tools, + raw_request=request, # 保留原始请求对象 + ) + + return agent_request, context + + def _parse_messages( + self, raw_messages: List[Dict[str, Any]] + ) -> List[Message]: + """解析消息列表 + + Args: + raw_messages: 原始消息数据 + + Returns: + 标准化的消息列表 + """ messages = [] - for msg_data in request_data["messages"]: + + for msg_data in raw_messages: if not isinstance(msg_data, dict): raise ValueError(f"Invalid message format: {msg_data}") if "role" not in msg_data: raise ValueError("Message missing 'role' field") - # 转换消息 try: role = MessageRole(msg_data["role"]) except ValueError as e: @@ -134,665 +226,310 @@ async def parse_request(self, request_data: Dict[str, Any]) -> AgentRequest: f"Invalid message role: {msg_data['role']}" ) from e + # 解析 tool_calls + tool_calls = None + if msg_data.get("tool_calls"): + tool_calls = [ + ToolCall( + id=tc.get("id", ""), + type=tc.get("type", "function"), + function=tc.get("function", {}), + ) + for tc in msg_data["tool_calls"] + ] + messages.append( Message( role=role, content=msg_data.get("content"), name=msg_data.get("name"), - tool_calls=msg_data.get("tool_calls"), + tool_calls=tool_calls, tool_call_id=msg_data.get("tool_call_id"), ) ) - # 提取标准参数 - agent_request = AgentRequest( - messages=messages, - model=request_data.get("model"), - stream=request_data.get("stream", False), - temperature=request_data.get("temperature"), - top_p=request_data.get("top_p"), - max_tokens=request_data.get("max_tokens"), - tools=request_data.get("tools"), - tool_choice=request_data.get("tool_choice"), - user=request_data.get("user"), - ) - - # 保存其他额外参数 - standard_fields = { - "messages", - "model", - "stream", - "temperature", - "top_p", - "max_tokens", - "tools", - "tool_choice", - "user", - } - agent_request.extra = { - k: v for k, v in request_data.items() if k not in standard_fields - } - - return agent_request - - async def format_response( - self, result: AgentResult, request: AgentRequest - ) -> Any: - """格式化响应为 OpenAI 格式 - - Args: - result: Agent 执行结果,支持: - - AgentRunResult: 核心数据结构 (推荐) - - AgentResponse: 完整响应对象 - - ModelResponse: litellm 的 ModelResponse - - CustomStreamWrapper: litellm 的流式响应 - request: 原始请求 - - Returns: - 格式化后的响应(dict 或 AsyncIterator) - """ - # 1. 检测 ModelResponse (来自 Model Service) - if self._is_model_response(result): - return self._format_model_response(result, request) - - # 2. 处理 AgentRunResult - if isinstance(result, AgentRunResult): - return await self._format_agent_run_result(result, request) - - # 3. 自动检测流式响应: - # - 请求明确指定 stream=true - # - 或返回值是迭代器/生成器 - is_stream = request.stream or self._is_iterator(result) - - if is_stream: - return self._format_stream_response(result, request) - - # 4. 非流式响应 - # 如果是字符串,包装成 AgentResponse - if isinstance(result, str): - result = self._wrap_string_response(result, request) - - # 如果是 AgentResponse,补充 OpenAI 必需字段并序列化 - if isinstance(result, AgentResponse): - return self._ensure_openai_format(result, request) - - raise TypeError( - "Expected AgentRunResult, AgentResponse, or ModelResponse, " - f"got {type(result)}" - ) + return messages - async def _format_agent_run_result( - self, result: AgentRunResult, request: AgentRequest - ) -> Union[Dict[str, Any], AsyncIterator[str]]: - """格式化 AgentRunResult 为 OpenAI 格式 - - AgentRunResult 的 content 可以是: - - string: 非流式响应 - - Iterator[str] 或 AsyncIterator[str]: 流式响应 + def _parse_tools( + self, raw_tools: Optional[List[Dict[str, Any]]] + ) -> Optional[List[Tool]]: + """解析工具列表 Args: - result: AgentRunResult 对象 - request: 原始请求 + raw_tools: 原始工具数据 Returns: - 非流式: OpenAI 格式的字典 - 流式: SSE 格式的异步迭代器 + 标准化的工具列表 """ - content = result.content - - # 检查 content 是否是迭代器 - if self._is_iterator(content): - # 流式响应 - return self._format_stream_content(content, request) - - # 非流式响应 - if isinstance(content, str): - return { - "id": f"chatcmpl-{int(time.time() * 1000)}", - "object": "chat.completion", - "created": int(time.time()), - "model": request.model or "agentrun-model", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": content, - }, - "finish_reason": "stop", - }], - } + if not raw_tools: + return None + + tools = [] + for tool_data in raw_tools: + if not isinstance(tool_data, dict): + continue + + tools.append( + Tool( + type=tool_data.get("type", "function"), + function=tool_data.get("function", {}), + ) + ) - raise TypeError( - "AgentRunResult.content must be str or Iterator[str], got" - f" {type(content)}" - ) + return tools if tools else None - def _is_model_response(self, obj: Any) -> bool: - """检查对象是否是 Model Service 的 ModelResponse + async def _format_stream( + self, + event_stream: AsyncIterator[AgentEvent], + context: Dict[str, Any], + ) -> AsyncIterator[str]: + """将 AgentEvent 流转换为 OpenAI SSE 格式 - ModelResponse 特征: - - 有 choices 属性 - - 有 usage 属性 (或 created, id 等) - - 不是 AgentResponse (AgentResponse 也有这些字段) + 自动生成边界事件: + - 首个 TEXT 事件前发送 role: assistant + - 工具调用自动追踪索引 + - 流结束发送 finish_reason 和 [DONE] Args: - obj: 要检查的对象 + event_stream: AgentEvent 流 + context: 上下文信息 - Returns: - bool: 是否是 ModelResponse + Yields: + SSE 格式的字符串 """ - # 排除已知类型 - if isinstance(obj, (str, AgentResponse, AgentRunResult, dict)): - return False - - # 检查 ModelResponse 的特征属性 - # litellm 的 ModelResponse 有 choices 和 model 属性 - return ( - hasattr(obj, "choices") - and hasattr(obj, "model") - and (hasattr(obj, "usage") or hasattr(obj, "created")) - ) + # 状态追踪 + sent_role = False + has_text = False + tool_call_index = -1 # 从 -1 开始,第一个工具调用时变为 0 + # 工具调用状态:{tool_id: {"started": bool, "index": int}} + tool_call_states: Dict[str, Dict[str, Any]] = {} + has_tool_calls = False + + async for event in event_stream: + # RAW 事件直接透传 + if event.event == EventType.RAW: + raw = event.data.get("raw", "") + if raw: + if not raw.endswith("\n\n"): + raw = raw.rstrip("\n") + "\n\n" + yield raw + continue + + # TEXT 事件 + if event.event == EventType.TEXT: + delta: Dict[str, Any] = {} + # 首个 TEXT 事件,发送 role + if not sent_role: + delta["role"] = "assistant" + sent_role = True + + content = event.data.get("delta", "") + if content: + delta["content"] = content + has_text = True + + # 应用 addition + if event.addition: + delta = self._apply_addition( + delta, + event.addition, + event.addition_merge_options, + ) - def _format_model_response( - self, response: Any, request: AgentRequest - ) -> Dict[str, Any]: - """格式化 ModelResponse 为 OpenAI 格式 + yield self._build_chunk(context, delta) + continue - ModelResponse 本身已经是 OpenAI 格式,直接转换为字典即可。 + # TOOL_CALL_CHUNK 事件 + if event.event == EventType.TOOL_CALL_CHUNK: + tool_id = event.data.get("id", "") + tool_name = event.data.get("name", "") + args_delta = event.data.get("args_delta", "") - Args: - response: litellm 的 ModelResponse 对象 - request: 原始请求 + delta = {} - Returns: - Dict: OpenAI 格式的响应字典 - """ - # 方式 1: 如果有 model_dump 方法 (Pydantic) - if hasattr(response, "model_dump"): - return response.model_dump(exclude_none=True) - - # 方式 2: 如果有 dict 方法 - if hasattr(response, "dict"): - return response.dict(exclude_none=True) - - # 方式 3: 手动转换 (litellm ModelResponse) - result = { - "id": getattr( - response, "id", f"chatcmpl-{int(time.time() * 1000)}" - ), - "object": getattr(response, "object", "chat.completion"), - "created": getattr(response, "created", int(time.time())), - "model": getattr( - response, "model", request.model or "agentrun-model" - ), - "choices": [], - } - - # 转换 choices - if hasattr(response, "choices"): - for choice in response.choices: - choice_dict = { - "index": getattr(choice, "index", 0), - "finish_reason": getattr(choice, "finish_reason", None), - } - - # 转换 message - if hasattr(choice, "message"): - msg = choice.message - choice_dict["message"] = { - "role": getattr(msg, "role", "assistant"), - "content": getattr(msg, "content", None), + # 首次见到这个工具调用 + if tool_id and tool_id not in tool_call_states: + tool_call_index += 1 + tool_call_states[tool_id] = { + "started": True, + "index": tool_call_index, } - # 可选字段 - if hasattr(msg, "tool_calls") and msg.tool_calls: - choice_dict["message"]["tool_calls"] = msg.tool_calls - - result["choices"].append(choice_dict) - - # 转换 usage - if hasattr(response, "usage") and response.usage: - usage = response.usage - result["usage"] = { - "prompt_tokens": getattr(usage, "prompt_tokens", 0), - "completion_tokens": getattr(usage, "completion_tokens", 0), - "total_tokens": getattr(usage, "total_tokens", 0), - } + has_tool_calls = True + + # 发送工具调用开始(包含 id, name) + delta["tool_calls"] = [{ + "index": tool_call_index, + "id": tool_id, + "type": "function", + "function": {"name": tool_name, "arguments": ""}, + }] + yield self._build_chunk(context, delta) + delta = {} + + # 发送参数增量 + if args_delta: + current_index = tool_call_states.get(tool_id, {}).get( + "index", tool_call_index + ) + delta["tool_calls"] = [{ + "index": current_index, + "function": {"arguments": args_delta}, + }] + + # 应用 addition + if event.addition: + delta = self._apply_addition( + delta, + event.addition, + event.addition_merge_options, + ) - return result + yield self._build_chunk(context, delta) + continue - def _is_iterator(self, obj: Any) -> bool: - """检查对象是否是迭代器 + # TOOL_RESULT 事件:OpenAI 协议通常不在流中输出工具结果 + if event.event == EventType.TOOL_RESULT: + continue - Args: - obj: 要检查的对象 + # TOOL_RESULT_CHUNK 事件:OpenAI 协议不支持流式工具输出 + if event.event == EventType.TOOL_RESULT_CHUNK: + continue - Returns: - bool: 是否是迭代器 - """ - # 检查是否是迭代器或生成器 - return ( - hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, dict)) - ) or hasattr(obj, "__aiter__") + # HITL 事件:OpenAI 协议不支持 + if event.event == EventType.HITL: + continue - async def _format_stream_content( - self, - content: Union[Iterator[str], AsyncIterator[str]], - request: AgentRequest, - ) -> AsyncIterator[str]: - """格式化流式 content 为 OpenAI SSE 格式 + # 其他事件忽略 + # (ERROR, STATE, CUSTOM 等不直接映射到 OpenAI 格式) - 将字符串迭代器转换为 OpenAI 流式响应格式。 + # 流结束后发送 finish_reason 和 [DONE] + if has_tool_calls: + yield self._build_chunk(context, {}, finish_reason="tool_calls") + elif has_text: + yield self._build_chunk(context, {}, finish_reason="stop") + yield "data: [DONE]\n\n" + + def _build_chunk( + self, + context: Dict[str, Any], + delta: Dict[str, Any], + finish_reason: Optional[str] = None, + ) -> str: + """构建 OpenAI 流式响应块 Args: - content: 字符串迭代器 (同步或异步) - request: 原始请求 + context: 上下文信息 + delta: delta 数据 + finish_reason: 结束原因 - Yields: - SSE 格式的数据行 + Returns: + SSE 格式的字符串 """ - response_id = f"chatcmpl-{int(time.time() * 1000)}" - created = int(time.time()) - model = request.model or "agentrun-model" - - # 发送第一个 chunk (包含 role) - first_chunk = { - "id": response_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [{ - "index": 0, - "delta": {"role": "assistant"}, - "finish_reason": None, - }], - } - yield f"data: {json.dumps(first_chunk, ensure_ascii=False)}\n\n" - - # 检查是否是异步迭代器 - if hasattr(content, "__aiter__"): - async for chunk in content: # type: ignore - if chunk: # 跳过空字符串 - data = { - "id": response_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [{ - "index": 0, - "delta": {"content": chunk}, - "finish_reason": None, - }], - } - yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" - else: - # 同步迭代器 - for chunk in content: # type: ignore - if chunk: - data = { - "id": response_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [{ - "index": 0, - "delta": {"content": chunk}, - "finish_reason": None, - }], - } - yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" - - # 发送结束 chunk - final_chunk = { - "id": response_id, + chunk = { + "id": context.get( + "response_id", f"chatcmpl-{uuid.uuid4().hex[:8]}" + ), "object": "chat.completion.chunk", - "created": created, - "model": model, + "created": context.get("created", int(time.time())), + "model": context.get("model", "agentrun"), "choices": [{ "index": 0, - "delta": {}, - "finish_reason": "stop", + "delta": delta, + "finish_reason": finish_reason, }], } - yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" - - # 发送结束标记 - yield "data: [DONE]\n\n" + json_str = json.dumps(chunk, ensure_ascii=False) + return f"data: {json_str}\n\n" - def _wrap_string_response( - self, content: str, request: AgentRequest - ) -> AgentResponse: - """将字符串包装成 AgentResponse - - Args: - content: 响应内容字符串 - request: 原始请求 - - Returns: - AgentResponse: 包装后的响应对象 - """ - return AgentResponse(content=content) - - def _ensure_openai_format( - self, response: AgentResponse, request: AgentRequest + def _format_non_stream( + self, + events: List[AgentEvent], + context: Dict[str, Any], ) -> Dict[str, Any]: - """确保 AgentResponse 符合 OpenAI 格式 + """将 AgentEvent 列表转换为 OpenAI 非流式响应 - 如果用户只填充了 content,自动补充 OpenAI 必需字段。 - 如果用户已填充完整字段,直接使用。 + 自动追踪工具调用状态。 Args: - response: Agent 返回的响应对象 - request: 原始请求 + events: AgentEvent 列表 + context: 上下文信息 Returns: - Dict: OpenAI 格式的响应字典 - """ - # 如果用户只提供了 content,构造完整的 OpenAI 格式 - if response.content and not response.choices: - return { - "id": response.id or f"chatcmpl-{int(time.time() * 1000)}", - "object": response.object or "chat.completion", - "created": response.created or int(time.time()), - "model": response.model or request.model or "agentrun-model", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": response.content, - }, - "finish_reason": "stop", - }], - "usage": ( - json.loads(response.usage.model_dump_json()) - if response.usage - else None - ), - } - - # 用户提供了完整字段,使用 JSON 序列化避免对象嵌套问题 - json_str = response.model_dump_json(exclude_none=True) - result = json.loads(json_str) - - # 确保必需字段存在 - if "id" not in result: - result["id"] = f"chatcmpl-{int(time.time() * 1000)}" - if "object" not in result: - result["object"] = "chat.completion" - if "created" not in result: - result["created"] = int(time.time()) - if "model" not in result: - result["model"] = request.model or "agentrun-model" - - # 移除 content 和 extra (OpenAI 格式中不需要) - result.pop("content", None) - result.pop("extra", None) - - return result - - def _is_custom_stream_wrapper(self, obj: Any) -> bool: - """检查是否是 Model Service 的 CustomStreamWrapper""" - # CustomStreamWrapper 的特征 - return ( - hasattr(obj, "__aiter__") - and type(obj).__name__ == "CustomStreamWrapper" - ) - - async def _format_model_stream( - self, stream_wrapper: Any, request: AgentRequest - ) -> AsyncIterator[str]: - """格式化 Model Service 的流式响应 - - CustomStreamWrapper 返回的 chunk 已经是完整的 OpenAI 格式对象。 + OpenAI 格式的响应字典 """ - async for chunk in stream_wrapper: - # chunk 是 litellm 的 ModelResponse 或字典 - if isinstance(chunk, dict): - # 已经是字典,直接格式化为 SSE - yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" - elif hasattr(chunk, "model_dump"): - # Pydantic 对象 - chunk_dict = chunk.model_dump(exclude_none=True) - yield f"data: {json.dumps(chunk_dict, ensure_ascii=False)}\n\n" - elif hasattr(chunk, "dict"): - # 旧版 Pydantic - chunk_dict = chunk.dict(exclude_none=True) - yield f"data: {json.dumps(chunk_dict, ensure_ascii=False)}\n\n" - else: - # 手动转换对象为字典 - chunk_dict = { - "id": getattr( - chunk, "id", f"chatcmpl-{int(time.time() * 1000)}" - ), - "object": getattr(chunk, "object", "chat.completion.chunk"), - "created": getattr(chunk, "created", int(time.time())), - "model": getattr( - chunk, "model", request.model or "agentrun-model" - ), - "choices": [], - } - - if hasattr(chunk, "choices"): - for choice in chunk.choices: - choice_dict = { - "index": getattr(choice, "index", 0), - "finish_reason": getattr( - choice, "finish_reason", None - ), + content_parts: List[str] = [] + # 工具调用状态:{tool_id: {id, name, arguments}} + tool_call_map: Dict[str, Dict[str, Any]] = {} + has_tool_calls = False + + for event in events: + if event.event == EventType.TEXT: + content_parts.append(event.data.get("delta", "")) + + elif event.event == EventType.TOOL_CALL_CHUNK: + tool_id = event.data.get("id", "") + tool_name = event.data.get("name", "") + args_delta = event.data.get("args_delta", "") + + if tool_id: + if tool_id not in tool_call_map: + tool_call_map[tool_id] = { + "id": tool_id, + "type": "function", + "function": {"name": tool_name, "arguments": ""}, } + has_tool_calls = True - if hasattr(choice, "delta"): - delta = choice.delta - delta_dict = {} - if hasattr(delta, "role") and delta.role: - delta_dict["role"] = delta.role - if hasattr(delta, "content") and delta.content: - delta_dict["content"] = delta.content - if ( - hasattr(delta, "tool_calls") - and delta.tool_calls - ): - delta_dict["tool_calls"] = delta.tool_calls - choice_dict["delta"] = delta_dict - - chunk_dict["choices"].append(choice_dict) - - yield f"data: {json.dumps(chunk_dict, ensure_ascii=False)}\n\n" - - # 发送结束标记 - yield "data: [DONE]\n\n" - - async def _format_stream_response( - self, result: AgentResult, request: AgentRequest - ) -> AsyncIterator[str]: - """格式化流式响应 + if args_delta: + tool_call_map[tool_id]["function"][ + "arguments" + ] += args_delta - Args: - result: 流式迭代器,支持: - - Iterator[str]/AsyncIterator[str]: 流式字符串 - - Iterator[AgentStreamResponse]: 流式响应对象 - - CustomStreamWrapper: Model Service 流式响应 - request: 原始请求 + # 构建响应 + content = "".join(content_parts) if content_parts else None + finish_reason = "tool_calls" if has_tool_calls else "stop" - Yields: - SSE 格式的数据行 - """ - # 检查是否是 CustomStreamWrapper (Model Service 流式响应) - if self._is_custom_stream_wrapper(result): - async for chunk in self._format_model_stream(result, request): - yield chunk - return - - response_id = f"chatcmpl-{int(time.time() * 1000)}" - created = int(time.time()) - model = request.model or "agentrun-model" - - # 检查是否是异步迭代器 - if hasattr(result, "__aiter__"): - first_chunk = True - async for chunk in result: # type: ignore - # 如果是字符串,包装成 AgentStreamResponse - if isinstance(chunk, str): - if first_chunk: - # 第一个 chunk: 发送 role - yield self._format_sse_chunk( - AgentStreamResponse( - id=response_id, - created=created, - model=model, - choices=[ - AgentStreamResponseChoice( - index=0, - delta=AgentStreamResponseDelta( - role=MessageRole.ASSISTANT, - ), - finish_reason=None, - ) - ], - ) - ) - first_chunk = False - - # 发送内容 chunk - if chunk: # 跳过空字符串 - yield self._format_sse_chunk( - AgentStreamResponse( - id=response_id, - created=created, - model=model, - choices=[ - AgentStreamResponseChoice( - index=0, - delta=AgentStreamResponseDelta( - content=chunk - ), - finish_reason=None, - ) - ], - ) - ) + message: Dict[str, Any] = { + "role": "assistant", + "content": content, + } - # 如果是 AgentStreamResponse,直接序列化 - elif isinstance(chunk, AgentStreamResponse): - yield self._format_sse_chunk(chunk) - - # 发送结束 chunk - yield self._format_sse_chunk( - AgentStreamResponse( - id=response_id, - created=created, - model=model, - choices=[ - AgentStreamResponseChoice( - index=0, - delta=AgentStreamResponseDelta(), - finish_reason="stop", - ) - ], - ) - ) - # 发送结束标记 - yield "data: [DONE]\n\n" - - # 同步迭代器 - elif hasattr(result, "__iter__"): - first_chunk = True - for chunk in result: # type: ignore - # 如果是字符串,包装成 AgentStreamResponse - if isinstance(chunk, str): - if first_chunk: - yield self._format_sse_chunk( - AgentStreamResponse( - id=response_id, - created=created, - model=model, - choices=[ - AgentStreamResponseChoice( - index=0, - delta=AgentStreamResponseDelta( - role=MessageRole.ASSISTANT, - ), - finish_reason=None, - ) - ], - ) - ) - first_chunk = False - - if chunk: - yield self._format_sse_chunk( - AgentStreamResponse( - id=response_id, - created=created, - model=model, - choices=[ - AgentStreamResponseChoice( - index=0, - delta=AgentStreamResponseDelta( - content=chunk - ), - finish_reason=None, - ) - ], - ) - ) + if tool_call_map: + message["tool_calls"] = list(tool_call_map.values()) - elif isinstance(chunk, AgentStreamResponse): - yield self._format_sse_chunk(chunk) - - # 发送结束 chunk - yield self._format_sse_chunk( - AgentStreamResponse( - id=response_id, - created=created, - model=model, - choices=[ - AgentStreamResponseChoice( - index=0, - delta=AgentStreamResponseDelta(), - finish_reason="stop", - ) - ], - ) - ) - yield "data: [DONE]\n\n" + response = { + "id": context.get( + "response_id", f"chatcmpl-{uuid.uuid4().hex[:12]}" + ), + "object": "chat.completion", + "created": context.get("created", int(time.time())), + "model": context.get("model", "agentrun"), + "choices": [{ + "index": 0, + "message": message, + "finish_reason": finish_reason, + }], + } - else: - raise TypeError( - "Expected Iterator or AsyncIterator for stream response, " - f"got {type(result)}" - ) + return response - def _format_sse_chunk(self, chunk: AgentStreamResponse) -> str: - """格式化单个 SSE chunk + def _apply_addition( + self, + delta: Dict[str, Any], + addition: Optional[Dict[str, Any]], + merge_options: Optional[MergeOptions] = None, + ) -> Dict[str, Any]: + """应用 addition 字段 Args: - chunk: AgentStreamResponse 对象 + delta: 原始 delta 数据 + addition: 附加字段 + merge_options: 合并选项,透传给 utils.helper.merge Returns: - SSE 格式的字符串 + 合并后的 delta 数据 """ - # 使用 Pydantic 的 JSON 序列化,自动处理所有嵌套对象 - json_str = chunk.model_dump_json(exclude_none=True) - json_data = json.loads(json_str) - - # 如果用户只提供了 content,转换为 OpenAI 格式 - if "content" in json_data and "choices" not in json_data: - json_data = { - "id": json_data.get( - "id", f"chatcmpl-{int(time.time() * 1000)}" - ), - "object": json_data.get("object", "chat.completion.chunk"), - "created": json_data.get("created", int(time.time())), - "model": json_data.get("model", "agentrun-model"), - "choices": [{ - "index": 0, - "delta": {"content": json_data["content"]}, - "finish_reason": None, - }], - } - else: - # 移除不属于 OpenAI 格式的字段 - json_data.pop("content", None) - json_data.pop("extra", None) + if not addition: + return delta - return f"data: {json.dumps(json_data, ensure_ascii=False)}\n\n" + return merge(delta, addition, **(merge_options or {})) diff --git a/agentrun/server/protocol.py b/agentrun/server/protocol.py index 3452a62..923028c 100644 --- a/agentrun/server/protocol.py +++ b/agentrun/server/protocol.py @@ -1,40 +1,57 @@ """协议抽象层 / Protocol Abstraction Layer -定义协议接口,支持未来扩展多种协议格式(OpenAI, Anthropic, Google 等)。 -Defines protocol interfaces, supporting future expansion of various protocol formats (OpenAI, Anthropic, Google, etc.). +定义协议接口,支持多种协议格式(OpenAI, AG-UI 等)。 -基于 Router 的设计 / Router-based design: -- 每个协议提供自己的 FastAPI Router / Each protocol provides its own FastAPI Router -- Server 负责挂载 Router 并管理路由前缀 / Server mounts Routers and manages route prefixes -- 协议完全自治,无需向 Server 声明接口 / Protocols are fully autonomous, no need to declare interfaces to Server +基于 Router 的设计: +- 每个协议提供自己的 FastAPI Router +- Server 负责挂载 Router 并管理路由前缀 +- 协议完全自治,无需向 Server 声明接口 """ from abc import ABC, abstractmethod -from typing import Awaitable, Callable, TYPE_CHECKING, Union +from typing import Any, Awaitable, Callable, Dict, TYPE_CHECKING, Union -from .model import AgentRequest, AgentResult +from .model import AgentRequest, AgentReturnType if TYPE_CHECKING: - from fastapi import APIRouter + from fastapi import APIRouter, Request from .invoker import AgentInvoker +# ============================================================================ +# 协议处理器基类 +# ============================================================================ + + class ProtocolHandler(ABC): """协议处理器基类 / Protocol Handler Base Class - 基于 Router 的设计 / Router-based design: - 协议通过 as_fastapi_router() 方法提供完整的路由定义,包括所有端点、请求处理、响应格式化等。 - Protocol provides complete route definitions through as_fastapi_router() method, including all endpoints, request handling, response formatting, etc. - - Server 只需挂载 Router 并管理路由前缀,无需了解协议细节。 - Server only needs to mount Router and manage route prefixes, without knowing protocol details. + 基于 Router 的设计: + 协议通过 as_fastapi_router() 方法提供完整的路由定义, + 包括所有端点、请求处理、响应格式化等。 + + Server 只需挂载 Router 并管理路由前缀,无需了解协议细节。 + + Example: + >>> class MyProtocolHandler(ProtocolHandler): + ... name = "my_protocol" + ... + ... def as_fastapi_router(self, agent_invoker): + ... router = APIRouter() + ... + ... @router.post("/run") + ... async def run(request: Request): + ... ... + ... + ... return router """ + name: str + @abstractmethod def as_fastapi_router(self, agent_invoker: "AgentInvoker") -> "APIRouter": - """ - 将协议转换为 FastAPI Router + """将协议转换为 FastAPI Router 协议自己决定: - 有哪些端点 @@ -43,54 +60,81 @@ def as_fastapi_router(self, agent_invoker: "AgentInvoker") -> "APIRouter": - 请求/响应处理 Args: - agent_invoker: Agent 调用器,用于执行用户的 invoke_agent + agent_invoker: Agent 调用器,用于执行用户的 invoke_agent Returns: - APIRouter: FastAPI 路由器,包含该协议的所有端点 - - Example: - ```python - def as_fastapi_router(self, agent_invoker): - router = APIRouter() - - @router.post("/chat/completions") - async def chat_completions(request: Request): - data = await request.json() - agent_request = parse_request(data) - result = await agent_invoker.invoke(agent_request) - return format_response(result) - - return router - ``` + APIRouter: FastAPI 路由器,包含该协议的所有端点 """ pass def get_prefix(self) -> str: - """ - 获取协议建议的路由前缀 + """获取协议建议的路由前缀 - Server 会优先使用用户指定的前缀,如果没有指定则使用此建议值。 + Server 会优先使用用户指定的前缀,如果没有指定则使用此建议值。 Returns: - str: 建议的前缀,如 "/v1" 或 "" + str: 建议的前缀,如 "/v1" 或 "" Example: - - OpenAI 协议: "/v1" - - Anthropic 协议: "/anthropic" + - OpenAI 协议: "/openai/v1" + - AG-UI 协议: "/agui/v1" - 无前缀: "" """ return "" +class BaseProtocolHandler(ProtocolHandler): + """协议处理器扩展基类 / Extended Protocol Handler Base Class + + 提供通用的请求解析和响应格式化逻辑。 + 子类需要实现具体的协议转换。 + """ + + async def parse_request( + self, + request: "Request", + request_data: Dict[str, Any], + ) -> tuple[AgentRequest, Dict[str, Any]]: + """解析 HTTP 请求为 AgentRequest + + 子类应该重写此方法来实现协议特定的解析逻辑。 + + Args: + request: FastAPI Request 对象 + request_data: 请求体 JSON 数据 + + Returns: + tuple: (AgentRequest, context) + - AgentRequest: 标准化的请求对象 + - context: 协议特定的上下文信息 + """ + raise NotImplementedError("Subclass must implement parse_request") + + def _is_iterator(self, obj: Any) -> bool: + """检查对象是否是迭代器 + + Args: + obj: 要检查的对象 + + Returns: + bool: 是否是迭代器 + """ + return ( + hasattr(obj, "__iter__") + and not isinstance(obj, (str, bytes, dict, list)) + ) or hasattr(obj, "__aiter__") + + +# ============================================================================ # Handler 类型定义 -# 同步 handler: 普通函数,直接返回 AgentResult -SyncInvokeAgentHandler = Callable[[AgentRequest], AgentResult] +# ============================================================================ + + +# 同步 handler: 返回 AgentReturnType +SyncInvokeAgentHandler = Callable[[AgentRequest], AgentReturnType] -# 异步 handler: 协程函数,返回 Awaitable[AgentResult] -AsyncInvokeAgentHandler = Callable[[AgentRequest], Awaitable[AgentResult]] +# 异步 handler: 返回 Awaitable[AgentReturnType] +AsyncInvokeAgentHandler = Callable[[AgentRequest], Awaitable[AgentReturnType]] -# 通用 handler: 可以是同步或异步 -InvokeAgentHandler = Union[ - SyncInvokeAgentHandler, - AsyncInvokeAgentHandler, -] +# 通用 handler: 同步或异步 +InvokeAgentHandler = Union[SyncInvokeAgentHandler, AsyncInvokeAgentHandler] diff --git a/agentrun/server/server.py b/agentrun/server/server.py index 1b6e3e9..cfacef6 100644 --- a/agentrun/server/server.py +++ b/agentrun/server/server.py @@ -1,19 +1,21 @@ """AgentRun HTTP Server / AgentRun HTTP 服务器 -基于 Router 的设计 / Router-based design: -- 每个协议提供自己的 Router / Each protocol provides its own Router -- Server 负责挂载 Router 并管理路由前缀 / Server mounts Routers and manages route prefixes -- 支持多协议同时运行 / Supports running multiple protocols simultaneously +基于 Router 的设计: +- 每个协议提供自己的 Router +- Server 负责挂载 Router 并管理路由前缀 +- 支持多协议同时运行(OpenAI + AG-UI) """ -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional, Sequence from fastapi import FastAPI import uvicorn from agentrun.utils.log import logger +from .agui_protocol import AGUIProtocolHandler from .invoker import AgentInvoker +from .model import ServerConfig from .openai_protocol import OpenAIProtocolHandler from .protocol import InvokeAgentHandler, ProtocolHandler @@ -21,79 +23,131 @@ class AgentRunServer: """AgentRun HTTP Server / AgentRun HTTP 服务器 - 基于 Router 的架构 / Router-based architecture: - - 每个协议提供完整的 FastAPI Router / Each protocol provides a complete FastAPI Router - - Server 只负责组装和前缀管理 / Server only handles assembly and prefix management - - 易于扩展新协议 / Easy to extend with new protocols + 基于 Router 的架构: + - 每个协议提供完整的 FastAPI Router + - Server 只负责组装和前缀管理 + - 易于扩展新协议 - Example (默认 OpenAI 协议 / Default OpenAI protocol): + Example (最简单用法): >>> def invoke_agent(request: AgentRequest): ... return "Hello, world!" >>> >>> server = AgentRunServer(invoke_agent=invoke_agent) >>> server.start(port=8000) - # 可访问 / Accessible: POST http://localhost:8000/v1/chat/completions + # 可访问: + # POST http://localhost:8000/openai/v1/chat/completions (OpenAI) + # POST http://localhost:8000/agui/v1/run (AG-UI) + + Example (流式输出): + >>> async def invoke_agent(request: AgentRequest): + ... yield "Hello, " + ... yield "world!" + >>> + >>> server = AgentRunServer(invoke_agent=invoke_agent) + >>> server.start(port=8000) - Example (自定义前缀 / Custom prefix): - >>> server = AgentRunServer( - ... invoke_agent=invoke_agent, - ... prefix_overrides={"OpenAIProtocolHandler": "/api/v1"} - ... ) + Example (使用事件): + >>> from agentrun.server import AgentResult, EventType + >>> + >>> async def invoke_agent(request: AgentRequest): + ... yield AgentResult( + ... event=EventType.STEP_STARTED, + ... data={"step_name": "thinking"} + ... ) + ... yield "I'm thinking..." + ... yield AgentResult( + ... event=EventType.STEP_FINISHED, + ... data={"step_name": "thinking"} + ... ) + >>> + >>> server = AgentRunServer(invoke_agent=invoke_agent) >>> server.start(port=8000) - # 可访问 / Accessible: POST http://localhost:8000/api/v1/chat/completions - Example (多协议 / Multiple protocols): + Example (仅 OpenAI 协议): >>> server = AgentRunServer( ... invoke_agent=invoke_agent, - ... protocols=[ - ... OpenAIProtocolHandler(), - ... CustomProtocolHandler(), - ... ] + ... protocols=[OpenAIProtocolHandler()] ... ) >>> server.start(port=8000) - Example (集成到现有 FastAPI 应用 / Integrate with existing FastAPI app): + Example (集成到现有 FastAPI 应用): >>> from fastapi import FastAPI >>> >>> app = FastAPI() >>> agent_server = AgentRunServer(invoke_agent=invoke_agent) >>> app.mount("/agent", agent_server.as_fastapi_app()) - # 可访问 / Accessible: POST http://localhost:8000/agent/v1/chat/completions + # 可访问: POST http://localhost:8000/agent/openai/v1/chat/completions + + Example (配置 CORS): + >>> server = AgentRunServer( + ... invoke_agent=invoke_agent, + ... config=ServerConfig(cors_origins=["http://localhost:3000"]) + ... ) """ def __init__( self, invoke_agent: InvokeAgentHandler, protocols: Optional[List[ProtocolHandler]] = None, - prefix_overrides: Optional[Dict[str, str]] = None, + config: Optional[ServerConfig] = None, ): - """初始化 AgentRun Server / Initialize AgentRun Server + """初始化 AgentRun Server Args: - invoke_agent: Agent 调用回调函数 / Agent invocation callback function - - 可以是同步或异步函数 / Can be synchronous or asynchronous function - - 支持返回字符串、AgentResponse 或生成器 / Supports returning string, AgentResponse or generator + invoke_agent: Agent 调用回调函数 + - 可以是同步或异步函数 + - 支持返回字符串或 AgentResult + - 支持使用 yield 进行流式输出 - protocols: 协议处理器列表 / List of protocol handlers - - 默认使用 OpenAI 协议 / Default uses OpenAI protocol - - 可以添加自定义协议 / Can add custom protocols + protocols: 协议处理器列表 + - 默认使用 OpenAI + AG-UI 协议 + - 可以添加自定义协议 - prefix_overrides: 协议前缀覆盖 / Protocol prefix overrides - - 格式 / Format: {协议类名 / protocol class name: 前缀 / prefix} - - 例如 / Example: {"OpenAIProtocolHandler": "/api/v1"} + config: 服务器配置 + - cors_origins: CORS 允许的源列表 + - openai: OpenAI 协议配置 + - agui: AG-UI 协议配置 """ self.app = FastAPI(title="AgentRun Server") self.agent_invoker = AgentInvoker(invoke_agent) - # 默认使用 OpenAI 协议 - if protocols is None: - protocols = [OpenAIProtocolHandler()] + # 配置 CORS + self._setup_cors(config.cors_origins if config else None) - self.prefix_overrides = prefix_overrides or {} + # 默认使用 OpenAI 和 AG-UI 协议 + if protocols is None: + protocols = [ + OpenAIProtocolHandler(config), + AGUIProtocolHandler(config), + ] # 挂载所有协议的 Router self._mount_protocols(protocols) + def _setup_cors(self, cors_origins: Optional[Sequence[str]] = None): + """配置 CORS 中间件 + + Args: + cors_origins: 允许的源列表,默认为 ["*"] 允许所有源 + """ + if not cors_origins: + return + + from fastapi.middleware.cors import CORSMiddleware + + origins = list(cors_origins) if cors_origins else ["*"] + + self.app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], + ) + + logger.debug(f"CORS 已启用,允许的源: {origins}") + def _mount_protocols(self, protocols: List[ProtocolHandler]): """挂载所有协议的路由 @@ -104,49 +158,17 @@ def _mount_protocols(self, protocols: List[ProtocolHandler]): # 获取协议的 Router router = protocol.as_fastapi_router(self.agent_invoker) - # 确定路由前缀 - prefix = self._get_protocol_prefix(protocol) + # 使用协议定义的前缀 + prefix = protocol.get_prefix() # 挂载到主应用 self.app.include_router(router, prefix=prefix) - logger.info( - f"✅ 已挂载协议: {protocol.__class__.__name__} ->" + logger.debug( + f"已挂载协议: {protocol.__class__.__name__} ->" f" {prefix or '(无前缀)'}" ) - def _get_protocol_prefix(self, protocol: ProtocolHandler) -> str: - """获取协议的路由前缀 - - 优先级: - 1. 用户指定的覆盖前缀 - 2. 协议自己的建议前缀 - 3. 基于协议类名的默认前缀 - - Args: - protocol: 协议处理器 - - Returns: - str: 路由前缀 - """ - protocol_name = protocol.__class__.__name__ - - # 1. 检查用户覆盖 - if protocol_name in self.prefix_overrides: - return self.prefix_overrides[protocol_name] - - # 2. 使用协议建议 - suggested_prefix = protocol.get_prefix() - if suggested_prefix: - return suggested_prefix - - # 3. 默认前缀(基于类名) - # OpenAIProtocolHandler -> /openai - name_without_handler = protocol_name.replace( - "ProtocolHandler", "" - ).replace("Handler", "") - return f"/{name_without_handler.lower()}" - def start( self, host: str = "0.0.0.0", @@ -157,18 +179,12 @@ def start( """启动 HTTP 服务器 Args: - host: 监听地址,默认 0.0.0.0 - port: 监听端口,默认 9000 - log_level: 日志级别,默认 info + host: 监听地址,默认 0.0.0.0 + port: 监听端口,默认 9000 + log_level: 日志级别,默认 info **kwargs: 传递给 uvicorn.run 的其他参数 """ - logger.info(f"🚀 启动 AgentRun Server: http://{host}:{port}") - - # 打印路由信息 - # for route in self.app.routes: - # if hasattr(route, "methods") and hasattr(route, "path"): - # methods = ", ".join(route.methods) # type: ignore - # logger.info(f" {methods:10} {route.path}") # type: ignore + logger.info(f"启动 AgentRun Server: http://{host}:{port}") uvicorn.run( self.app, host=host, port=port, log_level=log_level, **kwargs diff --git a/agentrun/utils/helper.py b/agentrun/utils/helper.py index 92981bd..c1938f1 100644 --- a/agentrun/utils/helper.py +++ b/agentrun/utils/helper.py @@ -4,7 +4,9 @@ This module provides general utility functions. """ -from typing import Optional +from typing import Any, Optional + +from typing_extensions import NotRequired, TypedDict, Unpack def mask_password(password: Optional[str]) -> str: @@ -32,3 +34,75 @@ def mask_password(password: Optional[str]) -> str: if len(password) <= 4: return password[0] + "*" * (len(password) - 2) + password[-1] return password[0:2] + "*" * (len(password) - 4) + password[-2:] + + +class MergeOptions(TypedDict): + concat_list: NotRequired[bool] + no_new_field: NotRequired[bool] + ignore_empty_list: NotRequired[bool] + + +def merge(a: Any, b: Any, **args: Unpack[MergeOptions]) -> Any: + """通用合并函数 / Generic deep merge helper. + + 合并规则概览: + - 若 ``b`` 为 ``None``: 返回 ``a`` + - 若 ``a`` 为 ``None``: 返回 ``b`` + - ``dict``: 递归按 key 深度合并 + - ``list``: 连接列表 ``a + b`` + - ``tuple``: 连接元组 ``a + b`` + - ``set``/``frozenset``: 取并集 + - 具有 ``__dict__`` 的同类型对象: 按属性字典递归合并后构造新实例 + - 其他类型: 直接返回 ``b`` (视为覆盖) + """ + + # None 合并: 保留非 None 一方 + if b is None: + return a + if a is None: + return b + + # dict 深度合并 + if isinstance(a, dict) and isinstance(b, dict): + result: dict[Any, Any] = dict(a) + for key, value in b.items(): + if key in result: + result[key] = merge(result[key], value, **args) + else: + if args.get("no_new_field", False): + continue + result[key] = value + return result + + # list 合并: 连接 + if isinstance(a, list) and isinstance(b, list): + if args.get("concat_list", False): + return [*a, *b] + if args.get("ignore_empty_list", False): + if len(b) == 0: + return a + return b + + # tuple 合并: 连接 + if isinstance(a, tuple) and isinstance(b, tuple): + return (*a, *b) + + # set / frozenset: 并集 + if isinstance(a, set) and isinstance(b, set): + return a | b + if isinstance(a, frozenset) and isinstance(b, frozenset): + return a | b + + # 同类型且具备 __dict__ 的对象: 按属性递归合并, 就地更新 a + if type(a) is type(b) and hasattr(a, "__dict__") and hasattr(b, "__dict__"): + for key, value in b.__dict__.items(): + if key in a.__dict__: + setattr(a, key, merge(getattr(a, key), value, **args)) + else: + if args.get("no_new_field", False): + continue + setattr(a, key, value) + return a + + # 其他情况: 视为覆盖, 返回 b + return b diff --git a/examples/quick_start.py b/examples/quick_start.py index 716731e..dcb4f94 100644 --- a/examples/quick_start.py +++ b/examples/quick_start.py @@ -1,18 +1,28 @@ +"""AgentRun Server 快速开始示例 + +curl http://127.0.0.1:9000/openai/v1/chat/completions -X POST \ + -H "Content-Type: application/json" \ + -d '{"messages": [{"role": "user", "content": "写一段代码,查询现在是几点?"}], "stream": true}' +""" + +import os from typing import Any from langchain.agents import create_agent import pydash from agentrun.integration.langchain import model, sandbox_toolset +from agentrun.integration.langgraph.agent_converter import AgentRunConverter from agentrun.sandbox import TemplateType from agentrun.server import AgentRequest, AgentRunServer +from agentrun.server.model import ServerConfig from agentrun.utils.log import logger # 请替换为您已经创建的 模型 和 沙箱 名称 -MODEL_NAME = "" +AGENTRUN_MODEL_NAME = os.getenv("AGENTRUN_MODEL_NAME", "") SANDBOX_NAME = "" -if MODEL_NAME.startswith("<"): +if AGENTRUN_MODEL_NAME.startswith("<") or not AGENTRUN_MODEL_NAME: raise ValueError("请将 MODEL_NAME 替换为您已经创建的模型名称") code_interpreter_tools = [] @@ -25,45 +35,54 @@ else: logger.warning("SANDBOX_NAME 未设置或未替换,跳过加载沙箱工具。") + +def get_weather_tool(): + """ + 获取天气工具""" + import time + + logger.debug("调用获取天气工具") + time.sleep(5) + return {"weather": "晴天,25度"} + + agent = create_agent( - model=model(MODEL_NAME), + model=model(AGENTRUN_MODEL_NAME), tools=[ *code_interpreter_tools, + get_weather_tool, ], system_prompt="你是一个 AgentRun 的 AI 专家,可以通过沙箱运行代码来回答用户的问题。", ) -def invoke_agent(request: AgentRequest): - content = request.messages[0].content - input: Any = {"messages": [{"role": "user", "content": content}]} - - try: - if request.stream: +async def invoke_agent(request: AgentRequest): + input: Any = { + "messages": [ + {"role": msg.role, "content": msg.content} + for msg in request.messages + ] + } - def stream_generator(): - result = agent.stream(input, stream_mode="messages") - for chunk in result: - yield pydash.get(chunk, "[0].content") + converter = AgentRunConverter() + if request.stream: - return stream_generator() - else: - result = agent.invoke(input) - return pydash.get(result, "messages.-1.content") - except Exception as e: - import traceback + async def async_generator(): + async for event in agent.astream(input, stream_mode="updates"): + for item in converter.convert(event): + yield item - traceback.print_exc() - logger.error("调用出错: %s", e) - raise e + return async_generator() + else: + result = await agent.ainvoke(input) + return pydash.get(result, "messages[-1].content", "") -AgentRunServer(invoke_agent=invoke_agent).start() -""" -curl 127.0.0.1:9000/openai/v1/chat/completions -XPOST \ - -H "content-type: application/json" \ - -d '{ - "messages": [{"role": "user", "content": "写一段代码,查询现在是几点?"}], - "stream":true - }' -""" +AgentRunServer( + invoke_agent=invoke_agent, + config=ServerConfig( + cors_origins=[ + "*" + ] # 部署在 AgentRun 上时,AgentRun 已经自动为你处理了跨域问题,可以省略这一行 + ), +).start() diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..26b7603 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,99 @@ +# mypy configuration +# 临时宽松配置 —— 用于在存在大量历史类型错误的仓库中逐步推进类型检查。 + +[mypy] +python_version = 3.10 + +# 排除常见构建与元数据目录,避免重复模块/生成物导致的冲突 +exclude = (^|/)(build|dist|\.venv|\.git|\.eggs|agentrun_sdk\.egg-info)(/|$) + +# ---------------------------- +# mypy 严格性开关(仅列一次) +# 下面通过关闭若干易产生历史噪音的检查项,保留对关键问题的基本提示。 +# 可在逐步修复过程中逐项开启以提升严格度。 +# ---------------------------- +# check_untyped_defs: 即使函数体未注明类型,也会检查函数内部(会产生大量噪音) +check_untyped_defs = False + +# disallow_untyped_defs: 禁止未添加类型注解的函数/方法(开启后会将未注解函数视为错误) +disallow_untyped_defs = False + +# disallow_untyped_calls: 禁止对未注解的可调用对象进行调用,以防 Any 传播 +disallow_untyped_calls = False + +# disallow_incomplete_defs: 禁止部分缺失注解(如参数或返回类型缺失) +disallow_incomplete_defs = False + +# disallow_untyped_decorators: 禁止使用未注解的装饰器 +disallow_untyped_decorators = False + +# warn_return_any: 当返回值被推断为 Any 时发出警告(关闭可减少噪音) +warn_return_any = False + +# warn_unused_configs: 报告配置中未被使用的选项,便于清理无效配置 +warn_unused_configs = True + +# warn_redundant_casts: 检测不必要或冗余的 cast 操作 +warn_redundant_casts = False + +# warn_unused_ignores: 检测无用的 "# type: ignore" 注释 +warn_unused_ignores = False + +# warn_unreachable: 检测不可达代码分支 +warn_unreachable = False + +# ignore_missing_imports: 忽略第三方库缺少类型桩(stubs)的导入错误 +ignore_missing_imports = True + +# strict_equality: 更严格的相等性检查(开启需对类型定义更精确) +strict_equality = False + +# no_implicit_optional: 禁止隐式 Optional,需要显式使用 Optional[...] 注解 +no_implicit_optional = False + +# ---------------------------- +# 映射说明(便于审查): +# - `disallow_untyped_defs` 关闭 -> 忽略大量 "no-untyped-def" 报错(函数未注解) +# - `warn_return_any` 关闭 -> 减少因返回 Any 导致的噪音警告 +# - `warn_unused_ignores` 关闭 -> 忽略历史性的不必要 `# type: ignore` 报告 +# - `ignore_missing_imports` 打开 -> 忽略第三方库缺少类型桩的导入错误 +# 若需要更精细的按错误类型过滤,请告知,我可实现输出过滤脚本或逐文件添加 `# type: ignore[code]`。 +# ---------------------------- + +# ---------------------------- +# 全局禁用的 mypy 错误代码(按用户要求忽略以下报错类别) +# 列表:annotation-unchecked, arg-type, assignment, attr-defined, +# call-arg, empty-body, has-type, import-untyped, misc, no-redef, +# return-value, union-attr, valid-type, var-annotated +# 说明:mypy 支持通过 `disable_error_code` 在配置中禁用特定错误代码。 +# 若你的 mypy 版本过旧不支持该项,我们可以改为输出过滤脚本或 `per-file-ignores`。 +# ---------------------------- +disable_error_code = + # annotation-unchecked: 未对未添加注解的函数体进行检查(通常由未开启 check_untyped_defs 导致) + annotation-unchecked, + # arg-type: 调用时传入参数类型与被调用者期望类型不匹配 + arg-type, + # assignment: 赋值操作中类型不兼容 + assignment, + # attr-defined: 访问的属性在类型上未定义或不可确定 + attr-defined, + # call-arg: 调用时缺少必需参数或多传/传错参数名 + call-arg, + # empty-body: 函数/类体为空,可能缺少实现 + empty-body, + # has-type: 无法确定表达式或变量的具体类型 + has-type, + # import-untyped: 导入的第三方库缺少类型桩(stubs) + import-untyped, + # misc: 其他杂项错误 + misc, + # no-redef: 名称重复定义或重定义冲突 + no-redef, + # return-value: 返回值类型与注解不匹配或返回缺失 + return-value, + # union-attr: 在联合类型上的属性访问可能不存在于所有分支 + union-attr, + # valid-type: 无效的类型声明或使用(例如变量被误用为类型) + valid-type, + # var-annotated: 变量缺少必要的类型注解 + var-annotated diff --git a/pyproject.toml b/pyproject.toml index 3e8f7ff..44b5edf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ server = [ "fastapi>=0.104.0", "uvicorn>=0.24.0", + "ag-ui-protocol>=0.1.10", ] langchain = [ diff --git a/tests/e2e/integration/langchain/test_agent_invoke_methods.py b/tests/e2e/integration/langchain/test_agent_invoke_methods.py new file mode 100644 index 0000000..f1293ac --- /dev/null +++ b/tests/e2e/integration/langchain/test_agent_invoke_methods.py @@ -0,0 +1,1174 @@ +""" +# cspell:ignore chatcmpl ASGI nonstream +AgentRunServer 集成测试 - 测试不同的 LangChain/LangGraph 调用方式 + +测试覆盖: +- astream_events: 使用 agent.astream_events(input, version="v2") +- astream: 使用 agent.astream(input, stream_mode="updates") +- stream: 使用 agent.stream(input, stream_mode="updates") +- invoke: 使用 agent.invoke(input) +- ainvoke: 使用 agent.ainvoke(input) + +每种方式都测试: +1. 纯文本生成场景 +2. 工具调用场景 +""" + +import json +import socket +import threading +import time +from typing import Any, cast, Dict, List, Sequence, Union +import uuid + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse +import httpx +import pytest +import uvicorn + +from agentrun.integration.langchain import model +from agentrun.integration.langgraph import AgentRunConverter +from agentrun.model import ModelService, ModelType, ProviderSettings +from agentrun.server import AgentRequest, AgentRunServer + +# ============================================================================= +# 配置 +# ============================================================================= + +# ============================================================================= +# 工具定义 +# ============================================================================= + + +def get_weather(city: str) -> Dict[str, Any]: + """获取指定城市的天气信息 + + Args: + city: 城市名称 + + Returns: + 包含天气信息的字典 + """ + return {"city": city, "weather": "晴天", "temperature": 25} + + +def get_time() -> str: + """获取当前时间 + + Returns: + 当前时间字符串 + """ + return "2024-01-01 12:00:00" + + +TOOLS = [get_weather, get_time] + + +# ============================================================================= +# 辅助函数 +# ============================================================================= + + +def _find_free_port() -> int: + """获取可用的本地端口""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _sse(data: Dict[str, Any]) -> str: + return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def _build_mock_openai_app() -> FastAPI: + """构建本地 OpenAI 协议兼容的简单服务""" + app = FastAPI() + + def _decide_action(messages: List[Dict[str, Any]]): + for msg in reversed(messages): + if msg.get("role") == "tool": + return {"type": "after_tool", "content": msg.get("content", "")} + + user_msg = next( + (m for m in reversed(messages) if m.get("role") == "user"), {} + ) + content = user_msg.get("content", "") + if "天气" in content or "weather" in content: + return { + "type": "tool_call", + "tool": "get_weather", + "arguments": {"city": "北京"}, + } + if "时间" in content or "time" in content or "几点" in content: + return {"type": "tool_call", "tool": "get_time", "arguments": {}} + return {"type": "chat", "content": content or "你好,我是本地模型"} + + def _chat_response(model: str, content: str) -> Dict[str, Any]: + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + }], + } + + def _tool_response( + model: str, tool: str, arguments: Dict[str, Any] + ) -> Dict[str, Any]: + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": { + "name": tool, + "arguments": json.dumps( + arguments, ensure_ascii=False + ), + }, + }], + }, + "finish_reason": "tool_calls", + }], + } + + async def _stream_chat(model: str, content: str): + yield _sse({ + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": None, + }], + }) + yield _sse({ + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "delta": {"content": content}, + "finish_reason": None, + }], + }) + yield _sse({ + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + }) + yield "data: [DONE]\n\n" + + async def _stream_tool(model: str, tool: str, arguments: Dict[str, Any]): + yield _sse({ + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": None, + }], + }) + yield _sse({ + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "delta": { + "tool_calls": [{ + "index": 0, + "id": "call_1", + "type": "function", + "function": { + "name": tool, + "arguments": json.dumps( + arguments, ensure_ascii=False + ), + }, + }] + }, + "finish_reason": None, + }], + }) + yield _sse({ + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + {"index": 0, "delta": {}, "finish_reason": "tool_calls"} + ], + }) + yield "data: [DONE]\n\n" + + @app.get("/v1/models") + async def list_models(): + return { + "object": "list", + "data": [ + {"id": "mock-model", "object": "model", "owned_by": "local"} + ], + } + + @app.post("/v1/chat/completions") + async def chat_completions(request: Request): + body = await request.json() + stream = bool(body.get("stream", False)) + model = body.get("model", "mock-model") + messages = body.get("messages", []) + + action = _decide_action(messages) + if action["type"] == "chat": + content = action["content"] or "好的,我在本地为你服务。" + if stream: + return StreamingResponse( + _stream_chat(model, content), media_type="text/event-stream" + ) + return JSONResponse(_chat_response(model, content)) + + if action["type"] == "tool_call": + tool = action["tool"] + arguments = action["arguments"] + if stream: + return StreamingResponse( + _stream_tool(model, tool, arguments), + media_type="text/event-stream", + ) + return JSONResponse(_tool_response(model, tool, arguments)) + + # after tool result + content = f"工具结果已收到: {action['content']}" + if stream: + return StreamingResponse( + _stream_chat(model, content), media_type="text/event-stream" + ) + return JSONResponse(_chat_response(model, content)) + + return app + + +def build_agent(model_input: Union[str, Any]): + """创建测试用的 agent""" + from langchain.agents import create_agent + + return create_agent( + model=model(model_input), + tools=TOOLS, + system_prompt="你是一个测试助手。", + ) + + +def parse_sse_events(content: str) -> List[Dict[str, Any]]: + """解析 SSE 响应内容为事件列表""" + events = [] + for line in content.split("\n"): + line = line.strip() + if line.startswith("data:"): + data_str = line[5:].strip() + if data_str and data_str != "[DONE]": + try: + events.append(json.loads(data_str)) + except json.JSONDecodeError: + pass + return events + + +async def request_agui_events( + server_app, + messages: List[Dict[str, str]], + stream: bool = True, +) -> List[Dict[str, Any]]: + """发送 AG-UI 请求并返回事件列表""" + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=server_app), + base_url="http://test", + ) as client: + response = await client.post( + "/ag-ui/agent", + json={"messages": messages, "stream": stream}, + timeout=60.0, + ) + + assert response.status_code == 200 + return parse_sse_events(response.text) + + +def _index(event_types: Sequence[str], target: str) -> int: + """安全获取事件索引,未找到抛出 AssertionError""" + assert target in event_types, f"缺少事件: {target}" + return event_types.index(target) + + +def _normalize_agui_event(event: Dict[str, Any]) -> Dict[str, Any]: + """去除可变字段,仅保留语义字段""" + normalized: Dict[str, Any] = {"type": event.get("type")} + if "delta" in event: + normalized["delta"] = event["delta"] + if "result" in event: + normalized["result"] = event["result"] + if "toolCallName" in event: + normalized["toolCallName"] = event["toolCallName"] + if "role" in event: + normalized["role"] = event["role"] + if "toolCallId" in event: + normalized["hasToolCallId"] = bool(event["toolCallId"]) + if "messageId" in event: + normalized["hasMessageId"] = bool(event["messageId"]) + if "threadId" in event: + normalized["hasThreadId"] = bool(event["threadId"]) + if "runId" in event: + normalized["hasRunId"] = bool(event["runId"]) + return normalized + + +AGUI_EXPECTED = { + "text_basic": [[ + {"type": "RUN_STARTED", "hasThreadId": True, "hasRunId": True}, + { + "type": "TEXT_MESSAGE_START", + "role": "assistant", + "hasMessageId": True, + }, + { + "type": "TEXT_MESSAGE_CONTENT", + "delta": "你好,请简单介绍一下你自己", + "hasMessageId": True, + }, + {"type": "TEXT_MESSAGE_END", "hasMessageId": True}, + {"type": "RUN_FINISHED", "hasThreadId": True, "hasRunId": True}, + ]], + "tool_weather": [ + [ + {"type": "RUN_STARTED", "hasThreadId": True, "hasRunId": True}, + { + "type": "TOOL_CALL_ARGS", + "delta": '{"city": "北京"}', + "hasToolCallId": True, + }, + { + "type": "TOOL_CALL_START", + "toolCallName": "get_weather", + "hasToolCallId": True, + }, + { + "type": "TOOL_CALL_ARGS", + "delta": '{"city": "北京"}', + "hasToolCallId": True, + }, + { + "type": "TOOL_CALL_RESULT", + "result": ( + '{"city": "北京", "weather": "晴天", "temperature": 25}' + ), + "hasToolCallId": True, + }, + {"type": "TOOL_CALL_END", "hasToolCallId": True}, + { + "type": "TEXT_MESSAGE_START", + "role": "assistant", + "hasMessageId": True, + }, + { + "type": "TEXT_MESSAGE_CONTENT", + "delta": ( + '工具结果已收到: {"city": "北京", "weather": "晴天",' + ' "temperature": 25}' + ), + "hasMessageId": True, + }, + {"type": "TEXT_MESSAGE_END", "hasMessageId": True}, + {"type": "RUN_FINISHED", "hasThreadId": True, "hasRunId": True}, + ], + [ + {"type": "RUN_STARTED", "hasThreadId": True, "hasRunId": True}, + { + "type": "TOOL_CALL_START", + "toolCallName": "get_weather", + "hasToolCallId": True, + }, + { + "type": "TOOL_CALL_ARGS", + "delta": '{"city": "北京"}', + "hasToolCallId": True, + }, + { + "type": "TOOL_CALL_RESULT", + "result": ( + '{"city": "北京", "weather": "晴天", "temperature": 25}' + ), + "hasToolCallId": True, + }, + {"type": "TOOL_CALL_END", "hasToolCallId": True}, + { + "type": "TEXT_MESSAGE_START", + "role": "assistant", + "hasMessageId": True, + }, + { + "type": "TEXT_MESSAGE_CONTENT", + "delta": ( + '工具结果已收到: {"city": "北京", "weather": "晴天",' + ' "temperature": 25}' + ), + "hasMessageId": True, + }, + {"type": "TEXT_MESSAGE_END", "hasMessageId": True}, + {"type": "RUN_FINISHED", "hasThreadId": True, "hasRunId": True}, + ], + ], + "tool_time": [[ + {"type": "RUN_STARTED", "hasThreadId": True, "hasRunId": True}, + { + "type": "TOOL_CALL_START", + "toolCallName": "get_time", + "hasToolCallId": True, + }, + { + "type": "TOOL_CALL_RESULT", + "result": "2024-01-01 12:00:00", + "hasToolCallId": True, + }, + {"type": "TOOL_CALL_END", "hasToolCallId": True}, + { + "type": "TEXT_MESSAGE_START", + "role": "assistant", + "hasMessageId": True, + }, + { + "type": "TEXT_MESSAGE_CONTENT", + "delta": "工具结果已收到: 2024-01-01 12:00:00", + "hasMessageId": True, + }, + {"type": "TEXT_MESSAGE_END", "hasMessageId": True}, + {"type": "RUN_FINISHED", "hasThreadId": True, "hasRunId": True}, + ]], +} + + +def assert_agui_events_exact( + events: List[Dict[str, Any]], case_key: str +) -> None: + normalized = [_normalize_agui_event(e) for e in events] + variants = AGUI_EXPECTED[case_key] + assert ( + normalized in variants + ), f"AGUI events mismatch for {case_key}: {normalized}" + + +def normalize_openai_event_types(chunks: List[Dict[str, Any]]) -> List[str]: + """将 OpenAI 协议的流式分片转换为统一的事件类型序列""" + event_types: List[str] = [] + for chunk in chunks: + choices = chunk.get("choices") or [] + if not choices: + continue + choice = choices[0] or {} + delta = choice.get("delta") or {} + finish_reason = choice.get("finish_reason") + + if delta.get("role"): + event_types.append("TEXT_MESSAGE_START") + if delta.get("content"): + event_types.append("TEXT_MESSAGE_CONTENT") + + for tool_call in delta.get("tool_calls") or []: + function = (tool_call or {}).get("function") or {} + name = function.get("name") + arguments = function.get("arguments", "") + if name: + event_types.append("TOOL_CALL_START") + if arguments: + event_types.append("TOOL_CALL_ARGS") + + if finish_reason == "tool_calls": + event_types.append("TOOL_CALL_END") + elif finish_reason == "stop": + event_types.append("TEXT_MESSAGE_END") + + return event_types + + +def _normalize_openai_stream( + chunks: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + normalized: List[Dict[str, Any]] = [] + for chunk in chunks: + choice = (chunk.get("choices") or [{}])[0] or {} + delta = choice.get("delta") or {} + entry: Dict[str, Any] = { + "object": chunk.get("object"), + "finish_reason": choice.get("finish_reason"), + } + if "role" in delta: + entry["delta_role"] = delta["role"] + if "content" in delta: + entry["delta_content"] = delta["content"] + if "tool_calls" in delta: + tools = [] + for tc in delta.get("tool_calls") or []: + func = (tc or {}).get("function") or {} + tools.append({ + "name": func.get("name"), + "arguments": func.get("arguments"), + "has_id": bool((tc or {}).get("id")), + }) + entry["tool_calls"] = tools + normalized.append(entry) + return normalized + + +OPENAI_STREAM_EXPECTED = { + "text_basic": [ + { + "object": "chat.completion.chunk", + "delta_role": "assistant", + "finish_reason": None, + }, + { + "object": "chat.completion.chunk", + "delta_content": "你好,请简单介绍一下你自己", + "finish_reason": None, + }, + { + "object": "chat.completion.chunk", + "finish_reason": "stop", + }, + ], + "tool_weather": [ + { + "object": "chat.completion.chunk", + "tool_calls": [ + {"name": None, "arguments": '{"city": "北京"}', "has_id": False} + ], + "finish_reason": None, + }, + { + "object": "chat.completion.chunk", + "tool_calls": [{ + "name": "get_weather", + "arguments": "", + "has_id": True, + }], + "finish_reason": None, + }, + { + "object": "chat.completion.chunk", + "tool_calls": [ + {"name": None, "arguments": '{"city": "北京"}', "has_id": False} + ], + "finish_reason": None, + }, + { + "object": "chat.completion.chunk", + "finish_reason": "tool_calls", + }, + { + "object": "chat.completion.chunk", + "finish_reason": None, + "delta_role": "assistant", + }, + { + "object": "chat.completion.chunk", + "delta_content": ( + '工具结果已收到: {"city": "北京", "weather": "晴天",' + ' "temperature": 25}' + ), + "finish_reason": None, + }, + {"object": "chat.completion.chunk", "finish_reason": "stop"}, + ], + "tool_time": [ + { + "object": "chat.completion.chunk", + "tool_calls": [{ + "name": "get_time", + "arguments": "", + "has_id": True, + }], + "finish_reason": None, + }, + { + "object": "chat.completion.chunk", + "finish_reason": "tool_calls", + }, + { + "object": "chat.completion.chunk", + "delta_role": "assistant", + "finish_reason": None, + }, + { + "object": "chat.completion.chunk", + "delta_content": "工具结果已收到: 2024-01-01 12:00:00", + "finish_reason": None, + }, + {"object": "chat.completion.chunk", "finish_reason": "stop"}, + ], +} + + +def _normalize_openai_nonstream(resp: Dict[str, Any]) -> Dict[str, Any]: + choice = (resp.get("choices") or [{}])[0] or {} + msg = choice.get("message") or {} + tools_norm = None + if msg.get("tool_calls"): + tools_norm = [] + for tc in msg.get("tool_calls") or []: + func = (tc or {}).get("function") or {} + tools_norm.append({ + "name": func.get("name"), + "arguments": func.get("arguments"), + "has_id": bool((tc or {}).get("id")), + }) + return { + "object": resp.get("object"), + "role": msg.get("role"), + "content": msg.get("content"), + "tool_calls": tools_norm, + "finish_reason": choice.get("finish_reason"), + } + + +OPENAI_NONSTREAM_EXPECTED = { + "text_basic": { + "object": "chat.completion", + "role": "assistant", + "content": "你好,请简单介绍一下你自己", + "tool_calls": None, + "finish_reason": "stop", + }, + "tool_weather": { + "object": "chat.completion", + "role": "assistant", + "content": ( + '工具结果已收到: {"city": "北京", "weather": "晴天",' + ' "temperature": 25}' + ), + "tool_calls": [{ + "name": "get_weather", + "arguments": '{"city": "北京"}', + "has_id": True, + }], + "finish_reason": "tool_calls", + }, + "tool_time": { + "object": "chat.completion", + "role": "assistant", + "content": "工具结果已收到: 2024-01-01 12:00:00", + "tool_calls": [{ + "name": "get_time", + "arguments": "", + "has_id": True, + }], + "finish_reason": "tool_calls", + }, +} + + +def assert_openai_text_generation_events(chunks: List[Dict[str, Any]]) -> None: + """校验 OpenAI 协议纯文本流式事件""" + assert ( + _normalize_openai_stream(chunks) == OPENAI_STREAM_EXPECTED["text_basic"] + ) + + +def assert_openai_tool_call_events( + chunks: List[Dict[str, Any]], case_key: str +) -> None: + """校验 OpenAI 协议工具调用流式事件""" + assert _normalize_openai_stream(chunks) == OPENAI_STREAM_EXPECTED[case_key] + + +def assert_openai_text_generation_response(resp: Dict[str, Any]) -> None: + """校验 OpenAI 协议非流式文本响应""" + assert ( + _normalize_openai_nonstream(resp) + == OPENAI_NONSTREAM_EXPECTED["text_basic"] + ) + + +def assert_openai_tool_call_response( + resp: Dict[str, Any], case_key: str +) -> None: + """校验 OpenAI 协议非流式工具调用响应""" + assert ( + _normalize_openai_nonstream(resp) == OPENAI_NONSTREAM_EXPECTED[case_key] + ) + + +AGUI_PROMPT_CASES = [ + ("text_basic", "你好,请简单介绍一下你自己"), + ("tool_weather", "北京的天气怎么样?"), + ("tool_time", "现在几点?"), +] + + +async def request_openai_events( + server_app, + messages: List[Dict[str, str]], + stream: bool = True, +) -> Union[List[Dict[str, Any]], Dict[str, Any]]: + """发送 OpenAI 协议请求并返回流式事件列表或响应""" + payload: Dict[str, Any] = { + "model": "mock-model", + "messages": messages, + "stream": stream, + } + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=server_app), + base_url="http://test", + ) as client: + response = await client.post( + "/openai/v1/chat/completions", + json=payload, + timeout=60.0, + ) + + assert response.status_code == 200 + + if stream: + return parse_sse_events(response.text) + + return response.json() + + +# ============================================================================= +# 模型服务准备 +# ============================================================================= + + +@pytest.fixture(scope="session") +def mock_openai_server(): + """启动本地 OpenAI 协议兼容服务""" + app = _build_mock_openai_app() + port = _find_free_port() + config = uvicorn.Config( + app, host="127.0.0.1", port=port, log_level="warning" + ) + server = uvicorn.Server(config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + base_url = f"http://127.0.0.1:{port}" + for _ in range(50): + try: + httpx.get(f"{base_url}/v1/models", timeout=0.2) + break + except Exception: + time.sleep(0.1) + + yield base_url + + server.should_exit = True + thread.join(timeout=5) + + +@pytest.fixture(scope="function") +def agent_model(mock_openai_server: str): + """使用本地 OpenAI 服务构造 ModelService""" + base_url = f"{mock_openai_server}/v1" + return ModelService( + model_service_name="mock-model", + model_type=ModelType.LLM, + provider="openai", + provider_settings=ProviderSettings( + api_key="sk-local-key", + base_url=base_url, + model_names=["mock-model"], + ), + ) + + +@pytest.fixture +def server_app_astream_events(agent_model): + """创建使用 astream_events 的服务器(AG-UI/OpenAI 通用)""" + agent = build_agent(agent_model) + + async def invoke_agent(request: AgentRequest): + input_data: Dict[str, Any] = { + "messages": [ + { + "role": ( + msg.role.value + if hasattr(msg.role, "value") + else str(msg.role) + ), + "content": msg.content, + } + for msg in request.messages + ] + } + + converter = AgentRunConverter() + + async def generator(): + async for event in agent.astream_events( + cast(Any, input_data), version="v2" + ): + for item in converter.convert(event): + yield item + + return generator() + + server = AgentRunServer(invoke_agent=invoke_agent) + return server.app + + +# ============================================================================= +# 测试类: astream_events +# ============================================================================= + + +class TestAstreamEvents: + """测试 agent.astream_events 调用方式""" + + @pytest.mark.parametrize( + "case_key,prompt", + AGUI_PROMPT_CASES, + ids=[case[0] for case in AGUI_PROMPT_CASES], + ) + async def test_astream_events( + self, server_app_astream_events, case_key, prompt + ): + """覆盖文本、工具、本地工具的流式事件""" + events = await request_agui_events( + server_app_astream_events, + [{"role": "user", "content": prompt}], + stream=True, + ) + + assert_agui_events_exact(events, case_key) + + +# ============================================================================= +# 测试类: astream (updates 模式) +# ============================================================================= + + +class TestAstreamUpdates: + """测试 agent.astream(stream_mode=\"updates\") 调用方式""" + + @pytest.fixture + def server_app(self, agent_model): + """创建使用 astream 的服务器""" + agent = build_agent(agent_model) + + async def invoke_agent(request: AgentRequest): + input_data: Dict[str, Any] = { + "messages": [ + { + "role": ( + msg.role.value + if hasattr(msg.role, "value") + else str(msg.role) + ), + "content": msg.content, + } + for msg in request.messages + ] + } + + converter = AgentRunConverter() + if request.stream: + + async def generator(): + async for event in agent.astream( + cast(Any, input_data), stream_mode="updates" + ): + for item in converter.convert(event): + yield item + + return generator() + else: + return await agent.ainvoke(cast(Any, input_data)) + + server = AgentRunServer(invoke_agent=invoke_agent) + return server.app + + @pytest.mark.parametrize( + "case_key,prompt", + AGUI_PROMPT_CASES, + ids=[case[0] for case in AGUI_PROMPT_CASES], + ) + async def test_astream_updates(self, server_app, case_key, prompt): + """流式 updates 模式覆盖对话与工具场景""" + events = await request_agui_events( + server_app, + [{"role": "user", "content": prompt}], + stream=True, + ) + + assert_agui_events_exact(events, case_key) + + +# ============================================================================= +# 测试类: stream (同步 updates 模式) +# ============================================================================= + + +class TestStreamUpdates: + """测试 agent.stream(stream_mode="updates") 调用方式""" + + @pytest.fixture + def server_app(self, agent_model): + """创建使用 stream 的服务器""" + agent = build_agent(agent_model) + + def invoke_agent(request: AgentRequest): + input_data: Dict[str, Any] = { + "messages": [ + { + "role": ( + msg.role.value + if hasattr(msg.role, "value") + else str(msg.role) + ), + "content": msg.content, + } + for msg in request.messages + ] + } + + converter = AgentRunConverter() + if request.stream: + + def generator(): + for event in agent.stream( + cast(Any, input_data), stream_mode="updates" + ): + for item in converter.convert(event): + yield item + + return generator() + else: + return agent.invoke(cast(Any, input_data)) + + server = AgentRunServer(invoke_agent=invoke_agent) + return server.app + + @pytest.mark.parametrize( + "case_key,prompt", + AGUI_PROMPT_CASES, + ids=[case[0] for case in AGUI_PROMPT_CASES], + ) + async def test_stream_updates(self, server_app, case_key, prompt): + """同步 stream updates 覆盖对话与工具场景""" + events = await request_agui_events( + server_app, + [{"role": "user", "content": prompt}], + stream=True, + ) + + assert_agui_events_exact(events, case_key) + + +# ============================================================================= +# 测试类: invoke/ainvoke (非流式) +# ============================================================================= + + +class TestInvoke: + """测试 agent.invoke/ainvoke 非流式调用方式""" + + @pytest.fixture + def server_app_sync(self, agent_model): + """创建使用 invoke 的服务器""" + agent = build_agent(agent_model) + + async def invoke_agent(request: AgentRequest): + input_data: Dict[str, Any] = { + "messages": [ + { + "role": ( + msg.role.value + if hasattr(msg.role, "value") + else str(msg.role) + ), + "content": msg.content, + } + for msg in request.messages + ] + } + + converter = AgentRunConverter() + + async def generator(): + async for event in agent.astream_events( + cast(Any, input_data), version="v2" + ): + for item in converter.convert(event): + yield item + + return generator() + + server = AgentRunServer(invoke_agent=invoke_agent) + return server.app + + @pytest.fixture + def server_app_async(self, agent_model): + """创建使用 ainvoke 的服务器""" + agent = build_agent(agent_model) + + async def invoke_agent(request: AgentRequest): + input_data: Dict[str, Any] = { + "messages": [ + { + "role": ( + msg.role.value + if hasattr(msg.role, "value") + else str(msg.role) + ), + "content": msg.content, + } + for msg in request.messages + ] + } + + converter = AgentRunConverter() + + async def generator(): + async for event in agent.astream_events( + cast(Any, input_data), version="v2" + ): + for item in converter.convert(event): + yield item + + return generator() + + server = AgentRunServer(invoke_agent=invoke_agent) + return server.app + + @pytest.mark.parametrize( + "case_key,prompt", + AGUI_PROMPT_CASES, + ids=[case[0] for case in AGUI_PROMPT_CASES], + ) + async def test_sync_invoke(self, server_app_sync, case_key, prompt): + """测试同步 invoke 非流式场景""" + events = await request_agui_events( + server_app_sync, + [{"role": "user", "content": prompt}], + stream=False, + ) + + assert_agui_events_exact(events, case_key) + + @pytest.mark.parametrize( + "case_key,prompt", + AGUI_PROMPT_CASES, + ids=[case[0] for case in AGUI_PROMPT_CASES], + ) + async def test_async_invoke(self, server_app_async, case_key, prompt): + """测试异步 ainvoke 非流式场景""" + events = await request_agui_events( + server_app_async, + [{"role": "user", "content": prompt}], + stream=False, + ) + + assert_agui_events_exact(events, case_key) + + +# ============================================================================= +# 测试类: OpenAI 协议 +# ============================================================================= + + +class TestOpenAIProtocol: + """测试 OpenAI Chat Completions 协议""" + + async def test_nonstream_text_generation(self, server_app_astream_events): + """非流式简单对话""" + resp = await request_openai_events( + server_app_astream_events, + [{"role": "user", "content": "你好,请简单介绍一下你自己"}], + stream=False, + ) + + assert_openai_text_generation_response(cast(Dict[str, Any], resp)) + + async def test_stream_text_generation(self, server_app_astream_events): + """OpenAI 协议纯文本流式输出检查""" + events = cast( + List[Dict[str, Any]], + await request_openai_events( + server_app_astream_events, + [{"role": "user", "content": "你好,请简单介绍一下你自己"}], + stream=True, + ), + ) + + assert_openai_text_generation_events(events) + + async def test_stream_tool_call(self, server_app_astream_events): + """OpenAI 协议工具调用流式输出检查""" + events = cast( + List[Dict[str, Any]], + await request_openai_events( + server_app_astream_events, + [{"role": "user", "content": "北京的天气怎么样?"}], + stream=True, + ), + ) + + assert_openai_tool_call_events(events, "tool_weather") + + async def test_nonstream_tool_call(self, server_app_astream_events): + """非流式工具调用""" + resp = await request_openai_events( + server_app_astream_events, + [{"role": "user", "content": "北京的天气怎么样?"}], + stream=False, + ) + + assert_openai_tool_call_response( + cast(Dict[str, Any], resp), "tool_weather" + ) + + async def test_nonstream_local_tool(self, server_app_astream_events): + """非流式本地工具调用""" + resp = await request_openai_events( + server_app_astream_events, + [{"role": "user", "content": "现在几点?"}], + stream=False, + ) + + assert_openai_tool_call_response( + cast(Dict[str, Any], resp), "tool_time" + ) + + async def test_stream_local_tool(self, server_app_astream_events): + """流式本地工具调用""" + events = cast( + List[Dict[str, Any]], + await request_openai_events( + server_app_astream_events, + [{"role": "user", "content": "现在几点?"}], + stream=True, + ), + ) + + assert_openai_tool_call_events(events, "tool_time") diff --git a/tests/unittests/integration/__init__.py b/tests/unittests/integration/__init__.py new file mode 100644 index 0000000..93b16d1 --- /dev/null +++ b/tests/unittests/integration/__init__.py @@ -0,0 +1 @@ +"""LangChain/LangGraph 集成测试""" diff --git a/tests/unittests/integration/conftest.py b/tests/unittests/integration/conftest.py new file mode 100644 index 0000000..579ce11 --- /dev/null +++ b/tests/unittests/integration/conftest.py @@ -0,0 +1,269 @@ +"""LangChain/LangGraph 集成测试的公共 fixtures 和辅助函数 + +提供模拟 LangChain/LangGraph 消息对象的工厂函数和常用测试辅助函数。 +""" + +from typing import Any, Dict, List, Union +from unittest.mock import MagicMock + +import pytest + +from agentrun.integration.langgraph import AgentRunConverter +from agentrun.server.model import AgentEvent, EventType + +# ============================================================================= +# Mock 消息工厂函数 +# ============================================================================= + + +def create_mock_ai_message( + content: str = "", + tool_calls: List[Dict[str, Any]] = None, +) -> MagicMock: + """创建模拟的 AIMessage 对象 + + Args: + content: 消息内容 + tool_calls: 工具调用列表 + + Returns: + MagicMock: 模拟的 AIMessage 对象 + """ + msg = MagicMock() + msg.content = content + msg.type = "ai" + msg.tool_calls = tool_calls or [] + return msg + + +def create_mock_ai_message_chunk( + content: str = "", + tool_call_chunks: List[Dict] = None, +) -> MagicMock: + """创建模拟的 AIMessageChunk 对象(流式输出) + + Args: + content: 内容片段 + tool_call_chunks: 工具调用片段列表 + + Returns: + MagicMock: 模拟的 AIMessageChunk 对象 + """ + chunk = MagicMock() + chunk.content = content + chunk.tool_call_chunks = tool_call_chunks or [] + return chunk + + +def create_mock_tool_message(content: str, tool_call_id: str) -> MagicMock: + """创建模拟的 ToolMessage 对象 + + Args: + content: 工具执行结果 + tool_call_id: 工具调用 ID + + Returns: + MagicMock: 模拟的 ToolMessage 对象 + """ + msg = MagicMock() + msg.content = content + msg.type = "tool" + msg.tool_call_id = tool_call_id + return msg + + +# ============================================================================= +# 事件转换辅助函数 +# ============================================================================= + + +def convert_and_collect(events: List[Dict]) -> List[Union[str, AgentEvent]]: + """转换事件列表并收集所有结果 + + Args: + events: LangChain/LangGraph 事件列表 + + Returns: + List: 转换后的 AgentEvent 列表 + """ + results = [] + for event in events: + results.extend(AgentRunConverter.to_agui_events(event)) + return results + + +def filter_agent_events( + results: List[Union[str, AgentEvent]], event_type: EventType +) -> List[AgentEvent]: + """过滤特定类型的 AgentEvent + + Args: + results: 转换结果列表 + event_type: 要过滤的事件类型 + + Returns: + List[AgentEvent]: 过滤后的事件列表 + """ + return [ + r + for r in results + if isinstance(r, AgentEvent) and r.event == event_type + ] + + +def get_event_types(results: List[Union[str, AgentEvent]]) -> List[EventType]: + """获取结果中所有 AgentEvent 的类型 + + Args: + results: 转换结果列表 + + Returns: + List[EventType]: 事件类型列表 + """ + return [r.event for r in results if isinstance(r, AgentEvent)] + + +# ============================================================================= +# astream_events 格式的事件工厂 +# ============================================================================= + + +def create_on_chat_model_stream_event(chunk: MagicMock) -> Dict: + """创建 on_chat_model_stream 事件 + + Args: + chunk: AIMessageChunk 对象 + + Returns: + Dict: astream_events 格式的事件 + """ + return { + "event": "on_chat_model_stream", + "data": {"chunk": chunk}, + } + + +def create_on_tool_start_event( + tool_name: str, + tool_input: Dict, + run_id: str = "run-123", + tool_call_id: str = None, +) -> Dict: + """创建 on_tool_start 事件 + + Args: + tool_name: 工具名称 + tool_input: 工具输入参数 + run_id: 运行 ID + tool_call_id: 工具调用 ID(可选,会放入 metadata) + + Returns: + Dict: astream_events 格式的事件 + """ + event = { + "event": "on_tool_start", + "name": tool_name, + "run_id": run_id, + "data": {"input": tool_input}, + } + if tool_call_id: + event["metadata"] = {"langgraph_tool_call_id": tool_call_id} + return event + + +def create_on_tool_end_event( + output: Any, + run_id: str = "run-123", + tool_call_id: str = None, +) -> Dict: + """创建 on_tool_end 事件 + + Args: + output: 工具输出 + run_id: 运行 ID + tool_call_id: 工具调用 ID(可选,会放入 metadata) + + Returns: + Dict: astream_events 格式的事件 + """ + event = { + "event": "on_tool_end", + "run_id": run_id, + "data": {"output": output}, + } + if tool_call_id: + event["metadata"] = {"langgraph_tool_call_id": tool_call_id} + return event + + +def create_on_tool_error_event( + error: str, + run_id: str = "run-123", +) -> Dict: + """创建 on_tool_error 事件 + + Args: + error: 错误信息 + run_id: 运行 ID + + Returns: + Dict: astream_events 格式的事件 + """ + return { + "event": "on_tool_error", + "run_id": run_id, + "data": {"error": error}, + } + + +# ============================================================================= +# stream_mode 格式的事件工厂 +# ============================================================================= + + +def create_stream_updates_event(node_name: str, messages: List) -> Dict: + """创建 stream_mode="updates" 格式的事件 + + Args: + node_name: 节点名称(如 "model", "agent", "tools") + messages: 消息列表 + + Returns: + Dict: stream_mode="updates" 格式的事件 + """ + return {node_name: {"messages": messages}} + + +def create_stream_values_event(messages: List) -> Dict: + """创建 stream_mode="values" 格式的事件 + + Args: + messages: 消息列表 + + Returns: + Dict: stream_mode="values" 格式的事件 + """ + return {"messages": messages} + + +# ============================================================================= +# Pytest Fixtures +# ============================================================================= + + +@pytest.fixture +def ai_message_factory(): + """提供 AIMessage 工厂函数""" + return create_mock_ai_message + + +@pytest.fixture +def ai_message_chunk_factory(): + """提供 AIMessageChunk 工厂函数""" + return create_mock_ai_message_chunk + + +@pytest.fixture +def tool_message_factory(): + """提供 ToolMessage 工厂函数""" + return create_mock_tool_message diff --git a/tests/unittests/integration/helpers.py b/tests/unittests/integration/helpers.py new file mode 100644 index 0000000..579ce11 --- /dev/null +++ b/tests/unittests/integration/helpers.py @@ -0,0 +1,269 @@ +"""LangChain/LangGraph 集成测试的公共 fixtures 和辅助函数 + +提供模拟 LangChain/LangGraph 消息对象的工厂函数和常用测试辅助函数。 +""" + +from typing import Any, Dict, List, Union +from unittest.mock import MagicMock + +import pytest + +from agentrun.integration.langgraph import AgentRunConverter +from agentrun.server.model import AgentEvent, EventType + +# ============================================================================= +# Mock 消息工厂函数 +# ============================================================================= + + +def create_mock_ai_message( + content: str = "", + tool_calls: List[Dict[str, Any]] = None, +) -> MagicMock: + """创建模拟的 AIMessage 对象 + + Args: + content: 消息内容 + tool_calls: 工具调用列表 + + Returns: + MagicMock: 模拟的 AIMessage 对象 + """ + msg = MagicMock() + msg.content = content + msg.type = "ai" + msg.tool_calls = tool_calls or [] + return msg + + +def create_mock_ai_message_chunk( + content: str = "", + tool_call_chunks: List[Dict] = None, +) -> MagicMock: + """创建模拟的 AIMessageChunk 对象(流式输出) + + Args: + content: 内容片段 + tool_call_chunks: 工具调用片段列表 + + Returns: + MagicMock: 模拟的 AIMessageChunk 对象 + """ + chunk = MagicMock() + chunk.content = content + chunk.tool_call_chunks = tool_call_chunks or [] + return chunk + + +def create_mock_tool_message(content: str, tool_call_id: str) -> MagicMock: + """创建模拟的 ToolMessage 对象 + + Args: + content: 工具执行结果 + tool_call_id: 工具调用 ID + + Returns: + MagicMock: 模拟的 ToolMessage 对象 + """ + msg = MagicMock() + msg.content = content + msg.type = "tool" + msg.tool_call_id = tool_call_id + return msg + + +# ============================================================================= +# 事件转换辅助函数 +# ============================================================================= + + +def convert_and_collect(events: List[Dict]) -> List[Union[str, AgentEvent]]: + """转换事件列表并收集所有结果 + + Args: + events: LangChain/LangGraph 事件列表 + + Returns: + List: 转换后的 AgentEvent 列表 + """ + results = [] + for event in events: + results.extend(AgentRunConverter.to_agui_events(event)) + return results + + +def filter_agent_events( + results: List[Union[str, AgentEvent]], event_type: EventType +) -> List[AgentEvent]: + """过滤特定类型的 AgentEvent + + Args: + results: 转换结果列表 + event_type: 要过滤的事件类型 + + Returns: + List[AgentEvent]: 过滤后的事件列表 + """ + return [ + r + for r in results + if isinstance(r, AgentEvent) and r.event == event_type + ] + + +def get_event_types(results: List[Union[str, AgentEvent]]) -> List[EventType]: + """获取结果中所有 AgentEvent 的类型 + + Args: + results: 转换结果列表 + + Returns: + List[EventType]: 事件类型列表 + """ + return [r.event for r in results if isinstance(r, AgentEvent)] + + +# ============================================================================= +# astream_events 格式的事件工厂 +# ============================================================================= + + +def create_on_chat_model_stream_event(chunk: MagicMock) -> Dict: + """创建 on_chat_model_stream 事件 + + Args: + chunk: AIMessageChunk 对象 + + Returns: + Dict: astream_events 格式的事件 + """ + return { + "event": "on_chat_model_stream", + "data": {"chunk": chunk}, + } + + +def create_on_tool_start_event( + tool_name: str, + tool_input: Dict, + run_id: str = "run-123", + tool_call_id: str = None, +) -> Dict: + """创建 on_tool_start 事件 + + Args: + tool_name: 工具名称 + tool_input: 工具输入参数 + run_id: 运行 ID + tool_call_id: 工具调用 ID(可选,会放入 metadata) + + Returns: + Dict: astream_events 格式的事件 + """ + event = { + "event": "on_tool_start", + "name": tool_name, + "run_id": run_id, + "data": {"input": tool_input}, + } + if tool_call_id: + event["metadata"] = {"langgraph_tool_call_id": tool_call_id} + return event + + +def create_on_tool_end_event( + output: Any, + run_id: str = "run-123", + tool_call_id: str = None, +) -> Dict: + """创建 on_tool_end 事件 + + Args: + output: 工具输出 + run_id: 运行 ID + tool_call_id: 工具调用 ID(可选,会放入 metadata) + + Returns: + Dict: astream_events 格式的事件 + """ + event = { + "event": "on_tool_end", + "run_id": run_id, + "data": {"output": output}, + } + if tool_call_id: + event["metadata"] = {"langgraph_tool_call_id": tool_call_id} + return event + + +def create_on_tool_error_event( + error: str, + run_id: str = "run-123", +) -> Dict: + """创建 on_tool_error 事件 + + Args: + error: 错误信息 + run_id: 运行 ID + + Returns: + Dict: astream_events 格式的事件 + """ + return { + "event": "on_tool_error", + "run_id": run_id, + "data": {"error": error}, + } + + +# ============================================================================= +# stream_mode 格式的事件工厂 +# ============================================================================= + + +def create_stream_updates_event(node_name: str, messages: List) -> Dict: + """创建 stream_mode="updates" 格式的事件 + + Args: + node_name: 节点名称(如 "model", "agent", "tools") + messages: 消息列表 + + Returns: + Dict: stream_mode="updates" 格式的事件 + """ + return {node_name: {"messages": messages}} + + +def create_stream_values_event(messages: List) -> Dict: + """创建 stream_mode="values" 格式的事件 + + Args: + messages: 消息列表 + + Returns: + Dict: stream_mode="values" 格式的事件 + """ + return {"messages": messages} + + +# ============================================================================= +# Pytest Fixtures +# ============================================================================= + + +@pytest.fixture +def ai_message_factory(): + """提供 AIMessage 工厂函数""" + return create_mock_ai_message + + +@pytest.fixture +def ai_message_chunk_factory(): + """提供 AIMessageChunk 工厂函数""" + return create_mock_ai_message_chunk + + +@pytest.fixture +def tool_message_factory(): + """提供 ToolMessage 工厂函数""" + return create_mock_tool_message diff --git a/tests/unittests/integration/test_agent_converter.py b/tests/unittests/integration/test_agent_converter.py new file mode 100644 index 0000000..4f28b60 --- /dev/null +++ b/tests/unittests/integration/test_agent_converter.py @@ -0,0 +1,1901 @@ +"""测试 convert 函数 / Test convert Function + +测试 convert 函数对不同 LangChain/LangGraph 调用方式返回事件格式的兼容性。 +支持的格式: +- astream_events(version="v2") 格式 +- stream/astream(stream_mode="updates") 格式 +- stream/astream(stream_mode="values") 格式 +""" + +from unittest.mock import MagicMock + +import pytest + +from agentrun.integration.langgraph.agent_converter import AgentRunConverter +from agentrun.server.model import AgentResult, EventType + +# 使用 helpers.py 中的公共 mock 函数 +from .helpers import ( + create_mock_ai_message, + create_mock_ai_message_chunk, + create_mock_tool_message, +) + +# ============================================================================= +# 测试事件格式检测函数 +# ============================================================================= + + +class TestEventFormatDetection: + """测试事件格式检测函数""" + + def test_is_astream_events_format(self): + """测试 astream_events 格式检测""" + # 正确的 astream_events 格式 + assert AgentRunConverter.is_astream_events_format( + {"event": "on_chat_model_stream", "data": {}} + ) + assert AgentRunConverter.is_astream_events_format( + {"event": "on_tool_start", "data": {}} + ) + assert AgentRunConverter.is_astream_events_format( + {"event": "on_tool_end", "data": {}} + ) + assert AgentRunConverter.is_astream_events_format( + {"event": "on_chain_stream", "data": {}} + ) + + # 不是 astream_events 格式 + assert not AgentRunConverter.is_astream_events_format( + {"model": {"messages": []}} + ) + assert not AgentRunConverter.is_astream_events_format({"messages": []}) + assert not AgentRunConverter.is_astream_events_format({}) + assert not AgentRunConverter.is_astream_events_format( + {"event": "custom_event"} + ) # 不以 on_ 开头 + + def test_is_stream_updates_format(self): + """测试 stream(updates) 格式检测""" + # 正确的 updates 格式 + assert AgentRunConverter.is_stream_updates_format( + {"model": {"messages": []}} + ) + assert AgentRunConverter.is_stream_updates_format( + {"agent": {"messages": []}} + ) + assert AgentRunConverter.is_stream_updates_format( + {"tools": {"messages": []}} + ) + assert AgentRunConverter.is_stream_updates_format( + {"__end__": {}, "model": {"messages": []}} + ) + + # 不是 updates 格式 + assert not AgentRunConverter.is_stream_updates_format( + {"event": "on_chat_model_stream"} + ) + assert not AgentRunConverter.is_stream_updates_format( + {"messages": []} + ) # 这是 values 格式 + assert not AgentRunConverter.is_stream_updates_format({}) + + def test_is_stream_values_format(self): + """测试 stream(values) 格式检测""" + # 正确的 values 格式 + assert AgentRunConverter.is_stream_values_format({"messages": []}) + assert AgentRunConverter.is_stream_values_format( + {"messages": [MagicMock()]} + ) + + # 不是 values 格式 + assert not AgentRunConverter.is_stream_values_format( + {"event": "on_chat_model_stream"} + ) + assert not AgentRunConverter.is_stream_values_format( + {"model": {"messages": []}} + ) + assert not AgentRunConverter.is_stream_values_format({}) + + +# ============================================================================= +# 测试 astream_events 格式的转换 +# ============================================================================= + + +class TestConvertAstreamEventsFormat: + """测试 astream_events 格式的事件转换""" + + def test_on_chat_model_stream_text_content(self): + """测试 on_chat_model_stream 事件的文本内容提取""" + chunk = create_mock_ai_message_chunk("你好") + event = { + "event": "on_chat_model_stream", + "data": {"chunk": chunk}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "你好" + + def test_on_chat_model_stream_empty_content(self): + """测试 on_chat_model_stream 事件的空内容""" + chunk = create_mock_ai_message_chunk("") + event = { + "event": "on_chat_model_stream", + "data": {"chunk": chunk}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_on_chat_model_stream_with_tool_call_args(self): + """测试 on_chat_model_stream 事件的工具调用参数""" + chunk = create_mock_ai_message_chunk( + "", + tool_call_chunks=[{ + "id": "call_123", + "name": "get_weather", + "args": '{"city": "北京"}', + }], + ) + event = { + "event": "on_chat_model_stream", + "data": {"chunk": chunk}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 第一个 chunk 有 id 和 name 时,发送完整的 TOOL_CALL_CHUNK + assert len(results) == 1 + assert isinstance(results[0], AgentResult) + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_123" + assert results[0].data["name"] == "get_weather" + assert results[0].data["args_delta"] == '{"city": "北京"}' + + def test_on_tool_start(self): + """测试 on_tool_start 事件""" + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "run_456", + "data": {"input": {"city": "北京"}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK(边界事件由协议层自动处理) + assert len(results) == 1 + assert isinstance(results[0], AgentResult) + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "run_456" + assert results[0].data["name"] == "get_weather" + assert "city" in results[0].data["args_delta"] + + def test_on_tool_start_without_input(self): + """测试 on_tool_start 事件(无输入参数)""" + event = { + "event": "on_tool_start", + "name": "get_time", + "run_id": "run_789", + "data": {}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK(边界事件由协议层自动处理) + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "run_789" + assert results[0].data["name"] == "get_time" + + def test_on_tool_end(self): + """测试 on_tool_end 事件 + + AG-UI 协议:TOOL_CALL_END 在 on_tool_start 中发送(参数传输完成时), + TOOL_CALL_RESULT 在 on_tool_end 中发送(工具执行完成时)。 + """ + event = { + "event": "on_tool_end", + "run_id": "run_456", + "data": {"output": {"weather": "晴天", "temperature": 25}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # on_tool_end 只发送 TOOL_CALL_RESULT + assert len(results) == 1 + + # TOOL_CALL_RESULT + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["id"] == "run_456" + assert "晴天" in results[0].data["result"] + + def test_on_tool_end_with_string_output(self): + """测试 on_tool_end 事件(字符串输出)""" + event = { + "event": "on_tool_end", + "run_id": "run_456", + "data": {"output": "晴天,25度"}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # on_tool_end 只发送 TOOL_CALL_RESULT + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["result"] == "晴天,25度" + + def test_on_tool_start_with_non_jsonable_args(self): + """工具输入包含不可 JSON 序列化对象时也能正常转换""" + + class Dummy: + + def __str__(self): + return "dummy_obj" + + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "run_non_json", + "data": {"input": {"obj": Dummy()}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "run_non_json" + assert "dummy_obj" in results[0].data["args_delta"] + + def test_on_tool_start_filters_internal_runtime_field(self): + """测试 on_tool_start 过滤 MCP 注入的 runtime 等内部字段""" + + class FakeToolRuntime: + """模拟 MCP 的 ToolRuntime 对象""" + + def __str__(self): + return "ToolRuntime(...huge internal state...)" + + event = { + "event": "on_tool_start", + "name": "maps_weather", + "run_id": "run_mcp_tool", + "data": { + "input": { + "city": "北京", # 用户实际参数 + "runtime": FakeToolRuntime(), # MCP 注入的内部字段 + "config": {"internal": "state"}, # 另一个内部字段 + "__pregel_runtime": "internal", # LangGraph 内部字段 + } + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["name"] == "maps_weather" + + delta = results[0].data["args_delta"] + # 应该只包含用户参数 city + assert "北京" in delta + # 不应该包含内部字段 + assert "runtime" not in delta.lower() or "ToolRuntime" not in delta + assert "internal" not in delta + assert "__pregel" not in delta + + def test_on_tool_start_uses_runtime_tool_call_id(self): + """测试 on_tool_start 使用 runtime 中的原始 tool_call_id 而非 run_id + + MCP 工具会在 input.runtime 中注入 tool_call_id,这是 LLM 返回的原始 ID。 + 应该优先使用这个 ID,以保证工具调用事件的 ID 一致性。 + """ + + class FakeToolRuntime: + """模拟 MCP 的 ToolRuntime 对象""" + + def __init__(self, tool_call_id: str): + self.tool_call_id = tool_call_id + + original_tool_call_id = "call_original_from_llm_12345" + + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": ( + "run_id_different_from_tool_call_id" + ), # run_id 与 tool_call_id 不同 + "data": { + "input": { + "city": "北京", + "runtime": FakeToolRuntime(original_tool_call_id), + } + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + + # 应该使用 runtime 中的原始 tool_call_id,而不是 run_id + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == original_tool_call_id + assert results[0].data["name"] == "get_weather" + + def test_on_tool_end_uses_runtime_tool_call_id(self): + """测试 on_tool_end 使用 runtime 中的原始 tool_call_id 而非 run_id""" + + class FakeToolRuntime: + """模拟 MCP 的 ToolRuntime 对象""" + + def __init__(self, tool_call_id: str): + self.tool_call_id = tool_call_id + + original_tool_call_id = "call_original_from_llm_67890" + + event = { + "event": "on_tool_end", + "run_id": "run_id_different_from_tool_call_id", + "data": { + "output": {"weather": "晴天", "temp": 25}, + "input": { + "city": "北京", + "runtime": FakeToolRuntime(original_tool_call_id), + }, + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # on_tool_end 只发送 TOOL_CALL_RESULT(TOOL_CALL_END 在 on_tool_start 发送) + assert len(results) == 1 + + # 应该使用 runtime 中的原始 tool_call_id + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["id"] == original_tool_call_id + + def test_on_tool_start_fallback_to_run_id(self): + """测试当 runtime 中没有 tool_call_id 时,回退使用 run_id""" + event = { + "event": "on_tool_start", + "name": "get_time", + "run_id": "run_789", + "data": {"input": {"timezone": "Asia/Shanghai"}}, # 没有 runtime + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + # 应该回退使用 run_id + assert results[0].data["id"] == "run_789" + + def test_streaming_tool_call_id_consistency_with_map(self): + """测试流式工具调用的 tool_call_id 一致性(使用映射) + + 在流式工具调用中: + - 第一个 chunk 有 id 但可能没有 args(用于建立映射) + - 后续 chunk 有 args 但 id 为空,只有 index(从映射查找 id) + + 使用 tool_call_id_map 可以确保 ID 一致性。 + """ + # 模拟流式工具调用的多个 chunk + events = [ + # 第一个 chunk: 有 id 和 name,没有 args(只用于建立映射) + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_abc123", + "name": "browser_navigate", + "args": "", + "index": 0, + }], + ) + }, + }, + # 第二个 chunk: id 为空,只有 index 和 args + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"url": "https://', + "index": 0, + }], + ) + }, + }, + # 第三个 chunk: id 为空,继续 args + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": 'example.com"}', + "index": 0, + }], + ) + }, + }, + ] + + # 使用 tool_call_id_map 来确保 ID 一致性 + tool_call_id_map: Dict[int, str] = {} + all_results = [] + + for event in events: + results = list( + AgentRunConverter.to_agui_events( + event, tool_call_id_map=tool_call_id_map + ) + ) + all_results.extend(results) + + # 验证映射已建立 + assert 0 in tool_call_id_map + assert tool_call_id_map[0] == "call_abc123" + + # 验证:所有 TOOL_CALL_CHUNK 都使用相同的 tool_call_id + chunk_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + + # 应该有 3 个 TOOL_CALL_CHUNK 事件(每个 chunk 一个) + assert len(chunk_events) == 3 + + # 所有事件应该使用相同的 tool_call_id(从映射获取) + for event in chunk_events: + assert event.data["id"] == "call_abc123" + + def test_streaming_tool_call_id_without_map_uses_index(self): + """测试不使用映射时,后续 chunk 回退到 index""" + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"url": "test"}', + "index": 0, + }], + ) + }, + } + + # 不传入 tool_call_id_map + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + # 回退使用 index + assert results[0].data["id"] == "0" + + def test_streaming_multiple_concurrent_tool_calls(self): + """测试多个并发工具调用(不同 index)的 ID 一致性""" + # 模拟 LLM 同时调用两个工具 + events = [ + # 第一个 chunk: 两个工具调用的 ID + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + { + "id": "call_tool1", + "name": "search", + "args": "", + "index": 0, + }, + { + "id": "call_tool2", + "name": "weather", + "args": "", + "index": 1, + }, + ], + ) + }, + }, + # 后续 chunk: 只有 index 和 args + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + { + "id": "", + "name": "", + "args": '{"q": "test"', + "index": 0, + }, + ], + ) + }, + }, + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + { + "id": "", + "name": "", + "args": '{"city": "北京"', + "index": 1, + }, + ], + ) + }, + }, + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + {"id": "", "name": "", "args": "}", "index": 0}, + {"id": "", "name": "", "args": "}", "index": 1}, + ], + ) + }, + }, + ] + + tool_call_id_map: Dict[int, str] = {} + all_results = [] + + for event in events: + results = list( + AgentRunConverter.to_agui_events( + event, tool_call_id_map=tool_call_id_map + ) + ) + all_results.extend(results) + + # 验证映射正确建立 + assert tool_call_id_map[0] == "call_tool1" + assert tool_call_id_map[1] == "call_tool2" + + # 验证所有事件使用正确的 ID + chunk_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + + # 应该有 6 个 TOOL_CALL_CHUNK 事件 + # - 2 个初始 chunk(id + name) + # - 4 个 args chunk + assert len(chunk_events) == 6 + + # 验证每个工具调用使用正确的 ID + tool1_chunks = [e for e in chunk_events if e.data["id"] == "call_tool1"] + tool2_chunks = [e for e in chunk_events if e.data["id"] == "call_tool2"] + + assert len(tool1_chunks) == 3 # 初始 + '{"q": "test"' + '}' + assert len(tool2_chunks) == 3 # 初始 + '{"city": "北京"' + '}' + + def test_agentrun_converter_class(self): + """测试 AgentRunConverter 类的完整功能""" + from agentrun.integration.langchain import AgentRunConverter + + events = [ + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_xyz", + "name": "test_tool", + "args": "", + "index": 0, + }], + ) + }, + }, + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"key": "value"}', + "index": 0, + }], + ) + }, + }, + ] + + converter = AgentRunConverter() + all_results = [] + + for event in events: + results = list(converter.convert(event)) + all_results.extend(results) + + # 验证内部映射 + assert converter._tool_call_id_map[0] == "call_xyz" + + # 验证结果 + chunk_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + # 现在有 2 个 chunk 事件(每个 stream chunk 一个) + assert len(chunk_events) == 2 + # 所有事件应该使用相同的 ID + for event in chunk_events: + assert event.data["id"] == "call_xyz" + + # 测试 reset + converter.reset() + assert len(converter._tool_call_id_map) == 0 + + def test_streaming_tool_call_with_first_chunk_having_args(self): + """测试第一个 chunk 同时有 id 和 args 的情况""" + # 有些模型可能在第一个 chunk 就返回完整的工具调用 + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_complete", + "name": "simple_tool", + "args": '{"done": true}', + "index": 0, + }], + ) + }, + } + + tool_call_id_map: Dict[int, str] = {} + tool_call_started_set: set = set() + results = list( + AgentRunConverter.to_agui_events( + event, + tool_call_id_map=tool_call_id_map, + tool_call_started_set=tool_call_started_set, + ) + ) + + # 验证映射被建立 + assert tool_call_id_map[0] == "call_complete" + # 验证 START 已发送 + assert "call_complete" in tool_call_started_set + + # 现在是单个 TOOL_CALL_CHUNK(包含 id, name, args_delta) + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_complete" + assert results[0].data["name"] == "simple_tool" + assert results[0].data["args_delta"] == '{"done": true}' + + def test_streaming_tool_call_id_none_vs_empty_string(self): + """测试 id 为 None 和空字符串的不同处理""" + events = [ + # id 为 None(建立映射) + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_from_none", + "name": "tool", + "args": "", + "index": 0, + }], + ) + }, + }, + # id 为 None(应该从映射获取) + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": None, + "name": "", + "args": '{"a": 1}', + "index": 0, + }], + ) + }, + }, + ] + + tool_call_id_map: Dict[int, str] = {} + all_results = [] + + for event in events: + results = list( + AgentRunConverter.to_agui_events( + event, tool_call_id_map=tool_call_id_map + ) + ) + all_results.extend(results) + + chunk_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + + # 现在有 2 个 chunk 事件(每个 stream chunk 一个) + assert len(chunk_events) == 2 + # 所有事件应该使用相同的 ID(从映射获取) + for event in chunk_events: + assert event.data["id"] == "call_from_none" + + def test_full_tool_call_flow_id_consistency(self): + """测试完整工具调用流程中的 ID 一致性 + + 模拟: + 1. on_chat_model_stream 产生 TOOL_CALL_CHUNK + 2. on_tool_start 不产生事件(已在流式中处理) + 3. on_tool_end 产生 TOOL_RESULT + + 验证所有事件使用相同的 tool_call_id + """ + # 模拟完整的工具调用流程 + events = [ + # 流式工具调用参数(第一个 chunk 有 id 和 name) + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_full_flow", + "name": "test_tool", + "args": "", + "index": 0, + }], + ) + }, + }, + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"param": "value"}', + "index": 0, + }], + ) + }, + }, + # 工具开始(使用 runtime.tool_call_id) + { + "event": "on_tool_start", + "name": "test_tool", + "run_id": "run_123", + "data": { + "input": { + "param": "value", + "runtime": MagicMock(tool_call_id="call_full_flow"), + } + }, + }, + # 工具结束 + { + "event": "on_tool_end", + "run_id": "run_123", + "data": { + "input": { + "param": "value", + "runtime": MagicMock(tool_call_id="call_full_flow"), + }, + "output": "success", + }, + }, + ] + + converter = AgentRunConverter() + all_results = [] + + for event in events: + results = list(converter.convert(event)) + all_results.extend(results) + + # 获取所有工具调用相关事件 + tool_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event in [EventType.TOOL_CALL_CHUNK, EventType.TOOL_RESULT] + ] + + # 验证所有事件都使用相同的 tool_call_id + for event in tool_events: + assert ( + event.data["id"] == "call_full_flow" + ), f"Event {event.event} has wrong id: {event.data.get('id')}" + + # 验证所有事件类型都存在 + event_types = [e.event for e in tool_events] + assert EventType.TOOL_CALL_CHUNK in event_types + assert EventType.TOOL_RESULT in event_types + + # 验证顺序:TOOL_CALL_CHUNK 必须在 TOOL_RESULT 之前 + chunk_idx = event_types.index(EventType.TOOL_CALL_CHUNK) + result_idx = event_types.index(EventType.TOOL_RESULT) + assert ( + chunk_idx < result_idx + ), "TOOL_CALL_CHUNK must come before TOOL_RESULT" + + def test_on_chain_stream_model_node(self): + """测试 on_chain_stream 事件(model 节点)""" + msg = create_mock_ai_message("你好!有什么可以帮你的吗?") + event = { + "event": "on_chain_stream", + "name": "model", + "data": {"chunk": {"messages": [msg]}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "你好!有什么可以帮你的吗?" + + def test_on_chain_stream_non_model_node(self): + """测试 on_chain_stream 事件(非 model 节点)""" + event = { + "event": "on_chain_stream", + "name": "agent", # 不是 "model" + "data": {"chunk": {"messages": []}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_on_chat_model_end_ignored(self): + """测试 on_chat_model_end 事件被忽略(避免重复)""" + event = { + "event": "on_chat_model_end", + "data": {"output": create_mock_ai_message("完成")}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_on_chain_stream_tool_call_no_duplicate_with_on_tool_start(self): + """测试 on_chain_stream 中的 tool call 不会与 on_tool_start 重复 + + 这个测试验证:当 on_chain_stream 已经发送了 TOOL_CALL_CHUNK 后, + 后续的 on_tool_start 事件不应该再次发送 TOOL_CALL_CHUNK。 + """ + tool_call_id = "call_abc123" + tool_name = "get_user_name" + + # 使用 converter 实例来维护状态 + converter = AgentRunConverter() + + # 1. 首先是 on_chain_stream 事件,包含 tool call + msg = create_mock_ai_message( + content="", + tool_calls=[{"id": tool_call_id, "name": tool_name, "args": {}}], + ) + chain_stream_event = { + "event": "on_chain_stream", + "name": "model", + "data": {"chunk": {"messages": [msg]}}, + } + + results1 = list(converter.convert(chain_stream_event)) + + # 应该产生一个 TOOL_CALL_CHUNK + assert len(results1) == 1 + assert results1[0].event == EventType.TOOL_CALL_CHUNK + assert results1[0].data["id"] == tool_call_id + assert results1[0].data["name"] == tool_name + + # 2. 然后是 on_tool_start 事件(使用不同的 run_id,模拟真实场景) + run_id = "run_xyz789" + tool_start_event = { + "event": "on_tool_start", + "run_id": run_id, + "name": tool_name, + "data": {"input": {}}, + } + + results2 = list(converter.convert(tool_start_event)) + + # 不应该产生 TOOL_CALL_CHUNK(因为已经在 on_chain_stream 中发送过了) + # 通过 tool_name_to_call_ids 映射,on_tool_start 应该使用相同的 tool_call_id + assert len(results2) == 0 + + def test_on_chain_stream_tool_call_with_on_tool_end(self): + """测试 on_chain_stream 中的 tool call 与 on_tool_end 的 ID 一致性 + + 这个测试验证:当 on_chain_stream 发送 TOOL_CALL_CHUNK 后, + on_tool_end 应该使用相同的 tool_call_id 发送 TOOL_RESULT。 + """ + tool_call_id = "call_abc123" + tool_name = "get_user_name" + run_id = "run_xyz789" + + # 使用 converter 实例来维护状态 + converter = AgentRunConverter() + + # 1. on_chain_stream 事件 + msg = create_mock_ai_message( + content="", + tool_calls=[{"id": tool_call_id, "name": tool_name, "args": {}}], + ) + chain_stream_event = { + "event": "on_chain_stream", + "name": "model", + "data": {"chunk": {"messages": [msg]}}, + } + list(converter.convert(chain_stream_event)) + + # 2. on_tool_start 事件 + tool_start_event = { + "event": "on_tool_start", + "run_id": run_id, + "name": tool_name, + "data": {"input": {}}, + } + list(converter.convert(tool_start_event)) + + # 3. on_tool_end 事件 + tool_end_event = { + "event": "on_tool_end", + "run_id": run_id, + "name": tool_name, + "data": {"output": '{"user_name": "张三"}'}, + } + results3 = list(converter.convert(tool_end_event)) + + # 应该产生一个 TOOL_RESULT,使用原始的 tool_call_id + assert len(results3) == 1 + assert results3[0].event == EventType.TOOL_RESULT + assert results3[0].data["id"] == tool_call_id + assert results3[0].data["result"] == '{"user_name": "张三"}' + + +# ============================================================================= +# 测试 stream/astream(stream_mode="updates") 格式的转换 +# ============================================================================= + + +class TestConvertStreamUpdatesFormat: + """测试 stream(updates) 格式的事件转换""" + + def test_ai_message_text_content(self): + """测试 AI 消息的文本内容""" + msg = create_mock_ai_message("你好!") + event = {"model": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "你好!" + + def test_ai_message_empty_content(self): + """测试 AI 消息的空内容""" + msg = create_mock_ai_message("") + event = {"model": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_ai_message_with_tool_calls(self): + """测试 AI 消息包含工具调用""" + msg = create_mock_ai_message( + "", + tool_calls=[{ + "id": "call_abc", + "name": "get_weather", + "args": {"city": "上海"}, + }], + ) + event = {"agent": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_abc" + assert results[0].data["name"] == "get_weather" + assert "上海" in results[0].data["args_delta"] + + def test_tool_message_result(self): + """测试工具消息的结果""" + msg = create_mock_tool_message('{"weather": "多云"}', "call_abc") + event = {"tools": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在只有 TOOL_RESULT + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["id"] == "call_abc" + assert "多云" in results[0].data["result"] + + def test_end_node_ignored(self): + """测试 __end__ 节点被忽略""" + event = {"__end__": {"messages": []}} + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_multiple_nodes_in_event(self): + """测试一个事件中包含多个节点""" + ai_msg = create_mock_ai_message("正在查询...") + tool_msg = create_mock_tool_message("查询结果", "call_xyz") + event = { + "__end__": {}, + "model": {"messages": [ai_msg]}, + "tools": {"messages": [tool_msg]}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 应该有 2 个结果:1 个文本 + 1 个 TOOL_RESULT + assert len(results) == 2 + assert results[0] == "正在查询..." + assert results[1].event == EventType.TOOL_RESULT + + def test_custom_messages_key(self): + """测试自定义 messages_key""" + msg = create_mock_ai_message("自定义消息") + event = {"model": {"custom_messages": [msg]}} + + # 使用默认 key 应该找不到消息 + results = list( + AgentRunConverter.to_agui_events(event, messages_key="messages") + ) + assert len(results) == 0 + + # 使用正确的 key + results = list( + AgentRunConverter.to_agui_events( + event, messages_key="custom_messages" + ) + ) + assert len(results) == 1 + assert results[0] == "自定义消息" + + +# ============================================================================= +# 测试 stream/astream(stream_mode="values") 格式的转换 +# ============================================================================= + + +class TestConvertStreamValuesFormat: + """测试 stream(values) 格式的事件转换""" + + def test_last_ai_message_content(self): + """测试最后一条 AI 消息的内容""" + msg1 = create_mock_ai_message("第一条消息") + msg2 = create_mock_ai_message("最后一条消息") + event = {"messages": [msg1, msg2]} + + results = list(AgentRunConverter.to_agui_events(event)) + + # 只处理最后一条消息 + assert len(results) == 1 + assert results[0] == "最后一条消息" + + def test_last_ai_message_with_tool_calls(self): + """测试最后一条 AI 消息包含工具调用""" + msg = create_mock_ai_message( + "", + tool_calls=[ + {"id": "call_def", "name": "search", "args": {"query": "天气"}} + ], + ) + event = {"messages": [msg]} + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + + def test_last_tool_message_result(self): + """测试最后一条工具消息的结果""" + ai_msg = create_mock_ai_message("之前的消息") + tool_msg = create_mock_tool_message("工具结果", "call_ghi") + event = {"messages": [ai_msg, tool_msg]} + + results = list(AgentRunConverter.to_agui_events(event)) + + # 只处理最后一条消息(工具消息),现在只有 TOOL_RESULT + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + + def test_empty_messages(self): + """测试空消息列表""" + event = {"messages": []} + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + +# ============================================================================= +# 测试 StreamEvent 对象的转换 +# ============================================================================= + + +class TestConvertStreamEventObject: + """测试 StreamEvent 对象(非 dict)的转换""" + + def test_stream_event_object(self): + """测试 StreamEvent 对象自动转换为 dict""" + # 模拟 StreamEvent 对象 + chunk = create_mock_ai_message_chunk("Hello") + stream_event = MagicMock() + stream_event.event = "on_chat_model_stream" + stream_event.data = {"chunk": chunk} + stream_event.name = "model" + stream_event.run_id = "run_001" + + results = list(AgentRunConverter.to_agui_events(stream_event)) + + assert len(results) == 1 + assert results[0] == "Hello" + + +# ============================================================================= +# 测试完整流程:模拟多个事件的序列 +# ============================================================================= + + +class TestConvertEventSequence: + """测试完整的事件序列转换""" + + def test_astream_events_full_sequence(self): + """测试 astream_events 格式的完整事件序列 + + AG-UI 协议要求的事件顺序: + TOOL_CALL_START → TOOL_CALL_ARGS → TOOL_CALL_END → TOOL_CALL_RESULT + """ + events = [ + # 1. 开始工具调用 + { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "tool_run_1", + "data": {"input": {"city": "北京"}}, + }, + # 2. 工具结束 + { + "event": "on_tool_end", + "run_id": "tool_run_1", + "data": {"output": {"weather": "晴天", "temp": 25}}, + }, + # 3. LLM 流式输出 + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk("北京")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk("今天")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk("晴天")}, + }, + ] + + all_results = [] + for event in events: + all_results.extend(AgentRunConverter.to_agui_events(event)) + + # 验证结果 + # on_tool_start: 1 TOOL_CALL_CHUNK + # on_tool_end: 1 TOOL_RESULT + # 3x on_chat_model_stream: 3 个文本 + assert len(all_results) == 5 + + # 工具调用事件 + assert all_results[0].event == EventType.TOOL_CALL_CHUNK + assert all_results[1].event == EventType.TOOL_RESULT + + # 文本内容 + assert all_results[2] == "北京" + assert all_results[3] == "今天" + assert all_results[4] == "晴天" + + def test_stream_updates_full_sequence(self): + """测试 stream(updates) 格式的完整事件序列""" + events = [ + # 1. Agent 决定调用工具 + { + "agent": { + "messages": [ + create_mock_ai_message( + "", + tool_calls=[{ + "id": "call_001", + "name": "get_weather", + "args": {"city": "上海"}, + }], + ) + ] + } + }, + # 2. 工具执行结果 + { + "tools": { + "messages": [ + create_mock_tool_message( + '{"weather": "多云"}', "call_001" + ) + ] + } + }, + # 3. Agent 最终回复 + {"model": {"messages": [create_mock_ai_message("上海今天多云。")]}}, + ] + + all_results = [] + for event in events: + all_results.extend(AgentRunConverter.to_agui_events(event)) + + # 验证结果: + # - 1 TOOL_CALL_CHUNK(工具调用) + # - 1 TOOL_RESULT(工具结果) + # - 1 文本回复 + assert len(all_results) == 3 + + # 工具调用 + assert all_results[0].event == EventType.TOOL_CALL_CHUNK + assert all_results[0].data["name"] == "get_weather" + + # 工具结果 + assert all_results[1].event == EventType.TOOL_RESULT + + # 最终回复 + assert all_results[2] == "上海今天多云。" + + +# ============================================================================= +# 测试边界情况 +# ============================================================================= + + +class TestConvertEdgeCases: + """测试边界情况""" + + def test_empty_event(self): + """测试空事件""" + results = list(AgentRunConverter.to_agui_events({})) + assert len(results) == 0 + + def test_none_values(self): + """测试 None 值""" + event = { + "event": "on_chat_model_stream", + "data": {"chunk": None}, + } + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_invalid_message_type(self): + """测试无效的消息类型""" + msg = MagicMock() + msg.type = "unknown" + msg.content = "test" + event = {"model": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + # unknown 类型不会产生输出 + assert len(results) == 0 + + def test_tool_call_without_id(self): + """测试没有 ID 的工具调用""" + msg = create_mock_ai_message( + "", + tool_calls=[{"name": "test", "args": {}}], # 没有 id + ) + event = {"agent": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + # 没有 id 的工具调用应该被跳过 + assert len(results) == 0 + + def test_tool_message_without_tool_call_id(self): + """测试没有 tool_call_id 的工具消息""" + msg = MagicMock() + msg.type = "tool" + msg.content = "result" + msg.tool_call_id = None # 没有 tool_call_id + + event = {"tools": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + # 没有 tool_call_id 的工具消息应该被跳过 + assert len(results) == 0 + + def test_dict_message_format(self): + """测试字典格式的消息(而非对象)""" + event = { + "model": {"messages": [{"type": "ai", "content": "字典格式消息"}]} + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "字典格式消息" + + def test_multimodal_content(self): + """测试多模态内容(list 格式)""" + chunk = MagicMock() + chunk.content = [ + {"type": "text", "text": "这是"}, + {"type": "text", "text": "多模态内容"}, + ] + chunk.tool_call_chunks = [] + + event = { + "event": "on_chat_model_stream", + "data": {"chunk": chunk}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "这是多模态内容" + + def test_output_with_content_attribute(self): + """测试有 content 属性的工具输出""" + output = MagicMock() + output.content = "工具输出内容" + + event = { + "event": "on_tool_end", + "run_id": "run_123", + "data": {"output": output}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # on_tool_end 只发送 TOOL_CALL_RESULT(TOOL_CALL_END 在 on_tool_start 发送) + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["result"] == "工具输出内容" + + def test_unsupported_stream_mode_messages_format(self): + """测试不支持的 stream_mode='messages' 格式(元组形式) + + stream_mode='messages' 返回 (AIMessageChunk, metadata) 元组, + 不是 dict 格式,to_agui_events 不支持此格式,应该不产生输出。 + """ + # 模拟 stream_mode="messages" 返回的元组格式 + chunk = create_mock_ai_message_chunk("测试内容") + metadata = {"langgraph_node": "model"} + event = (chunk, metadata) # 元组格式 + + # 元组格式会被 _event_to_dict 转换为空字典,因此不产生输出 + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_unsupported_random_dict_format(self): + """测试不支持的随机字典格式 + + 如果传入的 dict 不匹配任何已知格式,应该不产生输出。 + """ + event = { + "random_key": "random_value", + "another_key": {"nested": "data"}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + +# ============================================================================= +# 测试 AG-UI 协议事件顺序 +# ============================================================================= + + +class TestAguiEventOrder: + """测试事件顺序 + + 简化后的事件结构: + - TOOL_CALL_CHUNK - 工具调用(包含 id, name, args_delta) + - TOOL_RESULT - 工具执行结果 + + 边界事件(如 TOOL_CALL_START/END)由协议层自动处理。 + """ + + def test_streaming_tool_call_order(self): + """测试流式工具调用的事件顺序 + + TOOL_CALL_CHUNK 应该在 TOOL_RESULT 之前 + """ + events = [ + # 第一个 chunk:包含 id、name,无 args + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_order_test", + "name": "test_tool", + "args": "", + "index": 0, + }], + ) + }, + }, + # 第二个 chunk:有 args + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"key": "value"}', + "index": 0, + }], + ) + }, + }, + # 工具开始执行 + { + "event": "on_tool_start", + "name": "test_tool", + "run_id": "run_order", + "data": { + "input": { + "key": "value", + "runtime": MagicMock(tool_call_id="call_order_test"), + } + }, + }, + # 工具执行完成 + { + "event": "on_tool_end", + "run_id": "run_order", + "data": { + "input": { + "key": "value", + "runtime": MagicMock(tool_call_id="call_order_test"), + }, + "output": "success", + }, + }, + ] + + converter = AgentRunConverter() + all_results = [] + for event in events: + all_results.extend(converter.convert(event)) + + # 提取工具调用相关事件 + tool_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event in [EventType.TOOL_CALL_CHUNK, EventType.TOOL_RESULT] + ] + + # 验证有这两种事件 + event_types = [e.event for e in tool_events] + assert EventType.TOOL_CALL_CHUNK in event_types + assert EventType.TOOL_RESULT in event_types + + # 找到第一个 TOOL_CALL_CHUNK 和 TOOL_RESULT 的索引 + chunk_idx = event_types.index(EventType.TOOL_CALL_CHUNK) + result_idx = event_types.index(EventType.TOOL_RESULT) + + # 验证顺序:TOOL_CALL_CHUNK 必须在 TOOL_RESULT 之前 + assert chunk_idx < result_idx, ( + f"TOOL_CALL_CHUNK (idx={chunk_idx}) must come before " + f"TOOL_RESULT (idx={result_idx})" + ) + + def test_streaming_tool_call_start_not_duplicated(self): + """测试流式工具调用时 TOOL_CALL_START 不会重复发送""" + events = [ + # 第一个 chunk:包含 id、name + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_no_dup", + "name": "test_tool", + "args": '{"a": 1}', + "index": 0, + }], + ) + }, + }, + # 工具开始执行(此时 START 已在上面发送,不应重复) + { + "event": "on_tool_start", + "name": "test_tool", + "run_id": "run_no_dup", + "data": { + "input": { + "a": 1, + "runtime": MagicMock(tool_call_id="call_no_dup"), + } + }, + }, + # 工具执行完成 + { + "event": "on_tool_end", + "run_id": "run_no_dup", + "data": { + "input": { + "a": 1, + "runtime": MagicMock(tool_call_id="call_no_dup"), + }, + "output": "done", + }, + }, + ] + + converter = AgentRunConverter() + all_results = [] + for event in events: + all_results.extend(converter.convert(event)) + + # 统计 TOOL_CALL_START 事件的数量 + start_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + + # 应该只有一个 TOOL_CALL_START + assert ( + len(start_events) == 1 + ), f"Expected 1 TOOL_CALL_START, got {len(start_events)}" + + def test_non_streaming_tool_call_order(self): + """测试非流式场景的工具调用事件顺序 + + 在没有 on_chat_model_stream 事件的场景下, + 事件顺序仍应正确:TOOL_CALL_CHUNK → TOOL_RESULT + """ + events = [ + # 直接工具开始(无流式事件) + { + "event": "on_tool_start", + "name": "weather", + "run_id": "run_nonstream", + "data": {"input": {"city": "北京"}}, + }, + # 工具执行完成 + { + "event": "on_tool_end", + "run_id": "run_nonstream", + "data": {"output": "晴天"}, + }, + ] + + converter = AgentRunConverter() + all_results = [] + for event in events: + all_results.extend(converter.convert(event)) + + tool_events = [r for r in all_results if isinstance(r, AgentResult)] + + event_types = [e.event for e in tool_events] + + # 验证顺序:TOOL_CALL_CHUNK → TOOL_RESULT + assert event_types == [ + EventType.TOOL_CALL_CHUNK, + EventType.TOOL_RESULT, + ], f"Unexpected order: {event_types}" + + def test_multiple_concurrent_tool_calls_order(self): + """测试多个并发工具调用时各自的事件顺序正确""" + events = [ + # 两个工具调用的第一个 chunk + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + { + "id": "call_a", + "name": "tool_a", + "args": "", + "index": 0, + }, + { + "id": "call_b", + "name": "tool_b", + "args": "", + "index": 1, + }, + ], + ) + }, + }, + # 两个工具的参数 + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + { + "id": "", + "name": "", + "args": '{"x": 1}', + "index": 0, + }, + { + "id": "", + "name": "", + "args": '{"y": 2}', + "index": 1, + }, + ], + ) + }, + }, + # 工具 A 开始 + { + "event": "on_tool_start", + "name": "tool_a", + "run_id": "run_a", + "data": { + "input": { + "x": 1, + "runtime": MagicMock(tool_call_id="call_a"), + } + }, + }, + # 工具 B 开始 + { + "event": "on_tool_start", + "name": "tool_b", + "run_id": "run_b", + "data": { + "input": { + "y": 2, + "runtime": MagicMock(tool_call_id="call_b"), + } + }, + }, + # 工具 A 结束 + { + "event": "on_tool_end", + "run_id": "run_a", + "data": { + "input": { + "x": 1, + "runtime": MagicMock(tool_call_id="call_a"), + }, + "output": "result_a", + }, + }, + # 工具 B 结束 + { + "event": "on_tool_end", + "run_id": "run_b", + "data": { + "input": { + "y": 2, + "runtime": MagicMock(tool_call_id="call_b"), + }, + "output": "result_b", + }, + }, + ] + + converter = AgentRunConverter() + all_results = [] + for event in events: + all_results.extend(converter.convert(event)) + + # 分别验证工具 A 和工具 B 的事件顺序 + for tool_id in ["call_a", "call_b"]: + tool_events = [ + (i, r) + for i, r in enumerate(all_results) + if isinstance(r, AgentResult) and r.data.get("id") == tool_id + ] + + event_types = [e.event for _, e in tool_events] + + # 验证包含所有必需事件 + assert ( + EventType.TOOL_CALL_CHUNK in event_types + ), f"Tool {tool_id} missing TOOL_CALL_CHUNK" + assert ( + EventType.TOOL_RESULT in event_types + ), f"Tool {tool_id} missing TOOL_RESULT" + + # 验证顺序:TOOL_CALL_CHUNK 应该在 TOOL_RESULT 之前 + chunk_pos = event_types.index(EventType.TOOL_CALL_CHUNK) + result_pos = event_types.index(EventType.TOOL_RESULT) + + assert ( + chunk_pos < result_pos + ), f"Tool {tool_id}: CHUNK must come before RESULT" + + +# ============================================================================= +# 集成测试:模拟完整流程 +# ============================================================================= + + +class TestConvertIntegration: + """测试 convert 与完整流程的集成""" + + def test_astream_events_full_flow(self): + """测试模拟的 astream_events 完整流程""" + mock_events = [ + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk("你好")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk(",")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk("世界!")}, + }, + ] + + results = [] + for event in mock_events: + results.extend(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 3 + assert "".join(results) == "你好,世界!" + + def test_stream_updates_full_flow(self): + """测试模拟的 stream(updates) 完整流程""" + import json + + mock_events = [ + # Agent 决定调用工具 + { + "agent": { + "messages": [ + create_mock_ai_message( + "", + tool_calls=[{ + "id": "tc_001", + "name": "get_weather", + "args": {"city": "北京"}, + }], + ) + ] + } + }, + # 工具执行结果 + { + "tools": { + "messages": [ + create_mock_tool_message( + json.dumps( + {"city": "北京", "weather": "晴天"}, + ensure_ascii=False, + ), + "tc_001", + ) + ] + } + }, + # Agent 最终回复 + { + "model": { + "messages": [create_mock_ai_message("北京今天天气晴朗。")] + } + }, + ] + + results = [] + for event in mock_events: + results.extend(AgentRunConverter.to_agui_events(event)) + + # 验证事件顺序 + assert len(results) == 3 + + # 工具调用 + assert isinstance(results[0], AgentResult) + assert results[0].event == EventType.TOOL_CALL_CHUNK + + # 工具结果 + assert isinstance(results[1], AgentResult) + assert results[1].event == EventType.TOOL_RESULT + + # 最终文本 + assert results[2] == "北京今天天气晴朗。" + + def test_stream_values_full_flow(self): + """测试模拟的 stream(values) 完整流程""" + mock_events = [ + {"messages": [create_mock_ai_message("")]}, + { + "messages": [ + create_mock_ai_message( + "", + tool_calls=[ + {"id": "tc_002", "name": "get_time", "args": {}} + ], + ) + ] + }, + { + "messages": [ + create_mock_ai_message(""), + create_mock_tool_message("2024-01-01 12:00:00", "tc_002"), + ] + }, + { + "messages": [ + create_mock_ai_message(""), + create_mock_tool_message("2024-01-01 12:00:00", "tc_002"), + create_mock_ai_message("现在是 2024年1月1日。"), + ] + }, + ] + + results = [] + for event in mock_events: + results.extend(AgentRunConverter.to_agui_events(event)) + + # 验证有工具调用事件 + tool_chunks = [ + r + for r in results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + assert len(tool_chunks) >= 1 + + # 验证有最终文本 + text_results = [r for r in results if isinstance(r, str) and r] + assert any("2024" in t for t in text_results) diff --git a/tests/unittests/integration/test_langchain_agui_integration.py b/tests/unittests/integration/test_langchain_agui_integration.py new file mode 100644 index 0000000..22d1761 --- /dev/null +++ b/tests/unittests/integration/test_langchain_agui_integration.py @@ -0,0 +1,597 @@ +"""LangChain Integration with AGUI Protocol Integration Test (Optimized with MCP & ProtocolValidator) + +测试 LangChain 与 AGUI 协议的集成,包含: +1. 真实的 langchain_mcp_adapters.MultiServerMCPClient +2. Mock MCP SSE 服务器 (Starlette) +3. Mock ChatModel (替代 OpenAI API) +4. ProtocolValidator 严格验证事件序列和内容 +""" + +import json +import socket +import threading +import time +from typing import Any, cast, Dict, List, Optional, Sequence, Union + +import httpx +from langchain.agents import create_agent +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage, BaseMessage, ToolMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.tools import tool +from langchain_mcp_adapters.client import MultiServerMCPClient +from mcp.server import Server +from mcp.server.sse import SseServerTransport +from mcp.types import TextContent +from mcp.types import Tool as MCPTool +import pytest +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route +import uvicorn + +from agentrun.integration.langchain import AgentRunConverter +from agentrun.server import AgentRequest, AgentRunServer + +# ============================================================================= +# Protocol Validator +# ============================================================================= + + +class ProtocolValidator: + + def try_parse_streaming_line( + self, line: Union[str, dict, list] + ) -> Union[str, dict, list]: + """解析流式响应行,去除前缀 'data: ' 并转换为 JSON""" + + if type(line) is not str or line.startswith("data: [DONE]"): + return line + if line.startswith("data: {") or line.startswith("data: ["): + json_str = line[len("data: ") :] + return json.loads(json_str) + return line + + def valid_json(self, got: Any, expect: Any): + """检查 json 是否匹配,如果 expect 的 value 为 mock-placeholder 则仅检查 key 存在""" + + def valid(path: str, got: Any, expect: Any): + if expect == "mock-placeholder": + assert got, f"{path} 存在但值为空" + else: + got = self.try_parse_streaming_line(got) + expect = self.try_parse_streaming_line(expect) + + if isinstance(expect, dict): + assert isinstance( + got, dict + ), f"{path} 类型不匹配,期望 dict,实际 {type(got)}" + for k, v in expect.items(): + valid(f"{path}.{k}", got.get(k), v) + + # for k in got.keys(): + # if k not in expect: + # assert False, f"{path} 多余的键: {k}" + elif isinstance(expect, list): + assert isinstance( + got, list + ), f"{path} 类型不匹配,期望 list,实际 {type(got)}" + assert len(got) == len( + expect + ), f"{path} 列表长度不匹配,期望 {len(expect)},实际 {len(got)}" + for i in range(len(expect)): + valid(f"{path}[{i}]", got[i], expect[i]) + else: + assert ( + got == expect + ), f"{path} 不匹配,期望: {type(expect)} {expect},实际: {got}" + + print("valid", got, expect) + valid("", got, expect) + + def all_field_equal(self, key: str, arr: list, strict: bool = False): + """检查列表中所有对象的指定字段值是否相等""" + value = "" + for item in arr: + data = self.try_parse_streaming_line(item) + data = cast(dict, data) + + if not strict and key not in data: + continue + + if value == "": + value = data.get(key) + + assert value == data.get( + key + ), f"Field {key} not equal: {value} != {data.get(key)}" + assert value, f"Field {key} is empty" + + +# ============================================================================= +# Mock MCP SSE Server +# ============================================================================= + + +def create_mock_mcp_sse_app(tools_config: Dict[str, Any]) -> Starlette: + """创建 Mock MCP SSE 服务器""" + mcp_server = Server("mock-mcp-server") + + @mcp_server.list_tools() + async def list_tools(): + tools = [] + for tool_name, tool_info in tools_config.items(): + tools.append( + MCPTool( + name=tool_name, + description=tool_info.get("description", ""), + inputSchema=tool_info.get( + "input_schema", {"type": "object", "properties": {}} + ), + ) + ) + return tools + + @mcp_server.call_tool() + async def call_tool(name: str, arguments: Dict[str, Any]): + tool_info = tools_config.get(name, {}) + result_func = tool_info.get("result_func") + if result_func: + result = result_func(**arguments) + if isinstance(result, dict): + return [ + TextContent( + type="text", text=json.dumps(result, ensure_ascii=False) + ) + ] + return [TextContent(type="text", text=str(result))] + return [TextContent(type="text", text="No result")] + + sse_transport = SseServerTransport("/messages/") + + async def handle_sse(request: Request) -> Response: + async with sse_transport.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await mcp_server.run( + streams[0], + streams[1], + mcp_server.create_initialization_options(), + ) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + ) + return app + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture(scope="module") +def mock_mcp_server(): + """启动 Mock MCP SSE 服务器""" + tools_config = { + "get_current_time": { + "description": "获取当前时间", + "input_schema": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "时区", + "default": "UTC", + } + }, + }, + "result_func": lambda timezone="UTC": { + "timezone": timezone, + "datetime": "2025-12-17T11:26:10+08:00", + "day_of_week": "Wednesday", + "is_dst": False, + }, + }, + "maps_weather": { + "description": "获取天气信息", + "input_schema": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "城市名称"} + }, + "required": ["city"], + }, + "result_func": lambda city: { + "city": f"{city}市", + "forecasts": [ + { + "date": "2025-12-17", + "week": "3", + "dayweather": "多云", + "nightweather": "多云", + "daytemp": "14", + "nighttemp": "4", + "daywind": "北", + "nightwind": "北", + "daypower": "1-3", + "nightpower": "1-3", + "daytemp_float": "14.0", + "nighttemp_float": "4.0", + }, + { + "date": "2025-12-18", + "week": "4", + "dayweather": "晴", + "nightweather": "晴", + "daytemp": "15", + "nighttemp": "6", + "daywind": "东", + "nightwind": "东", + "daypower": "1-3", + "nightpower": "1-3", + "daytemp_float": "15.0", + "nighttemp_float": "6.0", + }, + { + "date": "2025-12-19", + "week": "5", + "dayweather": "晴", + "nightweather": "阴", + "daytemp": "21", + "nighttemp": "12", + "daywind": "东南", + "nightwind": "东南", + "daypower": "1-3", + "nightpower": "1-3", + "daytemp_float": "21.0", + "nighttemp_float": "12.0", + }, + { + "date": "2025-12-20", + "week": "6", + "dayweather": "阴", + "nightweather": "小雨", + "daytemp": "20", + "nighttemp": "8", + "daywind": "东北", + "nightwind": "东北", + "daypower": "1-3", + "nightpower": "1-3", + "daytemp_float": "20.0", + "nighttemp_float": "8.0", + }, + ], + }, + }, + } + + app = create_mock_mcp_sse_app(tools_config) + port = _find_free_port() + config = uvicorn.Config( + app, host="127.0.0.1", port=port, log_level="critical" + ) + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + base_url = f"http://127.0.0.1:{port}" + + start_time = time.time() + while time.time() - start_time < 10: + try: + with httpx.Client() as client: + resp = client.get(f"{base_url}/sse", timeout=0.5) + if resp.status_code in (200, 500): + break + except (httpx.ConnectError, httpx.ReadTimeout): + time.sleep(0.1) + + yield base_url + server.should_exit = True + thread.join(timeout=2) + + +# ============================================================================= +# Local Tools +# ============================================================================= + + +@tool +def get_user_name(): + """获取用户名称""" + return {"user_name": "张三"} + + +@tool +def get_user_token(user_name: str): + """获取用户的密钥,输入为用户名""" + return "ak_1234asd12341" + + +# ============================================================================= +# Mock Chat Model +# ============================================================================= + + +class MockChatModel(BaseChatModel): + """模拟 ChatOpenAI 的行为""" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Any = None, + **kwargs: Any, + ) -> ChatResult: + tool_outputs = [m for m in messages if isinstance(m, ToolMessage)] + + # Round 1: Call MCP time + Local name + if len(tool_outputs) == 0: + return ChatResult( + generations=[ + ChatGeneration( + message=AIMessage( + content=( + "我需要获取当前时间、天气信息和您的密钥信息。" + "让我依次处理这些请求。\n\n首先," + "让我获取当前时间:\n\n" + ), + tool_calls=[ + { + "name": "get_current_time", + "args": {"timezone": "Asia/Shanghai"}, + "id": "call_time", + }, + { + "name": "get_user_name", + "args": {}, + "id": "call_name", + }, + ], + ) + ) + ] + ) + + # Round 2: Call MCP weather + Local token + if len(tool_outputs) == 2: + return ChatResult( + generations=[ + ChatGeneration( + message=AIMessage( + content=( + "现在我已获取到当前时间和您的用户名。接下来," + "让我获取天气信息和您的密钥:\n\n" + ), + tool_calls=[ + { + "name": "maps_weather", + "args": {"city": "上海"}, + "id": "call_weather", + }, + { + "name": "get_user_token", + "args": {"user_name": "张三"}, + "id": "call_token", + }, + ], + ) + ) + ] + ) + + # Round 3: Finish + return ChatResult( + generations=[ + ChatGeneration( + message=AIMessage( + content=( + "以下是您请求的信息:\n\n## 当前时间\n-" + " 日期:2025年12月17日(星期三)\n-" + " 时间:11:26:10(北京时间,UTC+8)\n\n##" + " 天气信息(上海市)\n**今日(12月17日)天气:**\n-" + " 白天:多云,14°C,北风1-3级\n- 夜间:多云,4°C," + "北风1-3级\n\n**未来几天预报:**\n- 12月18日:晴," + "6-15°C\n- 12月19日:晴转阴,12-21°C \n-" + " 12月20日:阴转小雨,8-20°C\n\n##" + " 您的密钥信息\n您的用户密钥为:`ak_1234asd12341`\n\n请注意妥善保管您的密钥信息," + "不要在公共场合泄露。" + ), + ) + ) + ] + ) + + @property + def _llm_type(self) -> str: + return "mock-chat-model" + + def bind_tools(self, tools: Sequence[Any], **kwargs: Any): + return self + + +# ============================================================================= +# Tests +# ============================================================================= + + +class TestLangChainAguiIntegration(ProtocolValidator): + + async def test_multi_tool_query(self, mock_mcp_server): + """测试多工具查询场景 (MCP + Local + MockLLM)""" + + mcp_client = MultiServerMCPClient( + { + "tools": { + "url": f"{mock_mcp_server}/sse", + "transport": "sse", + } + } + ) + mcp_tools = await mcp_client.get_tools() + + async def invoke_agent(request: AgentRequest): + llm = MockChatModel() + tools = [*mcp_tools, get_user_name, get_user_token] + + agent = create_agent( + model=llm, + system_prompt="You are a helpful assistant", + tools=tools, + ) + + input_data = { + "messages": [{ + "role": "user", + "content": ( + "查询当前的时间,并获取天气信息,同时输出我的密钥信息" + ), + }] + } + + converter = AgentRunConverter() + async for event in agent.astream_events(input_data, version="v2"): + for item in converter.convert(event): + yield item + + app = AgentRunServer(invoke_agent=invoke_agent).app + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://test", + ) as client: + response = await client.post( + "/ag-ui/agent", + json={ + "messages": [{ + "role": "user", + "content": "查询当前的时间,并获取天气信息,同时输出我的密钥信息", + }], + "stream": True, + }, + timeout=60.0, + ) + + assert response.status_code == 200 + + events = [line for line in response.text.split("\n") if line] + expected = [ + ( + "data:" + ' {"type":"RUN_STARTED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_START","messageId":"mock-placeholder","role":"assistant"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"我需要获取当前时间、' + "天气信息和您的密钥信息。让我依次处理这些请求。\\n\\n首先," + '让我获取当前时间:\\n\\n"}' + ), + 'data: {"type":"TEXT_MESSAGE_END","messageId":"mock-placeholder"}', + ( + "data:" + ' {"type":"TOOL_CALL_START","toolCallId":"call_time","toolCallName":"get_current_time"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_ARGS","toolCallId":"call_time","delta":"{\\"timezone\\":' + ' \\"Asia/Shanghai\\"}"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_START","toolCallId":"call_name","toolCallName":"get_user_name"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_ARGS","toolCallId":"call_name","delta":""}' + ), + 'data: {"type":"TOOL_CALL_END","toolCallId":"call_name"}', + ( + "data:" + ' {"type":"TOOL_CALL_RESULT","messageId":"tool-result-call_name","toolCallId":"call_name","content":"{\\"user_name\\":' + ' \\"张三\\"}","role":"tool"}' + ), + 'data: {"type":"TOOL_CALL_END","toolCallId":"call_time"}', + ( + "data:" + ' {"type":"TOOL_CALL_RESULT","messageId":"tool-result-call_time","toolCallId":"call_time","content":"mock-placeholder","role":"tool"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_START","messageId":"mock-placeholder","role":"assistant"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"现在我已获取到当前时间和您的用户名。' + '接下来,让我获取天气信息和您的密钥:\\n\\n"}' + ), + 'data: {"type":"TEXT_MESSAGE_END","messageId":"mock-placeholder"}', + ( + "data:" + ' {"type":"TOOL_CALL_START","toolCallId":"call_weather","toolCallName":"maps_weather"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_ARGS","toolCallId":"call_weather","delta":"{\\"city\\":' + ' \\"上海\\"}"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_START","toolCallId":"call_token","toolCallName":"get_user_token"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_ARGS","toolCallId":"call_token","delta":"{\\"user_name\\":' + ' \\"张三\\"}"}' + ), + 'data: {"type":"TOOL_CALL_END","toolCallId":"call_token"}', + ( + "data:" + ' {"type":"TOOL_CALL_RESULT","messageId":"tool-result-call_token","toolCallId":"call_token","content":"ak_1234asd12341","role":"tool"}' + ), + 'data: {"type":"TOOL_CALL_END","toolCallId":"call_weather"}', + ( + "data:" + ' {"type":"TOOL_CALL_RESULT","messageId":"tool-result-call_weather","toolCallId":"call_weather","content":"mock-placeholder","role":"tool"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_START","messageId":"mock-placeholder","role":"assistant"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"以下是您请求的信息:\\n\\n##' + " 当前时间\\n- 日期:2025年12月17日(星期三)\\n-" + " 时间:11:26:10(北京时间,UTC+8)\\n\\n##" + " 天气信息(上海市)\\n**今日(12月17日)天气:**\\n-" + " 白天:多云,14°C,北风1-3级\\n- 夜间:多云,4°C," + "北风1-3级\\n\\n**未来几天预报:**\\n- 12月18日:晴,6-15°C\\n-" + " 12月19日:晴转阴,12-21°C \\n- 12月20日:阴转小雨," + "8-20°C\\n\\n##" + " 您的密钥信息\\n您的用户密钥为:`ak_1234asd12341`\\n\\n请注意妥善保管您的密钥信息," + '不要在公共场合泄露。"}' + ), + 'data: {"type":"TEXT_MESSAGE_END","messageId":"mock-placeholder"}', + ( + "data:" + ' {"type":"RUN_FINISHED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ] + + self.valid_json(events, expected) + + self.all_field_equal("runId", events) + self.all_field_equal("threadId", events) + self.all_field_equal("messageId", events[1:4]) + self.all_field_equal("messageId", events[12:15]) + self.all_field_equal("messageId", events[24:27]) diff --git a/tests/unittests/integration/test_langchain_convert.py b/tests/unittests/integration/test_langchain_convert.py new file mode 100644 index 0000000..bcdb421 --- /dev/null +++ b/tests/unittests/integration/test_langchain_convert.py @@ -0,0 +1,1805 @@ +"""测试 convert 函数 / Test convert Function + +测试 convert 函数对不同 LangChain/LangGraph 调用方式返回事件格式的兼容性。 +支持的格式: +- astream_events(version="v2") 格式 +- stream/astream(stream_mode="updates") 格式 +- stream/astream(stream_mode="values") 格式 +""" + +from unittest.mock import MagicMock + +import pytest + +from agentrun.integration.langgraph.agent_converter import AgentRunConverter +from agentrun.server.model import AgentResult, EventType +# 使用 conftest.py 中的公共 mock 函数 +from tests.unittests.integration.conftest import ( + create_mock_ai_message, + create_mock_ai_message_chunk, + create_mock_tool_message, +) + +# ============================================================================= +# 测试事件格式检测函数 +# ============================================================================= + + +class TestEventFormatDetection: + """测试事件格式检测函数""" + + def test_is_astream_events_format(self): + """测试 astream_events 格式检测""" + # 正确的 astream_events 格式 + assert AgentRunConverter.is_astream_events_format( + {"event": "on_chat_model_stream", "data": {}} + ) + assert AgentRunConverter.is_astream_events_format( + {"event": "on_tool_start", "data": {}} + ) + assert AgentRunConverter.is_astream_events_format( + {"event": "on_tool_end", "data": {}} + ) + assert AgentRunConverter.is_astream_events_format( + {"event": "on_chain_stream", "data": {}} + ) + + # 不是 astream_events 格式 + assert not AgentRunConverter.is_astream_events_format( + {"model": {"messages": []}} + ) + assert not AgentRunConverter.is_astream_events_format({"messages": []}) + assert not AgentRunConverter.is_astream_events_format({}) + assert not AgentRunConverter.is_astream_events_format( + {"event": "custom_event"} + ) # 不以 on_ 开头 + + def test_is_stream_updates_format(self): + """测试 stream(updates) 格式检测""" + # 正确的 updates 格式 + assert AgentRunConverter.is_stream_updates_format( + {"model": {"messages": []}} + ) + assert AgentRunConverter.is_stream_updates_format( + {"agent": {"messages": []}} + ) + assert AgentRunConverter.is_stream_updates_format( + {"tools": {"messages": []}} + ) + assert AgentRunConverter.is_stream_updates_format( + {"__end__": {}, "model": {"messages": []}} + ) + + # 不是 updates 格式 + assert not AgentRunConverter.is_stream_updates_format( + {"event": "on_chat_model_stream"} + ) + assert not AgentRunConverter.is_stream_updates_format( + {"messages": []} + ) # 这是 values 格式 + assert not AgentRunConverter.is_stream_updates_format({}) + + def test_is_stream_values_format(self): + """测试 stream(values) 格式检测""" + # 正确的 values 格式 + assert AgentRunConverter.is_stream_values_format({"messages": []}) + assert AgentRunConverter.is_stream_values_format( + {"messages": [MagicMock()]} + ) + + # 不是 values 格式 + assert not AgentRunConverter.is_stream_values_format( + {"event": "on_chat_model_stream"} + ) + assert not AgentRunConverter.is_stream_values_format( + {"model": {"messages": []}} + ) + assert not AgentRunConverter.is_stream_values_format({}) + + +# ============================================================================= +# 测试 astream_events 格式的转换 +# ============================================================================= + + +class TestConvertAstreamEventsFormat: + """测试 astream_events 格式的事件转换""" + + def test_on_chat_model_stream_text_content(self): + """测试 on_chat_model_stream 事件的文本内容提取""" + chunk = create_mock_ai_message_chunk("你好") + event = { + "event": "on_chat_model_stream", + "data": {"chunk": chunk}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "你好" + + def test_on_chat_model_stream_empty_content(self): + """测试 on_chat_model_stream 事件的空内容""" + chunk = create_mock_ai_message_chunk("") + event = { + "event": "on_chat_model_stream", + "data": {"chunk": chunk}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_on_chat_model_stream_with_tool_call_args(self): + """测试 on_chat_model_stream 事件的工具调用参数""" + chunk = create_mock_ai_message_chunk( + "", + tool_call_chunks=[{ + "id": "call_123", + "name": "get_weather", + "args": '{"city": "北京"}', + }], + ) + event = { + "event": "on_chat_model_stream", + "data": {"chunk": chunk}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 第一个 chunk 有 id 和 name 时,发送完整的 TOOL_CALL_CHUNK + assert len(results) == 1 + assert isinstance(results[0], AgentResult) + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_123" + assert results[0].data["name"] == "get_weather" + assert results[0].data["args_delta"] == '{"city": "北京"}' + + def test_on_tool_start(self): + """测试 on_tool_start 事件""" + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "run_456", + "data": {"input": {"city": "北京"}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK(边界事件由协议层自动处理) + assert len(results) == 1 + assert isinstance(results[0], AgentResult) + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "run_456" + assert results[0].data["name"] == "get_weather" + assert "city" in results[0].data["args_delta"] + + def test_on_tool_start_without_input(self): + """测试 on_tool_start 事件(无输入参数)""" + event = { + "event": "on_tool_start", + "name": "get_time", + "run_id": "run_789", + "data": {}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK(边界事件由协议层自动处理) + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "run_789" + assert results[0].data["name"] == "get_time" + + def test_on_tool_end(self): + """测试 on_tool_end 事件 + + AG-UI 协议:TOOL_CALL_END 在 on_tool_start 中发送(参数传输完成时), + TOOL_CALL_RESULT 在 on_tool_end 中发送(工具执行完成时)。 + """ + event = { + "event": "on_tool_end", + "run_id": "run_456", + "data": {"output": {"weather": "晴天", "temperature": 25}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # on_tool_end 只发送 TOOL_CALL_RESULT + assert len(results) == 1 + + # TOOL_CALL_RESULT + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["id"] == "run_456" + assert "晴天" in results[0].data["result"] + + def test_on_tool_end_with_string_output(self): + """测试 on_tool_end 事件(字符串输出)""" + event = { + "event": "on_tool_end", + "run_id": "run_456", + "data": {"output": "晴天,25度"}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # on_tool_end 只发送 TOOL_CALL_RESULT + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["result"] == "晴天,25度" + + def test_on_tool_start_with_non_jsonable_args(self): + """工具输入包含不可 JSON 序列化对象时也能正常转换""" + + class Dummy: + + def __str__(self): + return "dummy_obj" + + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "run_non_json", + "data": {"input": {"obj": Dummy()}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "run_non_json" + assert "dummy_obj" in results[0].data["args_delta"] + + def test_on_tool_start_filters_internal_runtime_field(self): + """测试 on_tool_start 过滤 MCP 注入的 runtime 等内部字段""" + + class FakeToolRuntime: + """模拟 MCP 的 ToolRuntime 对象""" + + def __str__(self): + return "ToolRuntime(...huge internal state...)" + + event = { + "event": "on_tool_start", + "name": "maps_weather", + "run_id": "run_mcp_tool", + "data": { + "input": { + "city": "北京", # 用户实际参数 + "runtime": FakeToolRuntime(), # MCP 注入的内部字段 + "config": {"internal": "state"}, # 另一个内部字段 + "__pregel_runtime": "internal", # LangGraph 内部字段 + } + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["name"] == "maps_weather" + + delta = results[0].data["args_delta"] + # 应该只包含用户参数 city + assert "北京" in delta + # 不应该包含内部字段 + assert "runtime" not in delta.lower() or "ToolRuntime" not in delta + assert "internal" not in delta + assert "__pregel" not in delta + + def test_on_tool_start_uses_runtime_tool_call_id(self): + """测试 on_tool_start 使用 runtime 中的原始 tool_call_id 而非 run_id + + MCP 工具会在 input.runtime 中注入 tool_call_id,这是 LLM 返回的原始 ID。 + 应该优先使用这个 ID,以保证工具调用事件的 ID 一致性。 + """ + + class FakeToolRuntime: + """模拟 MCP 的 ToolRuntime 对象""" + + def __init__(self, tool_call_id: str): + self.tool_call_id = tool_call_id + + original_tool_call_id = "call_original_from_llm_12345" + + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": ( + "run_id_different_from_tool_call_id" + ), # run_id 与 tool_call_id 不同 + "data": { + "input": { + "city": "北京", + "runtime": FakeToolRuntime(original_tool_call_id), + } + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + + # 应该使用 runtime 中的原始 tool_call_id,而不是 run_id + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == original_tool_call_id + assert results[0].data["name"] == "get_weather" + + def test_on_tool_end_uses_runtime_tool_call_id(self): + """测试 on_tool_end 使用 runtime 中的原始 tool_call_id 而非 run_id""" + + class FakeToolRuntime: + """模拟 MCP 的 ToolRuntime 对象""" + + def __init__(self, tool_call_id: str): + self.tool_call_id = tool_call_id + + original_tool_call_id = "call_original_from_llm_67890" + + event = { + "event": "on_tool_end", + "run_id": "run_id_different_from_tool_call_id", + "data": { + "output": {"weather": "晴天", "temp": 25}, + "input": { + "city": "北京", + "runtime": FakeToolRuntime(original_tool_call_id), + }, + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # on_tool_end 只发送 TOOL_CALL_RESULT(TOOL_CALL_END 在 on_tool_start 发送) + assert len(results) == 1 + + # 应该使用 runtime 中的原始 tool_call_id + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["id"] == original_tool_call_id + + def test_on_tool_start_fallback_to_run_id(self): + """测试当 runtime 中没有 tool_call_id 时,回退使用 run_id""" + event = { + "event": "on_tool_start", + "name": "get_time", + "run_id": "run_789", + "data": {"input": {"timezone": "Asia/Shanghai"}}, # 没有 runtime + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + # 应该回退使用 run_id + assert results[0].data["id"] == "run_789" + + def test_streaming_tool_call_id_consistency_with_map(self): + """测试流式工具调用的 tool_call_id 一致性(使用映射) + + 在流式工具调用中: + - 第一个 chunk 有 id 但可能没有 args(用于建立映射) + - 后续 chunk 有 args 但 id 为空,只有 index(从映射查找 id) + + 使用 tool_call_id_map 可以确保 ID 一致性。 + """ + # 模拟流式工具调用的多个 chunk + events = [ + # 第一个 chunk: 有 id 和 name,没有 args(只用于建立映射) + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_abc123", + "name": "browser_navigate", + "args": "", + "index": 0, + }], + ) + }, + }, + # 第二个 chunk: id 为空,只有 index 和 args + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"url": "https://', + "index": 0, + }], + ) + }, + }, + # 第三个 chunk: id 为空,继续 args + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": 'example.com"}', + "index": 0, + }], + ) + }, + }, + ] + + # 使用 tool_call_id_map 来确保 ID 一致性 + tool_call_id_map: Dict[int, str] = {} + all_results = [] + + for event in events: + results = list( + AgentRunConverter.to_agui_events( + event, tool_call_id_map=tool_call_id_map + ) + ) + all_results.extend(results) + + # 验证映射已建立 + assert 0 in tool_call_id_map + assert tool_call_id_map[0] == "call_abc123" + + # 验证:所有 TOOL_CALL_CHUNK 都使用相同的 tool_call_id + chunk_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + + # 应该有 3 个 TOOL_CALL_CHUNK 事件(每个 chunk 一个) + assert len(chunk_events) == 3 + + # 所有事件应该使用相同的 tool_call_id(从映射获取) + for event in chunk_events: + assert event.data["id"] == "call_abc123" + + def test_streaming_tool_call_id_without_map_uses_index(self): + """测试不使用映射时,后续 chunk 回退到 index""" + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"url": "test"}', + "index": 0, + }], + ) + }, + } + + # 不传入 tool_call_id_map + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + # 回退使用 index + assert results[0].data["id"] == "0" + + def test_streaming_multiple_concurrent_tool_calls(self): + """测试多个并发工具调用(不同 index)的 ID 一致性""" + # 模拟 LLM 同时调用两个工具 + events = [ + # 第一个 chunk: 两个工具调用的 ID + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + { + "id": "call_tool1", + "name": "search", + "args": "", + "index": 0, + }, + { + "id": "call_tool2", + "name": "weather", + "args": "", + "index": 1, + }, + ], + ) + }, + }, + # 后续 chunk: 只有 index 和 args + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + { + "id": "", + "name": "", + "args": '{"q": "test"', + "index": 0, + }, + ], + ) + }, + }, + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + { + "id": "", + "name": "", + "args": '{"city": "北京"', + "index": 1, + }, + ], + ) + }, + }, + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + {"id": "", "name": "", "args": "}", "index": 0}, + {"id": "", "name": "", "args": "}", "index": 1}, + ], + ) + }, + }, + ] + + tool_call_id_map: Dict[int, str] = {} + all_results = [] + + for event in events: + results = list( + AgentRunConverter.to_agui_events( + event, tool_call_id_map=tool_call_id_map + ) + ) + all_results.extend(results) + + # 验证映射正确建立 + assert tool_call_id_map[0] == "call_tool1" + assert tool_call_id_map[1] == "call_tool2" + + # 验证所有事件使用正确的 ID + chunk_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + + # 应该有 6 个 TOOL_CALL_CHUNK 事件 + # - 2 个初始 chunk(id + name) + # - 4 个 args chunk + assert len(chunk_events) == 6 + + # 验证每个工具调用使用正确的 ID + tool1_chunks = [e for e in chunk_events if e.data["id"] == "call_tool1"] + tool2_chunks = [e for e in chunk_events if e.data["id"] == "call_tool2"] + + assert len(tool1_chunks) == 3 # 初始 + '{"q": "test"' + '}' + assert len(tool2_chunks) == 3 # 初始 + '{"city": "北京"' + '}' + + def test_agentrun_converter_class(self): + """测试 AgentRunConverter 类的完整功能""" + from agentrun.integration.langchain import AgentRunConverter + + events = [ + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_xyz", + "name": "test_tool", + "args": "", + "index": 0, + }], + ) + }, + }, + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"key": "value"}', + "index": 0, + }], + ) + }, + }, + ] + + converter = AgentRunConverter() + all_results = [] + + for event in events: + results = list(converter.convert(event)) + all_results.extend(results) + + # 验证内部映射 + assert converter._tool_call_id_map[0] == "call_xyz" + + # 验证结果 + chunk_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + # 现在有 2 个 chunk 事件(每个 stream chunk 一个) + assert len(chunk_events) == 2 + # 所有事件应该使用相同的 ID + for event in chunk_events: + assert event.data["id"] == "call_xyz" + + # 测试 reset + converter.reset() + assert len(converter._tool_call_id_map) == 0 + + def test_streaming_tool_call_with_first_chunk_having_args(self): + """测试第一个 chunk 同时有 id 和 args 的情况""" + # 有些模型可能在第一个 chunk 就返回完整的工具调用 + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_complete", + "name": "simple_tool", + "args": '{"done": true}', + "index": 0, + }], + ) + }, + } + + tool_call_id_map: Dict[int, str] = {} + tool_call_started_set: set = set() + results = list( + AgentRunConverter.to_agui_events( + event, + tool_call_id_map=tool_call_id_map, + tool_call_started_set=tool_call_started_set, + ) + ) + + # 验证映射被建立 + assert tool_call_id_map[0] == "call_complete" + # 验证 START 已发送 + assert "call_complete" in tool_call_started_set + + # 现在是单个 TOOL_CALL_CHUNK(包含 id, name, args_delta) + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_complete" + assert results[0].data["name"] == "simple_tool" + assert results[0].data["args_delta"] == '{"done": true}' + + def test_streaming_tool_call_id_none_vs_empty_string(self): + """测试 id 为 None 和空字符串的不同处理""" + events = [ + # id 为 None(建立映射) + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_from_none", + "name": "tool", + "args": "", + "index": 0, + }], + ) + }, + }, + # id 为 None(应该从映射获取) + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": None, + "name": "", + "args": '{"a": 1}', + "index": 0, + }], + ) + }, + }, + ] + + tool_call_id_map: Dict[int, str] = {} + all_results = [] + + for event in events: + results = list( + AgentRunConverter.to_agui_events( + event, tool_call_id_map=tool_call_id_map + ) + ) + all_results.extend(results) + + chunk_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + + # 现在有 2 个 chunk 事件(每个 stream chunk 一个) + assert len(chunk_events) == 2 + # 所有事件应该使用相同的 ID(从映射获取) + for event in chunk_events: + assert event.data["id"] == "call_from_none" + + def test_full_tool_call_flow_id_consistency(self): + """测试完整工具调用流程中的 ID 一致性 + + 模拟: + 1. on_chat_model_stream 产生 TOOL_CALL_CHUNK + 2. on_tool_start 不产生事件(已在流式中处理) + 3. on_tool_end 产生 TOOL_RESULT + + 验证所有事件使用相同的 tool_call_id + """ + # 模拟完整的工具调用流程 + events = [ + # 流式工具调用参数(第一个 chunk 有 id 和 name) + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_full_flow", + "name": "test_tool", + "args": "", + "index": 0, + }], + ) + }, + }, + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"param": "value"}', + "index": 0, + }], + ) + }, + }, + # 工具开始(使用 runtime.tool_call_id) + { + "event": "on_tool_start", + "name": "test_tool", + "run_id": "run_123", + "data": { + "input": { + "param": "value", + "runtime": MagicMock(tool_call_id="call_full_flow"), + } + }, + }, + # 工具结束 + { + "event": "on_tool_end", + "run_id": "run_123", + "data": { + "input": { + "param": "value", + "runtime": MagicMock(tool_call_id="call_full_flow"), + }, + "output": "success", + }, + }, + ] + + converter = AgentRunConverter() + all_results = [] + + for event in events: + results = list(converter.convert(event)) + all_results.extend(results) + + # 获取所有工具调用相关事件 + tool_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event in [EventType.TOOL_CALL_CHUNK, EventType.TOOL_RESULT] + ] + + # 验证所有事件都使用相同的 tool_call_id + for event in tool_events: + assert ( + event.data["id"] == "call_full_flow" + ), f"Event {event.event} has wrong id: {event.data.get('id')}" + + # 验证所有事件类型都存在 + event_types = [e.event for e in tool_events] + assert EventType.TOOL_CALL_CHUNK in event_types + assert EventType.TOOL_RESULT in event_types + + # 验证顺序:TOOL_CALL_CHUNK 必须在 TOOL_RESULT 之前 + chunk_idx = event_types.index(EventType.TOOL_CALL_CHUNK) + result_idx = event_types.index(EventType.TOOL_RESULT) + assert ( + chunk_idx < result_idx + ), "TOOL_CALL_CHUNK must come before TOOL_RESULT" + + def test_on_chain_stream_model_node(self): + """测试 on_chain_stream 事件(model 节点)""" + msg = create_mock_ai_message("你好!有什么可以帮你的吗?") + event = { + "event": "on_chain_stream", + "name": "model", + "data": {"chunk": {"messages": [msg]}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "你好!有什么可以帮你的吗?" + + def test_on_chain_stream_non_model_node(self): + """测试 on_chain_stream 事件(非 model 节点)""" + event = { + "event": "on_chain_stream", + "name": "agent", # 不是 "model" + "data": {"chunk": {"messages": []}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_on_chat_model_end_ignored(self): + """测试 on_chat_model_end 事件被忽略(避免重复)""" + event = { + "event": "on_chat_model_end", + "data": {"output": create_mock_ai_message("完成")}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + +# ============================================================================= +# 测试 stream/astream(stream_mode="updates") 格式的转换 +# ============================================================================= + + +class TestConvertStreamUpdatesFormat: + """测试 stream(updates) 格式的事件转换""" + + def test_ai_message_text_content(self): + """测试 AI 消息的文本内容""" + msg = create_mock_ai_message("你好!") + event = {"model": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "你好!" + + def test_ai_message_empty_content(self): + """测试 AI 消息的空内容""" + msg = create_mock_ai_message("") + event = {"model": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_ai_message_with_tool_calls(self): + """测试 AI 消息包含工具调用""" + msg = create_mock_ai_message( + "", + tool_calls=[{ + "id": "call_abc", + "name": "get_weather", + "args": {"city": "上海"}, + }], + ) + event = {"agent": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_abc" + assert results[0].data["name"] == "get_weather" + assert "上海" in results[0].data["args_delta"] + + def test_tool_message_result(self): + """测试工具消息的结果""" + msg = create_mock_tool_message('{"weather": "多云"}', "call_abc") + event = {"tools": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在只有 TOOL_RESULT + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["id"] == "call_abc" + assert "多云" in results[0].data["result"] + + def test_end_node_ignored(self): + """测试 __end__ 节点被忽略""" + event = {"__end__": {"messages": []}} + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_multiple_nodes_in_event(self): + """测试一个事件中包含多个节点""" + ai_msg = create_mock_ai_message("正在查询...") + tool_msg = create_mock_tool_message("查询结果", "call_xyz") + event = { + "__end__": {}, + "model": {"messages": [ai_msg]}, + "tools": {"messages": [tool_msg]}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # 应该有 2 个结果:1 个文本 + 1 个 TOOL_RESULT + assert len(results) == 2 + assert results[0] == "正在查询..." + assert results[1].event == EventType.TOOL_RESULT + + def test_custom_messages_key(self): + """测试自定义 messages_key""" + msg = create_mock_ai_message("自定义消息") + event = {"model": {"custom_messages": [msg]}} + + # 使用默认 key 应该找不到消息 + results = list( + AgentRunConverter.to_agui_events(event, messages_key="messages") + ) + assert len(results) == 0 + + # 使用正确的 key + results = list( + AgentRunConverter.to_agui_events( + event, messages_key="custom_messages" + ) + ) + assert len(results) == 1 + assert results[0] == "自定义消息" + + +# ============================================================================= +# 测试 stream/astream(stream_mode="values") 格式的转换 +# ============================================================================= + + +class TestConvertStreamValuesFormat: + """测试 stream(values) 格式的事件转换""" + + def test_last_ai_message_content(self): + """测试最后一条 AI 消息的内容""" + msg1 = create_mock_ai_message("第一条消息") + msg2 = create_mock_ai_message("最后一条消息") + event = {"messages": [msg1, msg2]} + + results = list(AgentRunConverter.to_agui_events(event)) + + # 只处理最后一条消息 + assert len(results) == 1 + assert results[0] == "最后一条消息" + + def test_last_ai_message_with_tool_calls(self): + """测试最后一条 AI 消息包含工具调用""" + msg = create_mock_ai_message( + "", + tool_calls=[ + {"id": "call_def", "name": "search", "args": {"query": "天气"}} + ], + ) + event = {"messages": [msg]} + + results = list(AgentRunConverter.to_agui_events(event)) + + # 现在是单个 TOOL_CALL_CHUNK + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + + def test_last_tool_message_result(self): + """测试最后一条工具消息的结果""" + ai_msg = create_mock_ai_message("之前的消息") + tool_msg = create_mock_tool_message("工具结果", "call_ghi") + event = {"messages": [ai_msg, tool_msg]} + + results = list(AgentRunConverter.to_agui_events(event)) + + # 只处理最后一条消息(工具消息),现在只有 TOOL_RESULT + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + + def test_empty_messages(self): + """测试空消息列表""" + event = {"messages": []} + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + +# ============================================================================= +# 测试 StreamEvent 对象的转换 +# ============================================================================= + + +class TestConvertStreamEventObject: + """测试 StreamEvent 对象(非 dict)的转换""" + + def test_stream_event_object(self): + """测试 StreamEvent 对象自动转换为 dict""" + # 模拟 StreamEvent 对象 + chunk = create_mock_ai_message_chunk("Hello") + stream_event = MagicMock() + stream_event.event = "on_chat_model_stream" + stream_event.data = {"chunk": chunk} + stream_event.name = "model" + stream_event.run_id = "run_001" + + results = list(AgentRunConverter.to_agui_events(stream_event)) + + assert len(results) == 1 + assert results[0] == "Hello" + + +# ============================================================================= +# 测试完整流程:模拟多个事件的序列 +# ============================================================================= + + +class TestConvertEventSequence: + """测试完整的事件序列转换""" + + def test_astream_events_full_sequence(self): + """测试 astream_events 格式的完整事件序列 + + AG-UI 协议要求的事件顺序: + TOOL_CALL_START → TOOL_CALL_ARGS → TOOL_CALL_END → TOOL_CALL_RESULT + """ + events = [ + # 1. 开始工具调用 + { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "tool_run_1", + "data": {"input": {"city": "北京"}}, + }, + # 2. 工具结束 + { + "event": "on_tool_end", + "run_id": "tool_run_1", + "data": {"output": {"weather": "晴天", "temp": 25}}, + }, + # 3. LLM 流式输出 + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk("北京")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk("今天")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk("晴天")}, + }, + ] + + all_results = [] + for event in events: + all_results.extend(AgentRunConverter.to_agui_events(event)) + + # 验证结果 + # on_tool_start: 1 TOOL_CALL_CHUNK + # on_tool_end: 1 TOOL_RESULT + # 3x on_chat_model_stream: 3 个文本 + assert len(all_results) == 5 + + # 工具调用事件 + assert all_results[0].event == EventType.TOOL_CALL_CHUNK + assert all_results[1].event == EventType.TOOL_RESULT + + # 文本内容 + assert all_results[2] == "北京" + assert all_results[3] == "今天" + assert all_results[4] == "晴天" + + def test_stream_updates_full_sequence(self): + """测试 stream(updates) 格式的完整事件序列""" + events = [ + # 1. Agent 决定调用工具 + { + "agent": { + "messages": [ + create_mock_ai_message( + "", + tool_calls=[{ + "id": "call_001", + "name": "get_weather", + "args": {"city": "上海"}, + }], + ) + ] + } + }, + # 2. 工具执行结果 + { + "tools": { + "messages": [ + create_mock_tool_message( + '{"weather": "多云"}', "call_001" + ) + ] + } + }, + # 3. Agent 最终回复 + {"model": {"messages": [create_mock_ai_message("上海今天多云。")]}}, + ] + + all_results = [] + for event in events: + all_results.extend(AgentRunConverter.to_agui_events(event)) + + # 验证结果: + # - 1 TOOL_CALL_CHUNK(工具调用) + # - 1 TOOL_RESULT(工具结果) + # - 1 文本回复 + assert len(all_results) == 3 + + # 工具调用 + assert all_results[0].event == EventType.TOOL_CALL_CHUNK + assert all_results[0].data["name"] == "get_weather" + + # 工具结果 + assert all_results[1].event == EventType.TOOL_RESULT + + # 最终回复 + assert all_results[2] == "上海今天多云。" + + +# ============================================================================= +# 测试边界情况 +# ============================================================================= + + +class TestConvertEdgeCases: + """测试边界情况""" + + def test_empty_event(self): + """测试空事件""" + results = list(AgentRunConverter.to_agui_events({})) + assert len(results) == 0 + + def test_none_values(self): + """测试 None 值""" + event = { + "event": "on_chat_model_stream", + "data": {"chunk": None}, + } + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_invalid_message_type(self): + """测试无效的消息类型""" + msg = MagicMock() + msg.type = "unknown" + msg.content = "test" + event = {"model": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + # unknown 类型不会产生输出 + assert len(results) == 0 + + def test_tool_call_without_id(self): + """测试没有 ID 的工具调用""" + msg = create_mock_ai_message( + "", + tool_calls=[{"name": "test", "args": {}}], # 没有 id + ) + event = {"agent": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + # 没有 id 的工具调用应该被跳过 + assert len(results) == 0 + + def test_tool_message_without_tool_call_id(self): + """测试没有 tool_call_id 的工具消息""" + msg = MagicMock() + msg.type = "tool" + msg.content = "result" + msg.tool_call_id = None # 没有 tool_call_id + + event = {"tools": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + # 没有 tool_call_id 的工具消息应该被跳过 + assert len(results) == 0 + + def test_dict_message_format(self): + """测试字典格式的消息(而非对象)""" + event = { + "model": {"messages": [{"type": "ai", "content": "字典格式消息"}]} + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "字典格式消息" + + def test_multimodal_content(self): + """测试多模态内容(list 格式)""" + chunk = MagicMock() + chunk.content = [ + {"type": "text", "text": "这是"}, + {"type": "text", "text": "多模态内容"}, + ] + chunk.tool_call_chunks = [] + + event = { + "event": "on_chat_model_stream", + "data": {"chunk": chunk}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "这是多模态内容" + + def test_output_with_content_attribute(self): + """测试有 content 属性的工具输出""" + output = MagicMock() + output.content = "工具输出内容" + + event = { + "event": "on_tool_end", + "run_id": "run_123", + "data": {"output": output}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + # on_tool_end 只发送 TOOL_CALL_RESULT(TOOL_CALL_END 在 on_tool_start 发送) + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["result"] == "工具输出内容" + + def test_unsupported_stream_mode_messages_format(self): + """测试不支持的 stream_mode='messages' 格式(元组形式) + + stream_mode='messages' 返回 (AIMessageChunk, metadata) 元组, + 不是 dict 格式,to_agui_events 不支持此格式,应该不产生输出。 + """ + # 模拟 stream_mode="messages" 返回的元组格式 + chunk = create_mock_ai_message_chunk("测试内容") + metadata = {"langgraph_node": "model"} + event = (chunk, metadata) # 元组格式 + + # 元组格式会被 _event_to_dict 转换为空字典,因此不产生输出 + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + def test_unsupported_random_dict_format(self): + """测试不支持的随机字典格式 + + 如果传入的 dict 不匹配任何已知格式,应该不产生输出。 + """ + event = { + "random_key": "random_value", + "another_key": {"nested": "data"}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + assert len(results) == 0 + + +# ============================================================================= +# 测试 AG-UI 协议事件顺序 +# ============================================================================= + + +class TestAguiEventOrder: + """测试事件顺序 + + 简化后的事件结构: + - TOOL_CALL_CHUNK - 工具调用(包含 id, name, args_delta) + - TOOL_RESULT - 工具执行结果 + + 边界事件(如 TOOL_CALL_START/END)由协议层自动处理。 + """ + + def test_streaming_tool_call_order(self): + """测试流式工具调用的事件顺序 + + TOOL_CALL_CHUNK 应该在 TOOL_RESULT 之前 + """ + events = [ + # 第一个 chunk:包含 id、name,无 args + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_order_test", + "name": "test_tool", + "args": "", + "index": 0, + }], + ) + }, + }, + # 第二个 chunk:有 args + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"key": "value"}', + "index": 0, + }], + ) + }, + }, + # 工具开始执行 + { + "event": "on_tool_start", + "name": "test_tool", + "run_id": "run_order", + "data": { + "input": { + "key": "value", + "runtime": MagicMock(tool_call_id="call_order_test"), + } + }, + }, + # 工具执行完成 + { + "event": "on_tool_end", + "run_id": "run_order", + "data": { + "input": { + "key": "value", + "runtime": MagicMock(tool_call_id="call_order_test"), + }, + "output": "success", + }, + }, + ] + + converter = AgentRunConverter() + all_results = [] + for event in events: + all_results.extend(converter.convert(event)) + + # 提取工具调用相关事件 + tool_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event in [EventType.TOOL_CALL_CHUNK, EventType.TOOL_RESULT] + ] + + # 验证有这两种事件 + event_types = [e.event for e in tool_events] + assert EventType.TOOL_CALL_CHUNK in event_types + assert EventType.TOOL_RESULT in event_types + + # 找到第一个 TOOL_CALL_CHUNK 和 TOOL_RESULT 的索引 + chunk_idx = event_types.index(EventType.TOOL_CALL_CHUNK) + result_idx = event_types.index(EventType.TOOL_RESULT) + + # 验证顺序:TOOL_CALL_CHUNK 必须在 TOOL_RESULT 之前 + assert chunk_idx < result_idx, ( + f"TOOL_CALL_CHUNK (idx={chunk_idx}) must come before " + f"TOOL_RESULT (idx={result_idx})" + ) + + def test_streaming_tool_call_start_not_duplicated(self): + """测试流式工具调用时 TOOL_CALL_START 不会重复发送""" + events = [ + # 第一个 chunk:包含 id、name + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[{ + "id": "call_no_dup", + "name": "test_tool", + "args": '{"a": 1}', + "index": 0, + }], + ) + }, + }, + # 工具开始执行(此时 START 已在上面发送,不应重复) + { + "event": "on_tool_start", + "name": "test_tool", + "run_id": "run_no_dup", + "data": { + "input": { + "a": 1, + "runtime": MagicMock(tool_call_id="call_no_dup"), + } + }, + }, + # 工具执行完成 + { + "event": "on_tool_end", + "run_id": "run_no_dup", + "data": { + "input": { + "a": 1, + "runtime": MagicMock(tool_call_id="call_no_dup"), + }, + "output": "done", + }, + }, + ] + + converter = AgentRunConverter() + all_results = [] + for event in events: + all_results.extend(converter.convert(event)) + + # 统计 TOOL_CALL_START 事件的数量 + start_events = [ + r + for r in all_results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + + # 应该只有一个 TOOL_CALL_START + assert ( + len(start_events) == 1 + ), f"Expected 1 TOOL_CALL_START, got {len(start_events)}" + + def test_non_streaming_tool_call_order(self): + """测试非流式场景的工具调用事件顺序 + + 在没有 on_chat_model_stream 事件的场景下, + 事件顺序仍应正确:TOOL_CALL_CHUNK → TOOL_RESULT + """ + events = [ + # 直接工具开始(无流式事件) + { + "event": "on_tool_start", + "name": "weather", + "run_id": "run_nonstream", + "data": {"input": {"city": "北京"}}, + }, + # 工具执行完成 + { + "event": "on_tool_end", + "run_id": "run_nonstream", + "data": {"output": "晴天"}, + }, + ] + + converter = AgentRunConverter() + all_results = [] + for event in events: + all_results.extend(converter.convert(event)) + + tool_events = [r for r in all_results if isinstance(r, AgentResult)] + + event_types = [e.event for e in tool_events] + + # 验证顺序:TOOL_CALL_CHUNK → TOOL_RESULT + assert event_types == [ + EventType.TOOL_CALL_CHUNK, + EventType.TOOL_RESULT, + ], f"Unexpected order: {event_types}" + + def test_multiple_concurrent_tool_calls_order(self): + """测试多个并发工具调用时各自的事件顺序正确""" + events = [ + # 两个工具调用的第一个 chunk + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + { + "id": "call_a", + "name": "tool_a", + "args": "", + "index": 0, + }, + { + "id": "call_b", + "name": "tool_b", + "args": "", + "index": 1, + }, + ], + ) + }, + }, + # 两个工具的参数 + { + "event": "on_chat_model_stream", + "data": { + "chunk": MagicMock( + content="", + tool_call_chunks=[ + { + "id": "", + "name": "", + "args": '{"x": 1}', + "index": 0, + }, + { + "id": "", + "name": "", + "args": '{"y": 2}', + "index": 1, + }, + ], + ) + }, + }, + # 工具 A 开始 + { + "event": "on_tool_start", + "name": "tool_a", + "run_id": "run_a", + "data": { + "input": { + "x": 1, + "runtime": MagicMock(tool_call_id="call_a"), + } + }, + }, + # 工具 B 开始 + { + "event": "on_tool_start", + "name": "tool_b", + "run_id": "run_b", + "data": { + "input": { + "y": 2, + "runtime": MagicMock(tool_call_id="call_b"), + } + }, + }, + # 工具 A 结束 + { + "event": "on_tool_end", + "run_id": "run_a", + "data": { + "input": { + "x": 1, + "runtime": MagicMock(tool_call_id="call_a"), + }, + "output": "result_a", + }, + }, + # 工具 B 结束 + { + "event": "on_tool_end", + "run_id": "run_b", + "data": { + "input": { + "y": 2, + "runtime": MagicMock(tool_call_id="call_b"), + }, + "output": "result_b", + }, + }, + ] + + converter = AgentRunConverter() + all_results = [] + for event in events: + all_results.extend(converter.convert(event)) + + # 分别验证工具 A 和工具 B 的事件顺序 + for tool_id in ["call_a", "call_b"]: + tool_events = [ + (i, r) + for i, r in enumerate(all_results) + if isinstance(r, AgentResult) and r.data.get("id") == tool_id + ] + + event_types = [e.event for _, e in tool_events] + + # 验证包含所有必需事件 + assert ( + EventType.TOOL_CALL_CHUNK in event_types + ), f"Tool {tool_id} missing TOOL_CALL_CHUNK" + assert ( + EventType.TOOL_RESULT in event_types + ), f"Tool {tool_id} missing TOOL_RESULT" + + # 验证顺序:TOOL_CALL_CHUNK 应该在 TOOL_RESULT 之前 + chunk_pos = event_types.index(EventType.TOOL_CALL_CHUNK) + result_pos = event_types.index(EventType.TOOL_RESULT) + + assert ( + chunk_pos < result_pos + ), f"Tool {tool_id}: CHUNK must come before RESULT" + + +# ============================================================================= +# 集成测试:模拟完整流程 +# ============================================================================= + + +class TestConvertIntegration: + """测试 convert 与完整流程的集成""" + + def test_astream_events_full_flow(self): + """测试模拟的 astream_events 完整流程""" + mock_events = [ + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk("你好")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk(",")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_mock_ai_message_chunk("世界!")}, + }, + ] + + results = [] + for event in mock_events: + results.extend(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 3 + assert "".join(results) == "你好,世界!" + + def test_stream_updates_full_flow(self): + """测试模拟的 stream(updates) 完整流程""" + import json + + mock_events = [ + # Agent 决定调用工具 + { + "agent": { + "messages": [ + create_mock_ai_message( + "", + tool_calls=[{ + "id": "tc_001", + "name": "get_weather", + "args": {"city": "北京"}, + }], + ) + ] + } + }, + # 工具执行结果 + { + "tools": { + "messages": [ + create_mock_tool_message( + json.dumps( + {"city": "北京", "weather": "晴天"}, + ensure_ascii=False, + ), + "tc_001", + ) + ] + } + }, + # Agent 最终回复 + { + "model": { + "messages": [create_mock_ai_message("北京今天天气晴朗。")] + } + }, + ] + + results = [] + for event in mock_events: + results.extend(AgentRunConverter.to_agui_events(event)) + + # 验证事件顺序 + assert len(results) == 3 + + # 工具调用 + assert isinstance(results[0], AgentResult) + assert results[0].event == EventType.TOOL_CALL_CHUNK + + # 工具结果 + assert isinstance(results[1], AgentResult) + assert results[1].event == EventType.TOOL_RESULT + + # 最终文本 + assert results[2] == "北京今天天气晴朗。" + + def test_stream_values_full_flow(self): + """测试模拟的 stream(values) 完整流程""" + mock_events = [ + {"messages": [create_mock_ai_message("")]}, + { + "messages": [ + create_mock_ai_message( + "", + tool_calls=[ + {"id": "tc_002", "name": "get_time", "args": {}} + ], + ) + ] + }, + { + "messages": [ + create_mock_ai_message(""), + create_mock_tool_message("2024-01-01 12:00:00", "tc_002"), + ] + }, + { + "messages": [ + create_mock_ai_message(""), + create_mock_tool_message("2024-01-01 12:00:00", "tc_002"), + create_mock_ai_message("现在是 2024年1月1日。"), + ] + }, + ] + + results = [] + for event in mock_events: + results.extend(AgentRunConverter.to_agui_events(event)) + + # 验证有工具调用事件 + tool_chunks = [ + r + for r in results + if isinstance(r, AgentResult) + and r.event == EventType.TOOL_CALL_CHUNK + ] + assert len(tool_chunks) >= 1 + + # 验证有最终文本 + text_results = [r for r in results if isinstance(r, str) and r] + assert any("2024" in t for t in text_results) diff --git a/tests/unittests/integration/test_langgraph_events.py b/tests/unittests/integration/test_langgraph_events.py new file mode 100644 index 0000000..3bae1d8 --- /dev/null +++ b/tests/unittests/integration/test_langgraph_events.py @@ -0,0 +1,911 @@ +"""测试 LangGraph 事件到 AgentEvent 的转换 + +本文件专注于测试 LangGraph/LangChain 事件到 AgentEvent 的转换, +确保每种输入事件类型都有明确的预期输出。 + +简化后的事件结构: +- TOOL_CALL_CHUNK: 工具调用(包含 id, name, args_delta) +- TOOL_RESULT: 工具执行结果(包含 id, result) +- TEXT: 文本内容(字符串) + +边界事件(如 TOOL_CALL_START/END)由协议层自动生成,转换器不再输出这些事件。 +""" + +from typing import Dict, List, Union +from unittest.mock import MagicMock + +import pytest + +from agentrun.integration.langgraph import AgentRunConverter +from agentrun.server.model import AgentEvent, EventType + +# 使用 helpers.py 中的公共函数 +from .helpers import convert_and_collect +from .helpers import create_mock_ai_message as create_ai_message +from .helpers import create_mock_ai_message_chunk as create_ai_message_chunk +from .helpers import create_mock_tool_message as create_tool_message +from .helpers import filter_agent_events + +# ============================================================================= +# 测试 on_chat_model_stream 事件(流式文本输出) +# ============================================================================= + + +class TestOnChatModelStreamText: + """测试 on_chat_model_stream 事件的文本内容输出""" + + def test_simple_text_content(self): + """测试简单文本内容 + + 输入: on_chat_model_stream with content="你好" + 输出: "你好" (字符串) + """ + event = { + "event": "on_chat_model_stream", + "data": {"chunk": create_ai_message_chunk("你好")}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "你好" + + def test_empty_content_no_output(self): + """测试空内容不产生输出 + + 输入: on_chat_model_stream with content="" + 输出: (无) + """ + event = { + "event": "on_chat_model_stream", + "data": {"chunk": create_ai_message_chunk("")}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 0 + + def test_multiple_stream_chunks(self): + """测试多个流式 chunk + + 输入: 多个 on_chat_model_stream + 输出: 多个字符串 + """ + events = [ + { + "event": "on_chat_model_stream", + "data": {"chunk": create_ai_message_chunk("你")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_ai_message_chunk("好")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_ai_message_chunk("!")}, + }, + ] + + results = convert_and_collect(events) + + assert len(results) == 3 + assert results[0] == "你" + assert results[1] == "好" + assert results[2] == "!" + + +# ============================================================================= +# 测试 on_chat_model_stream 事件(流式工具调用) +# ============================================================================= + + +class TestOnChatModelStreamToolCall: + """测试 on_chat_model_stream 事件的工具调用输出""" + + def test_tool_call_first_chunk_with_id_and_name(self): + """测试工具调用第一个 chunk(包含 id 和 name) + + 输入: tool_call_chunk with id="call_123", name="get_weather", args="" + 输出: TOOL_CALL_CHUNK with id="call_123", name="get_weather", args_delta="" + """ + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_123", + "name": "get_weather", + "args": "", + "index": 0, + }] + ) + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert isinstance(results[0], AgentEvent) + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_123" + assert results[0].data["name"] == "get_weather" + assert results[0].data["args_delta"] == "" + + def test_tool_call_subsequent_chunk_with_args(self): + """测试工具调用后续 chunk(只有 args) + + 输入: 后续 chunk with args='{"city": "北京"}' + 输出: TOOL_CALL_CHUNK with args_delta='{"city": "北京"}' + """ + # 首先发送第一个 chunk 建立映射 + first_chunk = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_123", + "name": "get_weather", + "args": "", + "index": 0, + }] + ) + }, + } + + # 后续 chunk + second_chunk = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "", # 后续 chunk 没有 id + "name": "", + "args": '{"city": "北京"}', + "index": 0, + }] + ) + }, + } + + tool_call_id_map: Dict[int, str] = {} + results1 = list( + AgentRunConverter.to_agui_events( + first_chunk, tool_call_id_map=tool_call_id_map + ) + ) + results2 = list( + AgentRunConverter.to_agui_events( + second_chunk, tool_call_id_map=tool_call_id_map + ) + ) + + # 第一个 chunk 产生一个事件 + assert len(results1) == 1 + assert results1[0].data["id"] == "call_123" + + # 第二个 chunk 也产生一个事件,使用映射的 id + assert len(results2) == 1 + assert results2[0].event == EventType.TOOL_CALL_CHUNK + assert results2[0].data["id"] == "call_123" + assert results2[0].data["args_delta"] == '{"city": "北京"}' + + def test_tool_call_complete_in_one_chunk(self): + """测试完整工具调用在一个 chunk 中 + + 输入: chunk with id, name, and complete args + 输出: 单个 TOOL_CALL_CHUNK + """ + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_456", + "name": "get_time", + "args": '{"timezone": "UTC"}', + "index": 0, + }] + ) + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_456" + assert results[0].data["name"] == "get_time" + assert results[0].data["args_delta"] == '{"timezone": "UTC"}' + + def test_multiple_concurrent_tool_calls(self): + """测试多个并发工具调用 + + 输入: 一个 chunk 包含两个工具调用 + 输出: 两个 TOOL_CALL_CHUNK + """ + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[ + { + "id": "call_a", + "name": "tool_a", + "args": "", + "index": 0, + }, + { + "id": "call_b", + "name": "tool_b", + "args": "", + "index": 1, + }, + ] + ) + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 2 + assert results[0].data["id"] == "call_a" + assert results[0].data["name"] == "tool_a" + assert results[1].data["id"] == "call_b" + assert results[1].data["name"] == "tool_b" + + +# ============================================================================= +# 测试 on_tool_start 事件 +# ============================================================================= + + +class TestOnToolStart: + """测试 on_tool_start 事件的转换""" + + def test_simple_tool_start(self): + """测试简单的工具启动事件 + + 输入: on_tool_start with input + 输出: 单个 TOOL_CALL_CHUNK(如果未在流式中发送过) + """ + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "run_123", + "data": {"input": {"city": "北京"}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "run_123" + assert results[0].data["name"] == "get_weather" + assert "北京" in results[0].data["args_delta"] + + def test_tool_start_with_runtime_tool_call_id(self): + """测试使用 runtime.tool_call_id 的工具启动 + + 输入: on_tool_start with runtime.tool_call_id + 输出: TOOL_CALL_CHUNK 使用 runtime 中的 id + """ + + class FakeRuntime: + tool_call_id = "call_original_id" + + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "run_123", # 这个不应该被使用 + "data": {"input": {"city": "北京", "runtime": FakeRuntime()}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert ( + results[0].data["id"] == "call_original_id" + ) # 使用 runtime 中的 id + + def test_tool_start_no_duplicate_if_already_started(self): + """测试如果工具已在流式中开始,不重复发送 + + 输入: on_tool_start after streaming chunk + 输出: 无(因为已在流式中发送过) + """ + tool_call_started_set = {"call_123"} + + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "run_123", + "data": { + "input": { + "city": "北京", + "runtime": MagicMock(tool_call_id="call_123"), + } + }, + } + + results = list( + AgentRunConverter.to_agui_events( + event, tool_call_started_set=tool_call_started_set + ) + ) + + # 已经在流式中发送过,不再发送 + assert len(results) == 0 + + def test_tool_start_without_input(self): + """测试无输入参数的工具启动 + + 输入: on_tool_start with empty input + 输出: TOOL_CALL_CHUNK with empty args_delta + """ + event = { + "event": "on_tool_start", + "name": "get_time", + "run_id": "run_456", + "data": {}, # 无输入 + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["name"] == "get_time" + assert results[0].data["args_delta"] == "" + + +# ============================================================================= +# 测试 on_tool_end 事件 +# ============================================================================= + + +class TestOnToolEnd: + """测试 on_tool_end 事件的转换""" + + def test_simple_tool_end(self): + """测试简单的工具结束事件 + + 输入: on_tool_end with output + 输出: TOOL_RESULT + """ + event = { + "event": "on_tool_end", + "run_id": "run_123", + "data": {"output": {"result": "晴天", "temp": 25}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["id"] == "run_123" + assert "晴天" in results[0].data["result"] + + def test_tool_end_with_string_output(self): + """测试字符串输出的工具结束 + + 输入: on_tool_end with string output + 输出: TOOL_RESULT with string result + """ + event = { + "event": "on_tool_end", + "run_id": "run_456", + "data": {"output": "操作成功"}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["result"] == "操作成功" + + def test_tool_end_with_runtime_tool_call_id(self): + """测试使用 runtime.tool_call_id 的工具结束 + + 输入: on_tool_end with runtime.tool_call_id + 输出: TOOL_RESULT 使用 runtime 中的 id + """ + + class FakeRuntime: + tool_call_id = "call_original_id" + + event = { + "event": "on_tool_end", + "run_id": "run_123", + "data": { + "input": {"runtime": FakeRuntime()}, + "output": "结果", + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].data["id"] == "call_original_id" + + +# ============================================================================= +# 测试 stream_mode="updates" 格式 +# ============================================================================= + + +class TestStreamUpdatesFormat: + """测试 stream(stream_mode="updates") 格式的事件转换""" + + def test_ai_message_with_text(self): + """测试 AI 消息的文本内容 + + 输入: {model: {messages: [AIMessage(content="你好")]}} + 输出: "你好" (字符串) + """ + msg = create_ai_message("你好") + event = {"model": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "你好" + + def test_ai_message_with_tool_calls(self): + """测试 AI 消息的工具调用 + + 输入: {agent: {messages: [AIMessage with tool_calls]}} + 输出: TOOL_CALL_CHUNK + """ + msg = create_ai_message( + tool_calls=[{ + "id": "call_abc", + "name": "search", + "args": {"query": "天气"}, + }] + ) + event = {"agent": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_abc" + assert results[0].data["name"] == "search" + + def test_tool_message_result(self): + """测试工具消息的结果 + + 输入: {tools: {messages: [ToolMessage]}} + 输出: TOOL_RESULT + """ + msg = create_tool_message('{"weather": "晴天"}', "call_xyz") + event = {"tools": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["id"] == "call_xyz" + + +# ============================================================================= +# 测试 stream_mode="values" 格式 +# ============================================================================= + + +class TestStreamValuesFormat: + """测试 stream(stream_mode="values") 格式的事件转换""" + + def test_values_format_last_message(self): + """测试 values 格式只处理最后一条消息 + + 输入: {messages: [msg1, msg2, msg3]} + 输出: 只处理 msg3 + """ + msg1 = create_ai_message("第一条") + msg2 = create_ai_message("第二条") + msg3 = create_ai_message("第三条") + event = {"messages": [msg1, msg2, msg3]} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "第三条" + + def test_values_format_tool_call(self): + """测试 values 格式的工具调用 + + 输入: {messages: [AIMessage with tool_calls]} + 输出: TOOL_CALL_CHUNK + """ + msg = create_ai_message( + tool_calls=[{"id": "call_123", "name": "test", "args": {}}] + ) + event = {"messages": [msg]} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + + +# ============================================================================= +# 测试 AgentRunConverter 类 +# ============================================================================= + + +class TestAgentRunConverterClass: + """测试 AgentRunConverter 类的功能""" + + def test_converter_maintains_state(self): + """测试转换器维护状态(tool_call_id_map)""" + converter = AgentRunConverter() + + # 第一个 chunk 建立映射 + event1 = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_stateful", + "name": "test", + "args": "", + "index": 0, + }] + ) + }, + } + + results1 = list(converter.convert(event1)) + assert converter._tool_call_id_map[0] == "call_stateful" + + # 第二个 chunk 使用映射 + event2 = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"a": 1}', + "index": 0, + }] + ) + }, + } + + results2 = list(converter.convert(event2)) + assert results2[0].data["id"] == "call_stateful" + + def test_converter_reset(self): + """测试转换器重置""" + converter = AgentRunConverter() + + # 建立一些状态 + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_to_reset", + "name": "test", + "args": "", + "index": 0, + }] + ) + }, + } + list(converter.convert(event)) + + assert len(converter._tool_call_id_map) > 0 + + # 重置 + converter.reset() + assert len(converter._tool_call_id_map) == 0 + assert len(converter._tool_call_started_set) == 0 + + +# ============================================================================= +# 测试完整流程 +# ============================================================================= + + +class TestCompleteFlow: + """测试完整的工具调用流程""" + + def test_streaming_tool_call_complete_flow(self): + """测试流式工具调用的完整流程 + + 流程: + 1. on_chat_model_stream (id + name) + 2. on_chat_model_stream (args) + 3. on_tool_start (不产生事件,因为已发送) + 4. on_tool_end (TOOL_RESULT) + """ + converter = AgentRunConverter() + + events = [ + # 1. 流式工具调用开始 + { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_flow_test", + "name": "weather", + "args": "", + "index": 0, + }] + ) + }, + }, + # 2. 流式工具调用参数 + { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"city": "北京"}', + "index": 0, + }] + ) + }, + }, + # 3. 工具开始执行(已在流式中发送,不产生事件) + { + "event": "on_tool_start", + "name": "weather", + "run_id": "run_flow", + "data": { + "input": { + "city": "北京", + "runtime": MagicMock(tool_call_id="call_flow_test"), + } + }, + }, + # 4. 工具执行完成 + { + "event": "on_tool_end", + "run_id": "run_flow", + "data": { + "input": { + "city": "北京", + "runtime": MagicMock(tool_call_id="call_flow_test"), + }, + "output": "晴天", + }, + }, + ] + + all_results = [] + for event in events: + all_results.extend(converter.convert(event)) + + # 预期结果: + # - 2 个 TOOL_CALL_CHUNK(流式) + # - 1 个 TOOL_RESULT + chunk_events = filter_agent_events( + all_results, EventType.TOOL_CALL_CHUNK + ) + result_events = filter_agent_events(all_results, EventType.TOOL_RESULT) + + assert len(chunk_events) == 2 + assert len(result_events) == 1 + + # 验证 ID 一致性 + for event in chunk_events + result_events: + assert event.data["id"] == "call_flow_test" + + def test_non_streaming_tool_call_complete_flow(self): + """测试非流式工具调用的完整流程 + + 流程: + 1. on_tool_start (TOOL_CALL_CHUNK) + 2. on_tool_end (TOOL_RESULT) + """ + events = [ + { + "event": "on_tool_start", + "name": "calculator", + "run_id": "run_calc", + "data": {"input": {"a": 1, "b": 2}}, + }, + { + "event": "on_tool_end", + "run_id": "run_calc", + "data": {"output": 3}, + }, + ] + + all_results = convert_and_collect(events) + + chunk_events = filter_agent_events( + all_results, EventType.TOOL_CALL_CHUNK + ) + result_events = filter_agent_events(all_results, EventType.TOOL_RESULT) + + assert len(chunk_events) == 1 + assert len(result_events) == 1 + + # 验证顺序 + chunk_idx = all_results.index(chunk_events[0]) + result_idx = all_results.index(result_events[0]) + assert chunk_idx < result_idx + + +# ============================================================================= +# 测试错误事件 +# ============================================================================= + + +class TestErrorEvents: + """测试 LangChain 错误事件的转换""" + + def test_on_tool_error(self): + """测试 on_tool_error 事件 + + 输入: on_tool_error with error + 输出: ERROR 事件 + """ + event = { + "event": "on_tool_error", + "name": "weather_tool", + "run_id": "run_123", + "data": { + "error": ValueError("Invalid city name"), + "input": {"city": "invalid"}, + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert "weather_tool" in results[0].data["message"] + assert "ValueError" in results[0].data["message"] + assert results[0].data["code"] == "TOOL_ERROR" + assert results[0].data["tool_call_id"] == "run_123" + + def test_on_tool_error_with_runtime_tool_call_id(self): + """测试 on_tool_error 使用 runtime 中的 tool_call_id""" + + class FakeRuntime: + tool_call_id = "call_original_id" + + event = { + "event": "on_tool_error", + "name": "search_tool", + "run_id": "run_456", + "data": { + "error": Exception("API timeout"), + "input": {"query": "test", "runtime": FakeRuntime()}, + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].data["tool_call_id"] == "call_original_id" + + def test_on_tool_error_with_string_error(self): + """测试 on_tool_error 使用字符串错误""" + event = { + "event": "on_tool_error", + "name": "calc_tool", + "run_id": "run_789", + "data": { + "error": "Division by zero", + "input": {}, + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert "Division by zero" in results[0].data["message"] + + def test_on_llm_error(self): + """测试 on_llm_error 事件 + + 输入: on_llm_error with error + 输出: ERROR 事件 + """ + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": RuntimeError("API rate limit exceeded"), + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert "LLM error" in results[0].data["message"] + assert "RuntimeError" in results[0].data["message"] + assert results[0].data["code"] == "LLM_ERROR" + + def test_on_chain_error(self): + """测试 on_chain_error 事件 + + 输入: on_chain_error with error + 输出: ERROR 事件 + """ + event = { + "event": "on_chain_error", + "name": "agent_chain", + "run_id": "run_chain", + "data": { + "error": KeyError("missing_key"), + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert "agent_chain" in results[0].data["message"] + assert "KeyError" in results[0].data["message"] + assert results[0].data["code"] == "CHAIN_ERROR" + + def test_on_retriever_error(self): + """测试 on_retriever_error 事件 + + 输入: on_retriever_error with error + 输出: ERROR 事件 + """ + event = { + "event": "on_retriever_error", + "name": "vector_store", + "run_id": "run_retriever", + "data": { + "error": ConnectionError("Database connection failed"), + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert "vector_store" in results[0].data["message"] + assert "ConnectionError" in results[0].data["message"] + assert results[0].data["code"] == "RETRIEVER_ERROR" + + def test_tool_error_in_complete_flow(self): + """测试完整流程中的工具错误 + + 流程: + 1. on_tool_start (TOOL_CALL_CHUNK) + 2. on_tool_error (ERROR) + """ + events = [ + { + "event": "on_tool_start", + "name": "risky_tool", + "run_id": "run_risky", + "data": {"input": {"param": "test"}}, + }, + { + "event": "on_tool_error", + "name": "risky_tool", + "run_id": "run_risky", + "data": { + "error": RuntimeError("Tool execution failed"), + "input": {"param": "test"}, + }, + }, + ] + + all_results = convert_and_collect(events) + + chunk_events = filter_agent_events( + all_results, EventType.TOOL_CALL_CHUNK + ) + error_events = filter_agent_events(all_results, EventType.ERROR) + + assert len(chunk_events) == 1 + assert len(error_events) == 1 + assert chunk_events[0].data["id"] == "run_risky" + assert error_events[0].data["tool_call_id"] == "run_risky" diff --git a/tests/unittests/integration/test_langgraph_to_agent_event.py b/tests/unittests/integration/test_langgraph_to_agent_event.py new file mode 100644 index 0000000..a45f65a --- /dev/null +++ b/tests/unittests/integration/test_langgraph_to_agent_event.py @@ -0,0 +1,909 @@ +"""测试 LangGraph 事件到 AgentEvent 的转换 + +本文件专注于测试 LangGraph/LangChain 事件到 AgentEvent 的转换, +确保每种输入事件类型都有明确的预期输出。 + +简化后的事件结构: +- TOOL_CALL_CHUNK: 工具调用(包含 id, name, args_delta) +- TOOL_RESULT: 工具执行结果(包含 id, result) +- TEXT: 文本内容(字符串) + +边界事件(如 TOOL_CALL_START/END)由协议层自动生成,转换器不再输出这些事件。 +""" + +from typing import Dict +from unittest.mock import MagicMock + +from agentrun.integration.langgraph import AgentRunConverter +from agentrun.server.model import AgentEvent, EventType + +# 使用 conftest.py 中的公共函数 +from .conftest import convert_and_collect +from .conftest import create_mock_ai_message as create_ai_message +from .conftest import create_mock_ai_message_chunk as create_ai_message_chunk +from .conftest import create_mock_tool_message as create_tool_message +from .conftest import filter_agent_events + +# ============================================================================= +# 测试 on_chat_model_stream 事件(流式文本输出) +# ============================================================================= + + +class TestOnChatModelStreamText: + """测试 on_chat_model_stream 事件的文本内容输出""" + + def test_simple_text_content(self): + """测试简单文本内容 + + 输入: on_chat_model_stream with content="你好" + 输出: "你好" (字符串) + """ + event = { + "event": "on_chat_model_stream", + "data": {"chunk": create_ai_message_chunk("你好")}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "你好" + + def test_empty_content_no_output(self): + """测试空内容不产生输出 + + 输入: on_chat_model_stream with content="" + 输出: (无) + """ + event = { + "event": "on_chat_model_stream", + "data": {"chunk": create_ai_message_chunk("")}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 0 + + def test_multiple_stream_chunks(self): + """测试多个流式 chunk + + 输入: 多个 on_chat_model_stream + 输出: 多个字符串 + """ + events = [ + { + "event": "on_chat_model_stream", + "data": {"chunk": create_ai_message_chunk("你")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_ai_message_chunk("好")}, + }, + { + "event": "on_chat_model_stream", + "data": {"chunk": create_ai_message_chunk("!")}, + }, + ] + + results = convert_and_collect(events) + + assert len(results) == 3 + assert results[0] == "你" + assert results[1] == "好" + assert results[2] == "!" + + +# ============================================================================= +# 测试 on_chat_model_stream 事件(流式工具调用) +# ============================================================================= + + +class TestOnChatModelStreamToolCall: + """测试 on_chat_model_stream 事件的工具调用输出""" + + def test_tool_call_first_chunk_with_id_and_name(self): + """测试工具调用第一个 chunk(包含 id 和 name) + + 输入: tool_call_chunk with id="call_123", name="get_weather", args="" + 输出: TOOL_CALL_CHUNK with id="call_123", name="get_weather", args_delta="" + """ + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_123", + "name": "get_weather", + "args": "", + "index": 0, + }] + ) + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert isinstance(results[0], AgentEvent) + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_123" + assert results[0].data["name"] == "get_weather" + assert results[0].data["args_delta"] == "" + + def test_tool_call_subsequent_chunk_with_args(self): + """测试工具调用后续 chunk(只有 args) + + 输入: 后续 chunk with args='{"city": "北京"}' + 输出: TOOL_CALL_CHUNK with args_delta='{"city": "北京"}' + """ + # 首先发送第一个 chunk 建立映射 + first_chunk = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_123", + "name": "get_weather", + "args": "", + "index": 0, + }] + ) + }, + } + + # 后续 chunk + second_chunk = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "", # 后续 chunk 没有 id + "name": "", + "args": '{"city": "北京"}', + "index": 0, + }] + ) + }, + } + + tool_call_id_map: Dict[int, str] = {} + results1 = list( + AgentRunConverter.to_agui_events( + first_chunk, tool_call_id_map=tool_call_id_map + ) + ) + results2 = list( + AgentRunConverter.to_agui_events( + second_chunk, tool_call_id_map=tool_call_id_map + ) + ) + + # 第一个 chunk 产生一个事件 + assert len(results1) == 1 + assert results1[0].data["id"] == "call_123" + + # 第二个 chunk 也产生一个事件,使用映射的 id + assert len(results2) == 1 + assert results2[0].event == EventType.TOOL_CALL_CHUNK + assert results2[0].data["id"] == "call_123" + assert results2[0].data["args_delta"] == '{"city": "北京"}' + + def test_tool_call_complete_in_one_chunk(self): + """测试完整工具调用在一个 chunk 中 + + 输入: chunk with id, name, and complete args + 输出: 单个 TOOL_CALL_CHUNK + """ + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_456", + "name": "get_time", + "args": '{"timezone": "UTC"}', + "index": 0, + }] + ) + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_456" + assert results[0].data["name"] == "get_time" + assert results[0].data["args_delta"] == '{"timezone": "UTC"}' + + def test_multiple_concurrent_tool_calls(self): + """测试多个并发工具调用 + + 输入: 一个 chunk 包含两个工具调用 + 输出: 两个 TOOL_CALL_CHUNK + """ + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[ + { + "id": "call_a", + "name": "tool_a", + "args": "", + "index": 0, + }, + { + "id": "call_b", + "name": "tool_b", + "args": "", + "index": 1, + }, + ] + ) + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 2 + assert results[0].data["id"] == "call_a" + assert results[0].data["name"] == "tool_a" + assert results[1].data["id"] == "call_b" + assert results[1].data["name"] == "tool_b" + + +# ============================================================================= +# 测试 on_tool_start 事件 +# ============================================================================= + + +class TestOnToolStart: + """测试 on_tool_start 事件的转换""" + + def test_simple_tool_start(self): + """测试简单的工具启动事件 + + 输入: on_tool_start with input + 输出: 单个 TOOL_CALL_CHUNK(如果未在流式中发送过) + """ + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "run_123", + "data": {"input": {"city": "北京"}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "run_123" + assert results[0].data["name"] == "get_weather" + assert "北京" in results[0].data["args_delta"] + + def test_tool_start_with_runtime_tool_call_id(self): + """测试使用 runtime.tool_call_id 的工具启动 + + 输入: on_tool_start with runtime.tool_call_id + 输出: TOOL_CALL_CHUNK 使用 runtime 中的 id + """ + + class FakeRuntime: + tool_call_id = "call_original_id" + + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "run_123", # 这个不应该被使用 + "data": {"input": {"city": "北京", "runtime": FakeRuntime()}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert ( + results[0].data["id"] == "call_original_id" + ) # 使用 runtime 中的 id + + def test_tool_start_no_duplicate_if_already_started(self): + """测试如果工具已在流式中开始,不重复发送 + + 输入: on_tool_start after streaming chunk + 输出: 无(因为已在流式中发送过) + """ + tool_call_started_set = {"call_123"} + + event = { + "event": "on_tool_start", + "name": "get_weather", + "run_id": "run_123", + "data": { + "input": { + "city": "北京", + "runtime": MagicMock(tool_call_id="call_123"), + } + }, + } + + results = list( + AgentRunConverter.to_agui_events( + event, tool_call_started_set=tool_call_started_set + ) + ) + + # 已经在流式中发送过,不再发送 + assert len(results) == 0 + + def test_tool_start_without_input(self): + """测试无输入参数的工具启动 + + 输入: on_tool_start with empty input + 输出: TOOL_CALL_CHUNK with empty args_delta + """ + event = { + "event": "on_tool_start", + "name": "get_time", + "run_id": "run_456", + "data": {}, # 无输入 + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["name"] == "get_time" + assert results[0].data["args_delta"] == "" + + +# ============================================================================= +# 测试 on_tool_end 事件 +# ============================================================================= + + +class TestOnToolEnd: + """测试 on_tool_end 事件的转换""" + + def test_simple_tool_end(self): + """测试简单的工具结束事件 + + 输入: on_tool_end with output + 输出: TOOL_RESULT + """ + event = { + "event": "on_tool_end", + "run_id": "run_123", + "data": {"output": {"result": "晴天", "temp": 25}}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["id"] == "run_123" + assert "晴天" in results[0].data["result"] + + def test_tool_end_with_string_output(self): + """测试字符串输出的工具结束 + + 输入: on_tool_end with string output + 输出: TOOL_RESULT with string result + """ + event = { + "event": "on_tool_end", + "run_id": "run_456", + "data": {"output": "操作成功"}, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["result"] == "操作成功" + + def test_tool_end_with_runtime_tool_call_id(self): + """测试使用 runtime.tool_call_id 的工具结束 + + 输入: on_tool_end with runtime.tool_call_id + 输出: TOOL_RESULT 使用 runtime 中的 id + """ + + class FakeRuntime: + tool_call_id = "call_original_id" + + event = { + "event": "on_tool_end", + "run_id": "run_123", + "data": { + "input": {"runtime": FakeRuntime()}, + "output": "结果", + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].data["id"] == "call_original_id" + + +# ============================================================================= +# 测试 stream_mode="updates" 格式 +# ============================================================================= + + +class TestStreamUpdatesFormat: + """测试 stream(stream_mode="updates") 格式的事件转换""" + + def test_ai_message_with_text(self): + """测试 AI 消息的文本内容 + + 输入: {model: {messages: [AIMessage(content="你好")]}} + 输出: "你好" (字符串) + """ + msg = create_ai_message("你好") + event = {"model": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "你好" + + def test_ai_message_with_tool_calls(self): + """测试 AI 消息的工具调用 + + 输入: {agent: {messages: [AIMessage with tool_calls]}} + 输出: TOOL_CALL_CHUNK + """ + msg = create_ai_message( + tool_calls=[{ + "id": "call_abc", + "name": "search", + "args": {"query": "天气"}, + }] + ) + event = {"agent": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_abc" + assert results[0].data["name"] == "search" + + def test_tool_message_result(self): + """测试工具消息的结果 + + 输入: {tools: {messages: [ToolMessage]}} + 输出: TOOL_RESULT + """ + msg = create_tool_message('{"weather": "晴天"}', "call_xyz") + event = {"tools": {"messages": [msg]}} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + assert results[0].data["id"] == "call_xyz" + + +# ============================================================================= +# 测试 stream_mode="values" 格式 +# ============================================================================= + + +class TestStreamValuesFormat: + """测试 stream(stream_mode="values") 格式的事件转换""" + + def test_values_format_last_message(self): + """测试 values 格式只处理最后一条消息 + + 输入: {messages: [msg1, msg2, msg3]} + 输出: 只处理 msg3 + """ + msg1 = create_ai_message("第一条") + msg2 = create_ai_message("第二条") + msg3 = create_ai_message("第三条") + event = {"messages": [msg1, msg2, msg3]} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0] == "第三条" + + def test_values_format_tool_call(self): + """测试 values 格式的工具调用 + + 输入: {messages: [AIMessage with tool_calls]} + 输出: TOOL_CALL_CHUNK + """ + msg = create_ai_message( + tool_calls=[{"id": "call_123", "name": "test", "args": {}}] + ) + event = {"messages": [msg]} + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + + +# ============================================================================= +# 测试 AgentRunConverter 类 +# ============================================================================= + + +class TestAgentRunConverterClass: + """测试 AgentRunConverter 类的功能""" + + def test_converter_maintains_state(self): + """测试转换器维护状态(tool_call_id_map)""" + converter = AgentRunConverter() + + # 第一个 chunk 建立映射 + event1 = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_stateful", + "name": "test", + "args": "", + "index": 0, + }] + ) + }, + } + + results1 = list(converter.convert(event1)) + assert converter._tool_call_id_map[0] == "call_stateful" + + # 第二个 chunk 使用映射 + event2 = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"a": 1}', + "index": 0, + }] + ) + }, + } + + results2 = list(converter.convert(event2)) + assert results2[0].data["id"] == "call_stateful" + + def test_converter_reset(self): + """测试转换器重置""" + converter = AgentRunConverter() + + # 建立一些状态 + event = { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_to_reset", + "name": "test", + "args": "", + "index": 0, + }] + ) + }, + } + list(converter.convert(event)) + + assert len(converter._tool_call_id_map) > 0 + + # 重置 + converter.reset() + assert len(converter._tool_call_id_map) == 0 + assert len(converter._tool_call_started_set) == 0 + + +# ============================================================================= +# 测试完整流程 +# ============================================================================= + + +class TestCompleteFlow: + """测试完整的工具调用流程""" + + def test_streaming_tool_call_complete_flow(self): + """测试流式工具调用的完整流程 + + 流程: + 1. on_chat_model_stream (id + name) + 2. on_chat_model_stream (args) + 3. on_tool_start (不产生事件,因为已发送) + 4. on_tool_end (TOOL_RESULT) + """ + converter = AgentRunConverter() + + events = [ + # 1. 流式工具调用开始 + { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "call_flow_test", + "name": "weather", + "args": "", + "index": 0, + }] + ) + }, + }, + # 2. 流式工具调用参数 + { + "event": "on_chat_model_stream", + "data": { + "chunk": create_ai_message_chunk( + tool_call_chunks=[{ + "id": "", + "name": "", + "args": '{"city": "北京"}', + "index": 0, + }] + ) + }, + }, + # 3. 工具开始执行(已在流式中发送,不产生事件) + { + "event": "on_tool_start", + "name": "weather", + "run_id": "run_flow", + "data": { + "input": { + "city": "北京", + "runtime": MagicMock(tool_call_id="call_flow_test"), + } + }, + }, + # 4. 工具执行完成 + { + "event": "on_tool_end", + "run_id": "run_flow", + "data": { + "input": { + "city": "北京", + "runtime": MagicMock(tool_call_id="call_flow_test"), + }, + "output": "晴天", + }, + }, + ] + + all_results = [] + for event in events: + all_results.extend(converter.convert(event)) + + # 预期结果: + # - 2 个 TOOL_CALL_CHUNK(流式) + # - 1 个 TOOL_RESULT + chunk_events = filter_agent_events( + all_results, EventType.TOOL_CALL_CHUNK + ) + result_events = filter_agent_events(all_results, EventType.TOOL_RESULT) + + assert len(chunk_events) == 2 + assert len(result_events) == 1 + + # 验证 ID 一致性 + for event in chunk_events + result_events: + assert event.data["id"] == "call_flow_test" + + def test_non_streaming_tool_call_complete_flow(self): + """测试非流式工具调用的完整流程 + + 流程: + 1. on_tool_start (TOOL_CALL_CHUNK) + 2. on_tool_end (TOOL_RESULT) + """ + events = [ + { + "event": "on_tool_start", + "name": "calculator", + "run_id": "run_calc", + "data": {"input": {"a": 1, "b": 2}}, + }, + { + "event": "on_tool_end", + "run_id": "run_calc", + "data": {"output": 3}, + }, + ] + + all_results = convert_and_collect(events) + + chunk_events = filter_agent_events( + all_results, EventType.TOOL_CALL_CHUNK + ) + result_events = filter_agent_events(all_results, EventType.TOOL_RESULT) + + assert len(chunk_events) == 1 + assert len(result_events) == 1 + + # 验证顺序 + chunk_idx = all_results.index(chunk_events[0]) + result_idx = all_results.index(result_events[0]) + assert chunk_idx < result_idx + + +# ============================================================================= +# 测试错误事件 +# ============================================================================= + + +class TestErrorEvents: + """测试 LangChain 错误事件的转换""" + + def test_on_tool_error(self): + """测试 on_tool_error 事件 + + 输入: on_tool_error with error + 输出: ERROR 事件 + """ + event = { + "event": "on_tool_error", + "name": "weather_tool", + "run_id": "run_123", + "data": { + "error": ValueError("Invalid city name"), + "input": {"city": "invalid"}, + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert "weather_tool" in results[0].data["message"] + assert "ValueError" in results[0].data["message"] + assert results[0].data["code"] == "TOOL_ERROR" + assert results[0].data["tool_call_id"] == "run_123" + + def test_on_tool_error_with_runtime_tool_call_id(self): + """测试 on_tool_error 使用 runtime 中的 tool_call_id""" + + class FakeRuntime: + tool_call_id = "call_original_id" + + event = { + "event": "on_tool_error", + "name": "search_tool", + "run_id": "run_456", + "data": { + "error": Exception("API timeout"), + "input": {"query": "test", "runtime": FakeRuntime()}, + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].data["tool_call_id"] == "call_original_id" + + def test_on_tool_error_with_string_error(self): + """测试 on_tool_error 使用字符串错误""" + event = { + "event": "on_tool_error", + "name": "calc_tool", + "run_id": "run_789", + "data": { + "error": "Division by zero", + "input": {}, + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert "Division by zero" in results[0].data["message"] + + def test_on_llm_error(self): + """测试 on_llm_error 事件 + + 输入: on_llm_error with error + 输出: ERROR 事件 + """ + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": RuntimeError("API rate limit exceeded"), + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert "LLM error" in results[0].data["message"] + assert "RuntimeError" in results[0].data["message"] + assert results[0].data["code"] == "LLM_ERROR" + + def test_on_chain_error(self): + """测试 on_chain_error 事件 + + 输入: on_chain_error with error + 输出: ERROR 事件 + """ + event = { + "event": "on_chain_error", + "name": "agent_chain", + "run_id": "run_chain", + "data": { + "error": KeyError("missing_key"), + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert "agent_chain" in results[0].data["message"] + assert "KeyError" in results[0].data["message"] + assert results[0].data["code"] == "CHAIN_ERROR" + + def test_on_retriever_error(self): + """测试 on_retriever_error 事件 + + 输入: on_retriever_error with error + 输出: ERROR 事件 + """ + event = { + "event": "on_retriever_error", + "name": "vector_store", + "run_id": "run_retriever", + "data": { + "error": ConnectionError("Database connection failed"), + }, + } + + results = list(AgentRunConverter.to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert "vector_store" in results[0].data["message"] + assert "ConnectionError" in results[0].data["message"] + assert results[0].data["code"] == "RETRIEVER_ERROR" + + def test_tool_error_in_complete_flow(self): + """测试完整流程中的工具错误 + + 流程: + 1. on_tool_start (TOOL_CALL_CHUNK) + 2. on_tool_error (ERROR) + """ + events = [ + { + "event": "on_tool_start", + "name": "risky_tool", + "run_id": "run_risky", + "data": {"input": {"param": "test"}}, + }, + { + "event": "on_tool_error", + "name": "risky_tool", + "run_id": "run_risky", + "data": { + "error": RuntimeError("Tool execution failed"), + "input": {"param": "test"}, + }, + }, + ] + + all_results = convert_and_collect(events) + + chunk_events = filter_agent_events( + all_results, EventType.TOOL_CALL_CHUNK + ) + error_events = filter_agent_events(all_results, EventType.ERROR) + + assert len(chunk_events) == 1 + assert len(error_events) == 1 + assert chunk_events[0].data["id"] == "run_risky" + assert error_events[0].data["tool_call_id"] == "run_risky" diff --git a/tests/unittests/server/test_agui_event_sequence.py b/tests/unittests/server/test_agui_event_sequence.py new file mode 100644 index 0000000..41b407b --- /dev/null +++ b/tests/unittests/server/test_agui_event_sequence.py @@ -0,0 +1,2182 @@ +"""AG-UI 事件序列测试 + +基于 AG-UI 官方验证器 (verifyEvents) 的规则进行测试。 + +## AG-UI 官方验证规则 + +1. **RUN 生命周期** + - 第一个事件必须是 RUN_STARTED(或 RUN_ERROR) + - RUN_STARTED 不能在活跃 run 期间发送(必须先 RUN_FINISHED) + - RUN_FINISHED 后可以发送新的 RUN_STARTED(新 run) + - RUN_ERROR 后不能再发送任何事件 + - RUN_FINISHED 后不能再发送事件(除了新的 RUN_STARTED) + - RUN_FINISHED 前必须结束所有活跃的 messages、tool calls、steps + +2. **TEXT_MESSAGE 规则**(每个 messageId 独立跟踪) + - TEXT_MESSAGE_START: 同一个 messageId 不能重复开始 + - TEXT_MESSAGE_CONTENT: 必须有对应的活跃 messageId + - TEXT_MESSAGE_END: 必须有对应的活跃 messageId + - TEXT_MESSAGE_END 必须在 TOOL_CALL_START 之前发送 + +3. **TOOL_CALL 规则**(每个 toolCallId 独立跟踪) + - TOOL_CALL_START: 同一个 toolCallId 不能重复开始 + - TOOL_CALL_ARGS: 必须有对应的活跃 toolCallId + - TOOL_CALL_END: 必须有对应的活跃 toolCallId + - TOOL_CALL_START 前必须先结束其他活跃的工具调用(串行化) + +4. **串行化要求**(兼容 CopilotKit 等前端) + - TEXT_MESSAGE 和 TOOL_CALL 不能并行存在 + - 多个 TOOL_CALL 必须串行执行(在新工具开始前必须结束其他工具) + - 注意:AG-UI 协议本身支持并行,但某些前端实现强制串行 + +5. **STEP 规则** + - STEP_STARTED: 同一个 stepName 不能重复开始 + - STEP_FINISHED: 必须有对应的活跃 stepName + +## 测试覆盖矩阵 + +| 规则 | 测试 | +|------|------| +| RUN_STARTED 是第一个事件 | test_run_started_is_first | +| RUN_FINISHED 是最后一个事件 | test_run_finished_is_last | +| RUN_ERROR 后不能发送事件 | test_no_events_after_run_error | +| TEXT_MESSAGE 后 TOOL_CALL | test_text_then_tool_call | +| 多个 TOOL_CALL 串行 | test_tool_calls_serialized | +| RUN_FINISHED 前结束所有 | test_run_finished_ends_all | +""" + +import json +from typing import List + +import pytest + +from agentrun.server import ( + AgentEvent, + AgentRequest, + AgentRunServer, + AGUIProtocolConfig, + EventType, + ServerConfig, +) + + +def parse_sse_line(line: str) -> dict: + """解析 SSE 行""" + if line.startswith("data: "): + return json.loads(line[6:]) + return {} + + +def get_event_types(lines: List[str]) -> List[str]: + """提取所有事件类型""" + types = [] + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + types.append(data.get("type", "")) + return types + + +class TestAguiEventSequence: + """AG-UI 事件序列测试""" + + # ==================== 基本序列测试 ==================== + + @pytest.mark.asyncio + async def test_pure_text_stream(self): + """测试纯文本流的事件序列 + + 预期:RUN_STARTED → TEXT_MESSAGE_START → TEXT_MESSAGE_CONTENT* → TEXT_MESSAGE_END → RUN_FINISHED + """ + + async def invoke_agent(request: AgentRequest): + yield "Hello " + yield "World" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + assert types[0] == "RUN_STARTED" + assert types[1] == "TEXT_MESSAGE_START" + assert types[2] == "TEXT_MESSAGE_CONTENT" + assert types[3] == "TEXT_MESSAGE_CONTENT" + assert types[4] == "TEXT_MESSAGE_END" + assert types[5] == "RUN_FINISHED" + + @pytest.mark.asyncio + async def test_pure_tool_call(self): + """测试纯工具调用的事件序列 + + 预期:RUN_STARTED → TOOL_CALL_START → TOOL_CALL_ARGS → TOOL_CALL_END → TOOL_CALL_RESULT → RUN_FINISHED + """ + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "done"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "call tool"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + assert types == [ + "RUN_STARTED", + "TOOL_CALL_START", + "TOOL_CALL_ARGS", + "TOOL_CALL_END", + "TOOL_CALL_RESULT", + "RUN_FINISHED", + ] + + # ==================== 文本和工具调用交错测试 ==================== + + @pytest.mark.asyncio + async def test_text_then_tool_call(self): + """测试 文本 → 工具调用 + + AG-UI 协议要求:发送 TOOL_CALL_START 前必须先发送 TEXT_MESSAGE_END + """ + + async def invoke_agent(request: AgentRequest): + yield "思考中..." + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "search", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "found"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "search"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证 TEXT_MESSAGE_END 在 TOOL_CALL_START 之前 + text_end_idx = types.index("TEXT_MESSAGE_END") + tool_start_idx = types.index("TOOL_CALL_START") + assert ( + text_end_idx < tool_start_idx + ), "TEXT_MESSAGE_END must come before TOOL_CALL_START" + + @pytest.mark.asyncio + async def test_tool_call_then_text(self): + """测试 工具调用 → 文本 + + 关键点:工具调用后的文本需要新的 TEXT_MESSAGE_START + """ + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "calc", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "42"}, + ) + yield "答案是 42" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "calculate"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证工具调用后有新的 TEXT_MESSAGE_START + assert "TEXT_MESSAGE_START" in types + text_start_idx = types.index("TEXT_MESSAGE_START") + tool_result_idx = types.index("TOOL_CALL_RESULT") + assert ( + text_start_idx > tool_result_idx + ), "TEXT_MESSAGE_START must come after TOOL_CALL_RESULT" + + @pytest.mark.asyncio + async def test_text_tool_text(self): + """测试 文本 → 工具调用 → 文本 + + AG-UI 协议要求: + 1. 发送 TOOL_CALL_START 前必须先发送 TEXT_MESSAGE_END + 2. 工具调用后的新文本需要新的 TEXT_MESSAGE_START + """ + + async def invoke_agent(request: AgentRequest): + yield "让我查一下..." + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "search", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "晴天"}, + ) + yield "今天是晴天。" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "weather"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证有两个 TEXT_MESSAGE_START 和两个 TEXT_MESSAGE_END + assert types.count("TEXT_MESSAGE_START") == 2 + assert types.count("TEXT_MESSAGE_END") == 2 + + # 验证 messageId 不同 + message_ids = [] + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "TEXT_MESSAGE_START": + message_ids.append(data.get("messageId")) + + assert len(message_ids) == 2 + assert ( + message_ids[0] != message_ids[1] + ), "Second text message should have different messageId" + + # ==================== 多工具调用测试 ==================== + + @pytest.mark.asyncio + async def test_sequential_tool_calls(self): + """测试串行工具调用 + + 场景:工具1完成后再调用工具2 + """ + + async def invoke_agent(request: AgentRequest): + # 工具1 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "result1"}, + ) + # 工具2 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-2", "name": "tool2", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-2", "result": "result2"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "run tools"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证两个完整的工具调用序列 + assert types.count("TOOL_CALL_START") == 2 + assert types.count("TOOL_CALL_END") == 2 + assert types.count("TOOL_CALL_RESULT") == 2 + + @pytest.mark.asyncio + async def test_tool_chunk_then_text_without_result(self): + """测试 工具调用(无结果)→ 文本 + + AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先发送 TOOL_CALL_END + 场景:发送工具调用 chunk 后直接输出文本,没有等待结果 + """ + + async def invoke_agent(request: AgentRequest): + # 发送工具调用 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "async_tool", "args_delta": "{}"}, + ) + # 直接输出文本(没有 TOOL_RESULT) + yield "工具已触发,无需等待结果。" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "async"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证 TOOL_CALL_END 在 TEXT_MESSAGE_START 之前 + tool_end_idx = types.index("TOOL_CALL_END") + text_start_idx = types.index("TEXT_MESSAGE_START") + assert ( + tool_end_idx < text_start_idx + ), "TOOL_CALL_END must come before TEXT_MESSAGE_START" + + @pytest.mark.asyncio + async def test_parallel_tool_calls(self): + """测试并行工具调用 + + 场景:同时开始多个工具调用,然后返回结果 + """ + + async def invoke_agent(request: AgentRequest): + # 两个工具同时开始 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-2", "name": "tool2", "args_delta": "{}"}, + ) + # 结果陆续返回 + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "result1"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-2", "result": "result2"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "parallel"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证两个工具调用都正确关闭 + assert types.count("TOOL_CALL_START") == 2 + assert types.count("TOOL_CALL_END") == 2 + assert types.count("TOOL_CALL_RESULT") == 2 + + # ==================== 状态和错误事件测试 ==================== + + @pytest.mark.asyncio + async def test_text_then_state(self): + """测试 文本 → 状态更新 + + 问题:STATE 事件是否需要先关闭 TEXT_MESSAGE? + """ + + async def invoke_agent(request: AgentRequest): + yield "处理中..." + yield AgentEvent( + event=EventType.STATE, + data={"snapshot": {"progress": 50}}, + ) + yield "完成!" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "state"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证事件序列 + assert "STATE_SNAPSHOT" in types + assert "RUN_STARTED" in types + assert "RUN_FINISHED" in types + + @pytest.mark.asyncio + async def test_text_then_error(self): + """测试 文本 → 错误 + + AG-UI 协议允许 RUN_ERROR 在任何时候发送,不需要先关闭 TEXT_MESSAGE + RUN_ERROR 后不能再发送任何事件 + """ + + async def invoke_agent(request: AgentRequest): + yield "处理中..." + yield AgentEvent( + event=EventType.ERROR, + data={"message": "出错了", "code": "ERR001"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "error"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证错误事件存在 + assert "RUN_ERROR" in types + + # RUN_ERROR 是最后一个事件 + assert types[-1] == "RUN_ERROR" + + # 没有 RUN_FINISHED(RUN_ERROR 后不能发送任何事件) + assert "RUN_FINISHED" not in types + + @pytest.mark.asyncio + async def test_tool_call_then_error(self): + """测试 工具调用 → 错误 + + AG-UI 协议允许 RUN_ERROR 在任何时候发送,不需要先发送 TOOL_CALL_END + RUN_ERROR 后不能再发送任何事件 + """ + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "risky_tool", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.ERROR, + data={"message": "工具执行失败", "code": "TOOL_ERROR"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "error"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证错误事件存在 + assert "RUN_ERROR" in types + + # RUN_ERROR 是最后一个事件 + assert types[-1] == "RUN_ERROR" + + # 没有 RUN_FINISHED + assert "RUN_FINISHED" not in types + + @pytest.mark.asyncio + async def test_text_then_custom(self): + """测试 文本 → 自定义事件 + + 问题:CUSTOM 事件是否需要先关闭 TEXT_MESSAGE? + """ + + async def invoke_agent(request: AgentRequest): + yield "处理中..." + yield AgentEvent( + event=EventType.CUSTOM, + data={"name": "progress", "value": {"percent": 50}}, + ) + yield "继续..." + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "custom"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证自定义事件存在 + assert "CUSTOM" in types + + # ==================== 边界情况测试 ==================== + + @pytest.mark.asyncio + async def test_empty_text_ignored(self): + """测试空文本被忽略""" + + async def invoke_agent(request: AgentRequest): + yield "" + yield "Hello" + yield "" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "empty"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 只有一个 TEXT_MESSAGE_CONTENT(非空的那个) + assert types.count("TEXT_MESSAGE_CONTENT") == 1 + + @pytest.mark.asyncio + async def test_tool_call_without_result(self): + """测试没有结果的工具调用 + + 场景:只有 TOOL_CALL_CHUNK,没有 TOOL_RESULT + """ + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "tc-1", + "name": "fire_and_forget", + "args_delta": "{}", + }, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "fire"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证工具调用被正确关闭 + assert "TOOL_CALL_START" in types + assert "TOOL_CALL_END" in types + + @pytest.mark.asyncio + async def test_complex_sequence(self): + """测试复杂序列 + + 文本 → 工具1 → 文本 → 工具2 → 工具3(并行) → 文本 + AG-UI 允许 TEXT_MESSAGE 和 TOOL_CALL 并行,所以文本消息可以持续 + """ + + async def invoke_agent(request: AgentRequest): + yield "分析问题..." + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "analyze", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "分析完成"}, + ) + yield "开始搜索..." + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-2", "name": "search1", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-3", "name": "search2", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-2", "result": "搜索1完成"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-3", "result": "搜索2完成"}, + ) + yield "综合结果..." + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "complex"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证基本结构 + assert types[0] == "RUN_STARTED" + assert types[-1] == "RUN_FINISHED" + + # AG-UI 允许并行,所以可能只有一个文本消息(持续) + assert types.count("TEXT_MESSAGE_START") >= 1 + assert types.count("TEXT_MESSAGE_END") >= 1 + assert types.count("TEXT_MESSAGE_CONTENT") == 3 # 三段文本 + + # 验证工具调用数量 + assert types.count("TOOL_CALL_START") == 3 + assert types.count("TOOL_CALL_END") == 3 + assert types.count("TOOL_CALL_RESULT") == 3 + + @pytest.mark.asyncio + async def test_tool_result_without_start(self): + """测试直接发送 TOOL_RESULT(没有 TOOL_CALL_CHUNK) + + 场景:用户直接发送 TOOL_RESULT,没有先发送 TOOL_CALL_CHUNK + 预期:系统自动补充 TOOL_CALL_START 和 TOOL_CALL_END + """ + + async def invoke_agent(request: AgentRequest): + # 直接发送 TOOL_RESULT,没有 TOOL_CALL_CHUNK + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-orphan", "result": "孤立的结果"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "orphan"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证系统自动补充了 TOOL_CALL_START 和 TOOL_CALL_END + assert "TOOL_CALL_START" in types + assert "TOOL_CALL_END" in types + assert "TOOL_CALL_RESULT" in types + + # 验证顺序:START -> END -> RESULT + start_idx = types.index("TOOL_CALL_START") + end_idx = types.index("TOOL_CALL_END") + result_idx = types.index("TOOL_CALL_RESULT") + assert start_idx < end_idx < result_idx + + @pytest.mark.asyncio + async def test_text_then_tool_result_directly(self): + """测试 文本 → 直接 TOOL_RESULT + + 场景:先输出文本,然后直接发送 TOOL_RESULT(没有 TOOL_CALL_CHUNK) + AG-UI 要求 TEXT_MESSAGE_END 在 TOOL_CALL_START 之前 + 系统自动补充 TOOL_CALL_START 和 TOOL_CALL_END + """ + + async def invoke_agent(request: AgentRequest): + yield "执行结果:" + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-direct", "result": "直接结果"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "direct"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证 TEXT_MESSAGE_END 在 TOOL_CALL_START 之前 + text_end_idx = types.index("TEXT_MESSAGE_END") + tool_start_idx = types.index("TOOL_CALL_START") + assert text_end_idx < tool_start_idx + + @pytest.mark.asyncio + async def test_multiple_parallel_tools_then_text(self): + """测试多个并行工具调用后输出文本 + + 场景:同时开始多个工具调用,然后输出文本 + 在 copilotkit_compatibility=True(默认)模式下,第二个工具的事件会被放入队列, + 等到第一个工具调用收到 RESULT 后再处理。 + + 由于没有 RESULT 事件,队列中的 tc-b 不会被处理,直到文本事件到来时 + 需要先结束 tc-a,然后处理队列中的 tc-b,再结束 tc-b,最后输出文本。 + + 输入事件: + - tc-a CHUNK (START) + - tc-b CHUNK (放入队列,等待 tc-a 完成) + - TEXT "工具已触发" + + 预期输出: + - tc-a START -> tc-a ARGS -> tc-a END -> tc-b START -> tc-b ARGS -> tc-b END + - TEXT_MESSAGE_START -> TEXT_MESSAGE_CONTENT -> TEXT_MESSAGE_END + """ + + async def invoke_agent(request: AgentRequest): + # 并行工具调用 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-a", "name": "tool_a", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-b", "name": "tool_b", "args_delta": "{}"}, + ) + # 直接输出文本(没有等待结果) + yield "工具已触发" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "parallel"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证两个工具都被关闭了 + assert types.count("TOOL_CALL_END") == 2 + + # 验证所有 TOOL_CALL_END 在 TEXT_MESSAGE_START 之前 + text_start_idx = types.index("TEXT_MESSAGE_START") + for i, t in enumerate(types): + if t == "TOOL_CALL_END": + assert i < text_start_idx, ( + f"TOOL_CALL_END at {i} must come before TEXT_MESSAGE_START" + f" at {text_start_idx}" + ) + + @pytest.mark.asyncio + async def test_text_and_tool_interleaved_with_error(self): + """测试文本和工具交错后发生错误 + + 场景:文本 → 工具调用(未完成)→ 错误 + AG-UI 允许 RUN_ERROR 在任何时候发送,不需要先结束其他事件 + """ + + async def invoke_agent(request: AgentRequest): + yield "开始处理..." + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "tc-fail", + "name": "failing_tool", + "args_delta": "{}", + }, + ) + yield AgentEvent( + event=EventType.ERROR, + data={"message": "处理失败", "code": "PROCESS_ERROR"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "fail"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证错误事件存在 + assert "RUN_ERROR" in types + + # RUN_ERROR 是最后一个事件 + assert types[-1] == "RUN_ERROR" + + # 没有 RUN_FINISHED + assert "RUN_FINISHED" not in types + + @pytest.mark.asyncio + async def test_state_between_text_chunks(self): + """测试在文本流中间发送状态事件 + + 场景:文本 → 状态 → 文本(同一个消息) + 预期:状态事件不会打断文本消息 + """ + + async def invoke_agent(request: AgentRequest): + yield "第一部分" + yield AgentEvent( + event=EventType.STATE, + data={"snapshot": {"progress": 50}}, + ) + yield "第二部分" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "state"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证只有一个文本消息(状态事件没有打断) + assert types.count("TEXT_MESSAGE_START") == 1 + assert types.count("TEXT_MESSAGE_END") == 1 + + # 验证状态事件存在 + assert "STATE_SNAPSHOT" in types + + @pytest.mark.asyncio + async def test_custom_between_text_chunks(self): + """测试在文本流中间发送自定义事件 + + 场景:文本 → 自定义 → 文本(同一个消息) + 预期:自定义事件不会打断文本消息 + """ + + async def invoke_agent(request: AgentRequest): + yield "第一部分" + yield AgentEvent( + event=EventType.CUSTOM, + data={"name": "metrics", "value": {"tokens": 100}}, + ) + yield "第二部分" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "custom"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证只有一个文本消息(自定义事件没有打断) + assert types.count("TEXT_MESSAGE_START") == 1 + assert types.count("TEXT_MESSAGE_END") == 1 + + # 验证自定义事件存在 + assert "CUSTOM" in types + + @pytest.mark.asyncio + async def test_no_events_after_run_error(self): + """测试 RUN_ERROR 后不再发送任何事件 + + AG-UI 协议规则:RUN_ERROR 是终结事件,之后不能再发送任何事件 + (包括 TEXT_MESSAGE_START、TEXT_MESSAGE_END、RUN_FINISHED 等) + """ + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.ERROR, + data={"message": "发生错误", "code": "TEST_ERROR"}, + ) + # 错误后继续输出文本(应该被忽略) + yield "这段文本不应该出现" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "error"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证 RUN_ERROR 存在 + assert "RUN_ERROR" in types + + # 验证 RUN_ERROR 是最后一个事件 + assert types[-1] == "RUN_ERROR" + + # 验证没有 RUN_FINISHED(RUN_ERROR 后不应该有) + assert "RUN_FINISHED" not in types + + # 验证没有 TEXT_MESSAGE_START(错误后的文本应该被忽略) + assert "TEXT_MESSAGE_START" not in types + + @pytest.mark.asyncio + async def test_text_error_text_ignored(self): + """测试 文本 → 错误 → 文本(后续文本被忽略) + + 场景:先输出文本,发生错误,然后继续输出文本 + 预期: + 1. AG-UI 允许 RUN_ERROR 在任何时候发送 + 2. 错误后的文本被忽略 + 3. 没有 RUN_FINISHED + """ + + async def invoke_agent(request: AgentRequest): + yield "处理中..." + yield AgentEvent( + event=EventType.ERROR, + data={"message": "处理失败", "code": "PROCESS_ERROR"}, + ) + yield "这段不应该出现" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "error"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证基本结构 + assert "RUN_STARTED" in types + assert "TEXT_MESSAGE_START" in types + assert "TEXT_MESSAGE_CONTENT" in types + assert "RUN_ERROR" in types + + # 验证 RUN_ERROR 是最后一个事件 + assert types[-1] == "RUN_ERROR" + + # 验证没有 RUN_FINISHED + assert "RUN_FINISHED" not in types + + # 验证只有一个文本消息(错误后的不应该出现) + assert types.count("TEXT_MESSAGE_START") == 1 + assert types.count("TEXT_MESSAGE_CONTENT") == 1 + + @pytest.mark.asyncio + async def test_tool_error_tool_ignored(self): + """测试 工具调用 → 错误 → 工具调用(后续工具被忽略) + + 场景:开始工具调用,发生错误,然后继续工具调用 + 预期: + 1. TOOL_CALL_END 在 RUN_ERROR 之前 + 2. 错误后的工具调用被忽略 + 3. 没有 RUN_FINISHED + """ + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.ERROR, + data={"message": "工具失败", "code": "TOOL_ERROR"}, + ) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-2", "name": "tool2", "args_delta": "{}"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "error"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证 RUN_ERROR 是最后一个事件 + assert types[-1] == "RUN_ERROR" + + # 验证没有 RUN_FINISHED + assert "RUN_FINISHED" not in types + + # 验证只有一个工具调用(错误后的不应该出现) + assert types.count("TOOL_CALL_START") == 1 + + @pytest.mark.asyncio + async def test_tool_calls_serialized_copilotkit_mode(self): + """测试工具调用串行化(CopilotKit 兼容模式) + + 场景:发送 tc-1 的 CHUNK,然后发送 tc-2 的 CHUNK(没有显式结束 tc-1) + 预期:在 copilotkit_compatibility=True 模式下,发送 tc-2 START 前会自动结束 tc-1 + + 注意:AG-UI 协议本身支持并行工具调用,但某些前端实现(如 CopilotKit) + 强制要求串行化。为了兼容性,我们提供 copilotkit_compatibility 模式。 + """ + from agentrun.server import AGUIProtocolConfig, ServerConfig + + async def invoke_agent(request: AgentRequest): + # 第一个工具调用开始 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": '{"a": 1}'}, + ) + # 第二个工具调用开始(会自动先结束 tc-1) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-2", "name": "tool2", "args_delta": '{"b": 2}'}, + ) + # 结果(顺序返回) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "result1"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-2", "result": "result2"}, + ) + + # 启用 CopilotKit 兼容模式 + config = ServerConfig( + agui=AGUIProtocolConfig(copilotkit_compatibility=True) + ) + server = AgentRunServer(invoke_agent=invoke_agent, config=config) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "parallel"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证两个工具调用都存在 + assert types.count("TOOL_CALL_START") == 2 + assert types.count("TOOL_CALL_END") == 2 + assert types.count("TOOL_CALL_RESULT") == 2 + + # 验证串行化工具调用的顺序: + # tc-1 START -> tc-1 ARGS -> tc-1 END -> tc-2 START -> tc-2 ARGS -> tc-1 RESULT -> tc-2 END -> tc-2 RESULT + # 关键验证:tc-1 END 在 tc-2 START 之前(串行化) + tc1_end_idx = None + tc2_start_idx = None + for i, line in enumerate(lines): + if "TOOL_CALL_END" in line and "tc-1" in line: + tc1_end_idx = i + if "TOOL_CALL_START" in line and "tc-2" in line: + tc2_start_idx = i + + assert tc1_end_idx is not None, "tc-1 TOOL_CALL_END not found" + assert tc2_start_idx is not None, "tc-2 TOOL_CALL_START not found" + # 串行化工具调用:tc-1 END 应该在 tc-2 START 之前 + assert ( + tc1_end_idx < tc2_start_idx + ), "Serialized tool calls: tc-1 END should come before tc-2 START" + + # ==================== AG-UI 官方验证器规则测试 ==================== + + @pytest.mark.asyncio + async def test_run_started_is_first(self): + """AG-UI 规则:第一个事件必须是 RUN_STARTED""" + + async def invoke_agent(request: AgentRequest): + yield "Hello" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 第一个事件必须是 RUN_STARTED + assert types[0] == "RUN_STARTED" + + @pytest.mark.asyncio + async def test_run_finished_is_last_normal(self): + """AG-UI 规则:正常结束时 RUN_FINISHED 是最后一个事件""" + + async def invoke_agent(request: AgentRequest): + yield "Hello" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 最后一个事件是 RUN_FINISHED + assert types[-1] == "RUN_FINISHED" + + @pytest.mark.asyncio + async def test_run_error_is_last_on_error(self): + """AG-UI 规则:错误时 RUN_ERROR 是最后一个事件""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.ERROR, + data={"message": "Error", "code": "ERR"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 最后一个事件是 RUN_ERROR + assert types[-1] == "RUN_ERROR" + # 没有 RUN_FINISHED + assert "RUN_FINISHED" not in types + + @pytest.mark.asyncio + async def test_run_finished_ends_all_messages(self): + """AG-UI 规则:RUN_FINISHED 前必须结束所有 TEXT_MESSAGE""" + + async def invoke_agent(request: AgentRequest): + # 只发送 TEXT,不显式结束 + yield "Hello" + yield "World" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证 TEXT_MESSAGE_END 在 RUN_FINISHED 之前 + assert "TEXT_MESSAGE_END" in types + text_end_idx = types.index("TEXT_MESSAGE_END") + run_finished_idx = types.index("RUN_FINISHED") + assert text_end_idx < run_finished_idx + + @pytest.mark.asyncio + async def test_run_finished_ends_all_tool_calls(self): + """AG-UI 规则:RUN_FINISHED 前必须结束所有 TOOL_CALL""" + + async def invoke_agent(request: AgentRequest): + # 只发送 TOOL_CALL_CHUNK,不显式结束 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "{}"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证 TOOL_CALL_END 在 RUN_FINISHED 之前 + assert "TOOL_CALL_END" in types + tool_end_idx = types.index("TOOL_CALL_END") + run_finished_idx = types.index("RUN_FINISHED") + assert tool_end_idx < run_finished_idx + + @pytest.mark.asyncio + async def test_text_message_id_consistency(self): + """AG-UI 规则:TEXT_MESSAGE 的 messageId 必须一致""" + + async def invoke_agent(request: AgentRequest): + yield "Hello" + yield " World" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + + # 提取所有 messageId + message_ids = [] + for line in lines: + data = parse_sse_line(line) + if data and "messageId" in data: + if data.get("type") in [ + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + ]: + message_ids.append(data["messageId"]) + + # 所有 messageId 应该相同(同一个消息) + assert len(set(message_ids)) == 1 + + @pytest.mark.asyncio + async def test_tool_call_id_consistency(self): + """AG-UI 规则:TOOL_CALL 的 toolCallId 必须一致""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": '{"a":'}, + ) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "1}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "done"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + + # 提取所有 toolCallId + tool_call_ids = [] + for line in lines: + data = parse_sse_line(line) + if data and "toolCallId" in data: + tool_call_ids.append(data["toolCallId"]) + + # 所有 toolCallId 应该相同(同一个工具调用) + assert len(set(tool_call_ids)) == 1 + assert tool_call_ids[0] == "tc-1" + + @pytest.mark.asyncio + async def test_text_tool_text_sequence(self): + """AG-UI 规则:TEXT_MESSAGE 和 TOOL_CALL 不能并行 + + 文本 → 工具调用 → 文本 需要两个独立的 TEXT_MESSAGE + """ + + async def invoke_agent(request: AgentRequest): + yield "开始..." + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "{}"}, + ) + yield "继续..." # 工具调用后的文本 + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证有两个独立的 TEXT_MESSAGE + assert types.count("TEXT_MESSAGE_START") == 2 + assert types.count("TEXT_MESSAGE_CONTENT") == 2 + assert types.count("TEXT_MESSAGE_END") == 2 + + # 验证第一个 TEXT_MESSAGE_END 在 TOOL_CALL_START 之前 + first_text_end_idx = types.index("TEXT_MESSAGE_END") + tool_start_idx = types.index("TOOL_CALL_START") + assert first_text_end_idx < tool_start_idx + + @pytest.mark.asyncio + async def test_multiple_tool_calls_parallel(self): + """AG-UI 规则:多个 TOOL_CALL 可以并行(但在 copilotkit_compatibility=True 时会串行) + + 在 copilotkit_compatibility=True(默认)模式下,其他工具的事件会被放入队列, + 等到当前工具调用收到 RESULT 后再处理。 + + 输入事件: + - tc-1 CHUNK (START) + - tc-2 CHUNK (放入队列) + - tc-3 CHUNK (放入队列) + - tc-2 RESULT (放入队列,因为 tc-1 还没有 RESULT) + - tc-1 RESULT (处理队列中的事件) + - tc-3 RESULT + + 预期输出: + - tc-1 START -> tc-1 ARGS -> tc-1 END -> tc-1 RESULT + - tc-2 START -> tc-2 ARGS -> tc-2 END -> tc-2 RESULT + - tc-3 START -> tc-3 ARGS -> tc-3 END -> tc-3 RESULT + """ + + async def invoke_agent(request: AgentRequest): + # 开始第一个工具 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "{}"}, + ) + # 开始第二个工具(会被放入队列) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-2", "name": "tool2", "args_delta": "{}"}, + ) + # 开始第三个工具(会被放入队列) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-3", "name": "tool3", "args_delta": "{}"}, + ) + # 结果陆续返回 + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-2", "result": "result2"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "result1"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-3", "result": "result3"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证三个工具调用都存在 + assert types.count("TOOL_CALL_START") == 3 + assert types.count("TOOL_CALL_END") == 3 + assert types.count("TOOL_CALL_RESULT") == 3 + + @pytest.mark.asyncio + async def test_same_tool_call_id_not_duplicated(self): + """AG-UI 规则:同一个 toolCallId 不能重复 START""" + + async def invoke_agent(request: AgentRequest): + # 发送多个相同 ID 的 CHUNK(应该只生成一个 START) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": '{"a":'}, + ) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "1}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "done"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 只应该有一个 TOOL_CALL_START + assert types.count("TOOL_CALL_START") == 1 + # 但有两个 TOOL_CALL_ARGS + assert types.count("TOOL_CALL_ARGS") == 2 + + @pytest.mark.asyncio + async def test_interleaved_tool_calls_with_repeated_args(self): + """测试交错的工具调用(带重复的 ARGS 事件) + + 场景:模拟 LangChain 流式输出时可能产生的交错事件序列 + 在 copilotkit_compatibility=True(默认)模式下,其他工具的事件会被放入队列, + 等到当前工具调用收到 RESULT 后再处理。 + + 输入事件: + - tc-1 CHUNK (START) + - tc-2 CHUNK (放入队列,等待 tc-1 完成) + - tc-1 CHUNK (ARGS,同一个工具调用,继续处理) + - tc-3 CHUNK (放入队列,等待 tc-1 完成) + - tc-1 RESULT (处理队列中的 tc-2 和 tc-3) + - tc-2 RESULT + - tc-3 RESULT + + 预期输出: + - tc-1 START -> tc-1 ARGS -> tc-1 ARGS -> tc-1 END -> tc-1 RESULT + - tc-2 START -> tc-2 ARGS -> tc-2 END -> tc-2 RESULT + - tc-3 START -> tc-3 ARGS -> tc-3 END -> tc-3 RESULT + """ + + async def invoke_agent(request: AgentRequest): + # 第一个工具调用 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "get_time", "args_delta": '{"tz":'}, + ) + # 第二个工具调用(会被放入队列) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-2", "name": "get_user", "args_delta": ""}, + ) + # tc-1 的额外 ARGS(同一个工具调用,继续处理) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "tc-1", + "name": "get_time", + "args_delta": '"Asia"}', + }, + ) + # 第三个工具调用(会被放入队列) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "tc-3", + "name": "get_token", + "args_delta": '{"user":"test"}', + }, + ) + # 结果 + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "time result"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-2", "result": "user result"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-3", "result": "token result"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证事件序列正确 + assert types[0] == "RUN_STARTED" + assert types[-1] == "RUN_FINISHED" + + # 每个工具调用只有一个 START + assert types.count("TOOL_CALL_START") == 3 + assert types.count("TOOL_CALL_END") == 3 + assert types.count("TOOL_CALL_RESULT") == 3 + + # 验证串行化:每个 ARGS 都有对应的活跃 START + tool_states = {} + for line in lines: + data = parse_sse_line(line) + event_type = data.get("type", "") + tool_id = data.get("toolCallId", "") + + if event_type == "TOOL_CALL_START": + tool_states[tool_id] = {"started": True, "ended": False} + elif event_type == "TOOL_CALL_END": + if tool_id in tool_states: + tool_states[tool_id]["ended"] = True + elif event_type == "TOOL_CALL_ARGS": + assert ( + tool_id in tool_states + ), f"ARGS for {tool_id} without START" + assert not tool_states[tool_id][ + "ended" + ], f"ARGS for {tool_id} after END" + + @pytest.mark.asyncio + async def test_text_tool_interleaved_complex(self): + """测试复杂的文本和工具调用交错场景 + + 场景:模拟真实的 LLM 输出 + 在 copilotkit_compatibility=True(默认)模式下,其他工具的事件会被放入队列, + 等到当前工具调用收到 RESULT 后再处理。 + + 输入事件: + 1. 文本 "让我查一下..." + 2. tc-a CHUNK (START) + 3. tc-b CHUNK (放入队列,等待 tc-a 完成) + 4. tc-a CHUNK (ARGS,同一个工具调用,继续处理) + 5. tc-a RESULT (处理队列中的 tc-b) + 6. tc-b RESULT + 7. 文本 "根据结果..." + + 预期输出: + - TEXT_MESSAGE_START -> TEXT_MESSAGE_CONTENT -> TEXT_MESSAGE_END + - tc-a START -> tc-a ARGS -> tc-a ARGS -> tc-a END -> tc-a RESULT + - tc-b START -> tc-b ARGS -> tc-b END -> tc-b RESULT + - TEXT_MESSAGE_START -> TEXT_MESSAGE_CONTENT -> TEXT_MESSAGE_END + """ + + async def invoke_agent(request: AgentRequest): + # 第一段文本 + yield "让我查一下..." + # 工具调用 A + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-a", "name": "search", "args_delta": '{"q":'}, + ) + # 工具调用 B(会被放入队列) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "tc-b", + "name": "lookup", + "args_delta": '{"id": 1}', + }, + ) + # A 的额外 ARGS(同一个工具调用,继续处理) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-a", "name": "search", "args_delta": '"test"}'}, + ) + # 结果 + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-a", "result": "search result"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-b", "result": "lookup result"}, + ) + # 第二段文本 + yield "根据结果..." + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "complex"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证基本事件存在 + assert types[0] == "RUN_STARTED" + assert types[-1] == "RUN_FINISHED" + + # 验证文本消息 + assert types.count("TEXT_MESSAGE_START") == 2 + assert types.count("TEXT_MESSAGE_END") == 2 + + # 验证工具调用 + # 每个工具调用只有一个 START + assert types.count("TOOL_CALL_START") == 2 # tc-a 一次, tc-b 一次 + assert types.count("TOOL_CALL_END") == 2 + assert types.count("TOOL_CALL_RESULT") == 2 + assert types.count("TOOL_CALL_ARGS") == 3 # tc-a 两次, tc-b 一次 + + @pytest.mark.asyncio + async def test_tool_call_args_after_end(self): + """测试工具调用结束后收到 ARGS 事件 + + 场景:LangChain 交错输出时,可能在 tc-1 END 后收到 tc-1 的 ARGS + 在 copilotkit_compatibility=True(默认)模式下,其他工具的事件会被放入队列, + 等到当前工具调用收到 RESULT 后再处理。 + + 输入事件: + - tc-1 CHUNK (START) + - tc-2 CHUNK (放入队列,等待 tc-1 完成) + - tc-1 CHUNK (ARGS,同一个工具调用,继续处理) + - tc-1 RESULT (处理队列中的 tc-2) + - tc-2 RESULT + + 预期输出: + - tc-1 START -> tc-1 ARGS -> tc-1 ARGS -> tc-1 END -> tc-1 RESULT + - tc-2 START -> tc-2 ARGS -> tc-2 END -> tc-2 RESULT + """ + + async def invoke_agent(request: AgentRequest): + # tc-1 开始 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": '{"a":'}, + ) + # tc-2(会被放入队列) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-2", "name": "tool2", "args_delta": '{"b": 2}'}, + ) + # tc-1 的额外 ARGS(同一个工具调用,继续处理) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "1}"}, + ) + # 结果 + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "result1"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-2", "result": "result2"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证事件序列 + assert types[0] == "RUN_STARTED" + assert types[-1] == "RUN_FINISHED" + + # 每个工具调用只有一个 START + assert types.count("TOOL_CALL_START") == 2 + assert types.count("TOOL_CALL_END") == 2 + assert types.count("TOOL_CALL_RESULT") == 2 + + # 验证每个 ARGS 都有对应的活跃 START + # 检查事件序列中没有 ARGS 出现在 END 之后(对于同一个 tool_id) + tool_states = {} # tool_id -> {"started": bool, "ended": bool} + for line in lines: + data = parse_sse_line(line) + event_type = data.get("type", "") + tool_id = data.get("toolCallId", "") + + if event_type == "TOOL_CALL_START": + tool_states[tool_id] = {"started": True, "ended": False} + elif event_type == "TOOL_CALL_END": + if tool_id in tool_states: + tool_states[tool_id]["ended"] = True + elif event_type == "TOOL_CALL_ARGS": + # ARGS 必须在 START 之后,END 之前 + assert ( + tool_id in tool_states + ), f"ARGS for {tool_id} without START" + assert not tool_states[tool_id][ + "ended" + ], f"ARGS for {tool_id} after END" + + @pytest.mark.asyncio + async def test_parallel_tool_calls_standard_mode(self): + """测试标准模式下的并行工具调用 + + 场景:默认模式(copilotkit_compatibility=False)允许并行工具调用 + 预期:tc-2 START 可以在 tc-1 END 之前发送 + """ + + async def invoke_agent(request: AgentRequest): + # 第一个工具调用开始 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": '{"a": 1}'}, + ) + # 第二个工具调用开始(并行) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-2", "name": "tool2", "args_delta": '{"b": 2}'}, + ) + # 结果 + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "result1"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-2", "result": "result2"}, + ) + + # 使用默认配置(标准模式) + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "parallel"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证两个工具调用都存在 + assert types.count("TOOL_CALL_START") == 2 + assert types.count("TOOL_CALL_END") == 2 + assert types.count("TOOL_CALL_RESULT") == 2 + + # 验证并行工具调用的顺序: + # tc-1 START -> tc-1 ARGS -> tc-2 START -> tc-2 ARGS -> tc-1 END -> tc-1 RESULT -> tc-2 END -> tc-2 RESULT + # 关键验证:tc-2 START 在 tc-1 END 之前(并行) + tc1_end_idx = None + tc2_start_idx = None + for i, line in enumerate(lines): + if "TOOL_CALL_END" in line and "tc-1" in line: + tc1_end_idx = i + if "TOOL_CALL_START" in line and "tc-2" in line: + tc2_start_idx = i + + assert tc1_end_idx is not None, "tc-1 TOOL_CALL_END not found" + assert tc2_start_idx is not None, "tc-2 TOOL_CALL_START not found" + # 并行工具调用:tc-2 START 应该在 tc-1 END 之前 + assert ( + tc2_start_idx < tc1_end_idx + ), "Parallel tool calls: tc-2 START should come before tc-1 END" + + @pytest.mark.asyncio + async def test_langchain_duplicate_tool_call_with_uuid_id_copilotkit_mode( + self, + ): + """测试 LangChain 重复工具调用(使用 UUID 格式 ID)- CopilotKit 兼容模式 + + 场景:模拟 LangChain 流式输出时可能产生的重复事件: + - 流式 chunk 使用 call_xxx ID + - on_tool_start 使用 UUID 格式的 run_id + - 两者对应同一个逻辑工具调用 + + 预期:在 copilotkit_compatibility=True 模式下,UUID 格式的重复事件应该被忽略或合并到已有的 call_xxx ID + """ + from agentrun.server import AGUIProtocolConfig, ServerConfig + + async def invoke_agent(request: AgentRequest): + # 流式 chunk:使用 call_xxx ID + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "call_abc123", + "name": "get_weather", + "args_delta": '{"city": "Beijing"}', + }, + ) + # on_tool_start:使用 UUID 格式的 run_id(同一个工具调用) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890", + "name": "get_weather", + "args_delta": '{"city": "Beijing"}', + }, + ) + # on_tool_end:使用 UUID 格式的 run_id + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={ + "id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890", + "name": "get_weather", + "result": "Sunny, 25°C", + }, + ) + + # 启用 CopilotKit 兼容模式 + config = ServerConfig( + agui=AGUIProtocolConfig(copilotkit_compatibility=True) + ) + server = AgentRunServer(invoke_agent=invoke_agent, config=config) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "weather"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证只有一个工具调用(UUID 重复事件被合并) + assert ( + types.count("TOOL_CALL_START") == 1 + ), f"Expected 1 TOOL_CALL_START, got {types.count('TOOL_CALL_START')}" + assert ( + types.count("TOOL_CALL_END") == 1 + ), f"Expected 1 TOOL_CALL_END, got {types.count('TOOL_CALL_END')}" + assert ( + types.count("TOOL_CALL_RESULT") == 1 + ), f"Expected 1 TOOL_CALL_RESULT, got {types.count('TOOL_CALL_RESULT')}" + + # 验证使用的是 call_xxx ID,不是 UUID + for line in lines: + if "TOOL_CALL_START" in line: + data = json.loads(line.replace("data: ", "")) + assert ( + data["toolCallId"] == "call_abc123" + ), f"Expected call_abc123, got {data['toolCallId']}" + if "TOOL_CALL_RESULT" in line: + data = json.loads(line.replace("data: ", "")) + assert ( + data["toolCallId"] == "call_abc123" + ), f"Expected call_abc123, got {data['toolCallId']}" + + # @pytest.mark.asyncio + # async def test_langchain_multiple_tools_with_uuid_ids_copilotkit_mode(self): + # """测试 LangChain 多个工具调用(使用 UUID 格式 ID)- CopilotKit 兼容模式 + + # 场景:模拟 LangChain 并行调用多个工具时的事件序列: + # - 流式 chunk 使用 call_xxx ID + # - on_tool_start/end 使用 UUID 格式的 run_id + # - 需要正确匹配每个工具的 ID + + # 预期:在 copilotkit_compatibility=True 模式下,每个工具调用应该使用正确的 ID,不会混淆 + # """ + # from agentrun.server import AGUIProtocolConfig, ServerConfig + + # async def invoke_agent(request: AgentRequest): + # # 第一个工具的流式 chunk + # yield AgentEvent( + # event=EventType.TOOL_CALL_CHUNK, + # data={ + # "id": "call_tool1", + # "name": "get_weather", + # "args_delta": '{"city": "Beijing"}', + # }, + # ) + # # 第二个工具的流式 chunk + # yield AgentEvent( + # event=EventType.TOOL_CALL_CHUNK, + # data={ + # "id": "call_tool2", + # "name": "get_time", + # "args_delta": '{"timezone": "UTC"}', + # }, + # ) + # # 第一个工具的 on_tool_start(UUID) + # yield AgentEvent( + # event=EventType.TOOL_CALL, # 改为 TOOL_CALL + # data={ + # "id": "uuid-weather-123", + # "name": "get_weather", + # "args": '{"city": "Beijing"}', + # }, + # ) + # # 第二个工具的 on_tool_start(UUID) + # yield AgentEvent( + # event=EventType.TOOL_CALL, # 改为 TOOL_CALL + # data={ + # "id": "uuid-time-456", + # "name": "get_time", + # "args": '{"timezone": "UTC"}', + # }, + # ) + # # 第一个工具的 on_tool_end(UUID) + # yield AgentEvent( + # event=EventType.TOOL_RESULT, + # data={ + # "id": "uuid-weather-123", + # "name": "get_weather", + # "result": "Sunny, 25°C", + # }, + # ) + # # 第二个工具的 on_tool_end(UUID) + # yield AgentEvent( + # event=EventType.TOOL_RESULT, + # data={ + # "id": "uuid-time-456", + # "name": "get_time", + # "result": "12:00 UTC", + # }, + # ) + + # # 启用 CopilotKit 兼容模式 + # config = ServerConfig( + # agui=AGUIProtocolConfig(copilotkit_compatibility=True) + # ) + # server = AgentRunServer(invoke_agent=invoke_agent, config=config) + # app = server.as_fastapi_app() + # from fastapi.testclient import TestClient + + # client = TestClient(app) + # response = client.post( + # "/ag-ui/agent", + # json={"messages": [{"role": "user", "content": "tools"}]}, + # ) + + # lines = [line async for line in response.aiter_lines() if line] + # types = get_event_types(lines) + + # # 验证只有两个工具调用(UUID 重复事件被合并) + # assert ( + # types.count("TOOL_CALL_START") == 2 + # ), f"Expected 2 TOOL_CALL_START, got {types.count('TOOL_CALL_START')}" + # assert ( + # types.count("TOOL_CALL_END") == 2 + # ), f"Expected 2 TOOL_CALL_END, got {types.count('TOOL_CALL_END')}" + # assert ( + # types.count("TOOL_CALL_RESULT") == 2 + # ), f"Expected 2 TOOL_CALL_RESULT, got {types.count('TOOL_CALL_RESULT')}" + + # # 验证使用的是 call_xxx ID,不是 UUID + # tool_call_ids = [] + # for line in lines: + # if "TOOL_CALL_START" in line: + # data = json.loads(line.replace("data: ", "")) + # tool_call_ids.append(data["toolCallId"]) + + # assert ( + # "call_tool1" in tool_call_ids or "call_tool2" in tool_call_ids + # ), f"Expected call_xxx IDs, got {tool_call_ids}" + + @pytest.mark.asyncio + async def test_langchain_uuid_not_deduplicated_standard_mode(self): + """测试标准模式下,UUID 格式 ID 不会被去重 + + 场景:默认模式(copilotkit_compatibility=False)下,UUID 格式的 ID 应该被视为独立的工具调用 + """ + + async def invoke_agent(request: AgentRequest): + # 流式 chunk:使用 call_xxx ID + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "call_abc123", + "name": "get_weather", + "args_delta": '{"city": "Beijing"}', + }, + ) + # on_tool_start:使用 UUID 格式的 run_id(同一个工具调用) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890", + "name": "get_weather", + "args_delta": '{"city": "Beijing"}', + }, + ) + # 两个工具的结果 + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={ + "id": "call_abc123", + "name": "get_weather", + "result": "Sunny, 25°C", + }, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={ + "id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890", + "name": "get_weather", + "result": "Sunny, 25°C", + }, + ) + + # 使用默认配置(标准模式) + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "weather"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 禁用串行化时,UUID 格式的 ID 不会被去重,所以有两个工具调用 + assert ( + types.count("TOOL_CALL_START") == 2 + ), f"Expected 2 TOOL_CALL_START, got {types.count('TOOL_CALL_START')}" + assert ( + types.count("TOOL_CALL_END") == 2 + ), f"Expected 2 TOOL_CALL_END, got {types.count('TOOL_CALL_END')}" + assert ( + types.count("TOOL_CALL_RESULT") == 2 + ), f"Expected 2 TOOL_CALL_RESULT, got {types.count('TOOL_CALL_RESULT')}" + + @pytest.mark.asyncio + async def test_tool_result_after_another_tool_started(self): + """测试在另一个工具调用开始后收到之前工具的 RESULT + + 场景: + 1. tc-1 START, ARGS, END + 2. tc-2 START, ARGS + 3. tc-1 ARGS (tc-1 已结束,会重新开始) + 4. tc-1 RESULT (此时 tc-1 是活跃的,需要先结束) + + 预期:在发送 tc-1 RESULT 前,应该先结束所有活跃的工具调用 + """ + + async def invoke_agent(request: AgentRequest): + # tc-1 开始 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "call_tc1", + "name": "tool1", + "args_delta": '{"a": 1}', + }, + ) + # tc-2 开始(会自动结束 tc-1) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "call_tc2", + "name": "tool2", + "args_delta": '{"b": 2}', + }, + ) + # tc-1 的额外 ARGS(tc-1 已结束,会重新开始) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "call_tc1", "name": "tool1", "args_delta": ""}, + ) + # tc-1 的 RESULT(此时 tc-1 是活跃的) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "call_tc1", "result": "result1"}, + ) + # tc-2 的 RESULT + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "call_tc2", "result": "result2"}, + ) + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + client = TestClient(app) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + lines = [line async for line in response.aiter_lines() if line] + types = get_event_types(lines) + + # 验证事件序列 + assert types[0] == "RUN_STARTED" + assert types[-1] == "RUN_FINISHED" + + # 验证没有连续的 TOOL_CALL_START 和 TOOL_CALL_RESULT(中间应该有 TOOL_CALL_END) + for i in range(len(types) - 1): + if types[i] == "TOOL_CALL_START": + # 下一个不应该是 TOOL_CALL_RESULT + assert types[i + 1] != "TOOL_CALL_RESULT", ( + "TOOL_CALL_RESULT should not immediately follow" + f" TOOL_CALL_START at index {i}" + ) + + # 验证所有 TOOL_CALL_RESULT 之前都有对应的 TOOL_CALL_END + assert types.count("TOOL_CALL_RESULT") == 2 + assert types.count("TOOL_CALL_END") >= types.count("TOOL_CALL_RESULT") diff --git a/tests/unittests/server/test_agui_normalizer.py b/tests/unittests/server/test_agui_normalizer.py new file mode 100644 index 0000000..d785910 --- /dev/null +++ b/tests/unittests/server/test_agui_normalizer.py @@ -0,0 +1,454 @@ +"""测试 AG-UI 事件规范化器 + +测试 AguiEventNormalizer 类的功能: +- 追踪工具调用状态 +- 字符串和字典输入转换 +- 状态重置 + +注意:边界事件(TOOL_CALL_START/END、TEXT_MESSAGE_START/END) +现在由协议层自动生成,AguiEventNormalizer 主要用于状态追踪。 +""" + +import pytest + +from agentrun.server import AgentEvent, AguiEventNormalizer, EventType + + +class TestAguiEventNormalizer: + """测试 AguiEventNormalizer 类""" + + def test_pass_through_text_events(self): + """测试文本事件直接传递""" + normalizer = AguiEventNormalizer() + + event = AgentEvent( + event=EventType.TEXT, + data={"delta": "Hello"}, + ) + results = list(normalizer.normalize(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TEXT + assert results[0].data["delta"] == "Hello" + + def test_pass_through_custom_events(self): + """测试自定义事件直接传递""" + normalizer = AguiEventNormalizer() + + event = AgentEvent( + event=EventType.CUSTOM, + data={"name": "step_started", "value": {"step": "test"}}, + ) + results = list(normalizer.normalize(event)) + + assert len(results) == 1 + assert results[0].event == EventType.CUSTOM + + def test_tool_call_chunk_tracking(self): + """测试 TOOL_CALL_CHUNK 状态追踪""" + normalizer = AguiEventNormalizer() + + # 发送 CHUNK 事件 + event = AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "call_1", "name": "test", "args_delta": '{"x": 1}'}, + ) + results = list(normalizer.normalize(event)) + + # 事件直接传递 + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL_CHUNK + assert results[0].data["id"] == "call_1" + + # 状态被追踪 + assert "call_1" in normalizer.get_seen_tool_calls() + assert "call_1" in normalizer.get_active_tool_calls() + + def test_tool_call_tracking(self): + """测试 TOOL_CALL 状态追踪""" + normalizer = AguiEventNormalizer() + + event = AgentEvent( + event=EventType.TOOL_CALL, + data={"id": "call_2", "name": "search", "args": '{"q": "hello"}'}, + ) + results = list(normalizer.normalize(event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL + + # 状态被追踪 + assert "call_2" in normalizer.get_seen_tool_calls() + assert "call_2" in normalizer.get_active_tool_calls() + + def test_tool_result_marks_tool_call_complete(self): + """测试 TOOL_RESULT 标记工具调用完成""" + normalizer = AguiEventNormalizer() + + # 先发送工具调用 + chunk_event = AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "call_1", "name": "test", "args_delta": "{}"}, + ) + list(normalizer.normalize(chunk_event)) + assert "call_1" in normalizer.get_active_tool_calls() + + # 发送结果 + result_event = AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "call_1", "result": "success"}, + ) + results = list(normalizer.normalize(result_event)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_RESULT + + # 工具调用不再活跃(但仍在已见列表中) + assert "call_1" not in normalizer.get_active_tool_calls() + assert "call_1" in normalizer.get_seen_tool_calls() + + def test_multiple_concurrent_tool_calls(self): + """测试多个并发工具调用追踪""" + normalizer = AguiEventNormalizer() + + # 开始两个工具调用 + for tool_id in ["call_a", "call_b"]: + event = AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": tool_id, + "name": f"tool_{tool_id}", + "args_delta": "{}", + }, + ) + list(normalizer.normalize(event)) + + # 两个都应该是活跃的 + assert len(normalizer.get_active_tool_calls()) == 2 + assert "call_a" in normalizer.get_active_tool_calls() + assert "call_b" in normalizer.get_active_tool_calls() + + # 结束其中一个 + result_event = AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "call_a", "result": "done"}, + ) + list(normalizer.normalize(result_event)) + + # call_a 不再活跃,call_b 仍然活跃 + assert len(normalizer.get_active_tool_calls()) == 1 + assert "call_a" not in normalizer.get_active_tool_calls() + assert "call_b" in normalizer.get_active_tool_calls() + + def test_string_input_converted_to_text(self): + """测试字符串输入自动转换为文本事件""" + normalizer = AguiEventNormalizer() + + results = list(normalizer.normalize("Hello")) + + assert len(results) == 1 + assert results[0].event == EventType.TEXT + assert results[0].data["delta"] == "Hello" + + def test_dict_input_converted_to_agent_event(self): + """测试字典输入自动转换为 AgentEvent""" + normalizer = AguiEventNormalizer() + + event_dict = { + "event": EventType.TEXT, + "data": {"delta": "Hello from dict"}, + } + results = list(normalizer.normalize(event_dict)) + + assert len(results) == 1 + assert results[0].event == EventType.TEXT + assert results[0].data["delta"] == "Hello from dict" + + def test_dict_with_string_event_type(self): + """测试字符串事件类型的字典转换""" + normalizer = AguiEventNormalizer() + + event_dict = { + "event": "CUSTOM", + "data": {"name": "test"}, + } + results = list(normalizer.normalize(event_dict)) + + assert len(results) == 1 + assert results[0].event == EventType.CUSTOM + + def test_invalid_dict_returns_nothing(self): + """测试无效字典不产生事件""" + normalizer = AguiEventNormalizer() + + # 缺少 event 字段 + event_dict = {"data": {"delta": "Hello"}} + results = list(normalizer.normalize(event_dict)) + + assert len(results) == 0 + + def test_dict_with_invalid_event_type_value(self): + """测试字典中无效的事件类型值""" + normalizer = AguiEventNormalizer() + + # 事件类型既不是有效的 EventType 值也不是有效的枚举名称 + event_dict = { + "event": "INVALID_EVENT_TYPE", + "data": {"delta": "Hello"}, + } + results = list(normalizer.normalize(event_dict)) + + # 无效事件类型应该返回空 + assert len(results) == 0 + + def test_dict_with_invalid_event_type_key(self): + """测试字典中无效的事件类型键(尝试通过枚举名称)""" + normalizer = AguiEventNormalizer() + + # 尝试使用无效的枚举名称 + event_dict = { + "event": "NONEXISTENT", + "data": {"delta": "Hello"}, + } + results = list(normalizer.normalize(event_dict)) + + # 无效事件类型应该返回空 + assert len(results) == 0 + + def test_normalize_with_non_standard_input(self): + """测试非标准输入类型""" + normalizer = AguiEventNormalizer() + + # 传入整数 + results = list(normalizer.normalize(123)) + assert len(results) == 0 + + # 传入 None + results = list(normalizer.normalize(None)) + assert len(results) == 0 + + # 传入列表 + results = list(normalizer.normalize([1, 2, 3])) + assert len(results) == 0 + + def test_reset_clears_state(self): + """测试 reset 清空状态""" + normalizer = AguiEventNormalizer() + + # 添加一些状态 + event = AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "call_1", "name": "test", "args_delta": "{}"}, + ) + list(normalizer.normalize(event)) + assert len(normalizer.get_active_tool_calls()) == 1 + assert len(normalizer.get_seen_tool_calls()) == 1 + + # 重置 + normalizer.reset() + + # 状态应该清空 + assert len(normalizer.get_active_tool_calls()) == 0 + assert len(normalizer.get_seen_tool_calls()) == 0 + + def test_tool_call_with_empty_id(self): + """测试空 tool_call_id 的 TOOL_CALL 事件""" + normalizer = AguiEventNormalizer() + + event = AgentEvent( + event=EventType.TOOL_CALL, + data={"id": "", "name": "test", "args": "{}"}, # 空 id + ) + results = list(normalizer.normalize(event)) + + assert len(results) == 1 + # 空 id 不会被添加到追踪列表 + assert len(normalizer.get_seen_tool_calls()) == 0 + assert len(normalizer.get_active_tool_calls()) == 0 + + def test_tool_call_chunk_with_empty_id(self): + """测试空 tool_call_id 的 TOOL_CALL_CHUNK 事件""" + normalizer = AguiEventNormalizer() + + event = AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "", "name": "test", "args_delta": "{}"}, # 空 id + ) + results = list(normalizer.normalize(event)) + + assert len(results) == 1 + # 空 id 不会被添加到追踪列表 + assert len(normalizer.get_seen_tool_calls()) == 0 + + def test_tool_call_chunk_without_name(self): + """测试没有 name 的 TOOL_CALL_CHUNK 事件""" + normalizer = AguiEventNormalizer() + + event = AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "call_1", "args_delta": "{}"}, # 没有 name + ) + results = list(normalizer.normalize(event)) + + assert len(results) == 1 + # id 会被追踪,但没有 name + assert "call_1" in normalizer.get_seen_tool_calls() + # 没有 name 时不会添加到 active_tool_calls 的名称映射 + assert "call_1" not in normalizer.get_active_tool_calls() + + def test_tool_result_with_empty_id(self): + """测试空 tool_call_id 的 TOOL_RESULT 事件""" + normalizer = AguiEventNormalizer() + + # 先添加一个工具调用 + chunk_event = AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "call_1", "name": "test", "args_delta": "{}"}, + ) + list(normalizer.normalize(chunk_event)) + assert "call_1" in normalizer.get_active_tool_calls() + + # 发送空 id 的结果 + result_event = AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "", "result": "done"}, # 空 id + ) + results = list(normalizer.normalize(result_event)) + + assert len(results) == 1 + # call_1 仍然是活跃的(因为结果的 id 为空) + assert "call_1" in normalizer.get_active_tool_calls() + + def test_dict_with_event_type_enum(self): + """测试字典中直接使用 EventType 枚举(覆盖 106->115 分支)""" + normalizer = AguiEventNormalizer() + + event_dict = { + "event": EventType.TEXT, # 直接使用枚举,不是字符串 + "data": {"delta": "Hello"}, + } + results = list(normalizer.normalize(event_dict)) + + assert len(results) == 1 + assert results[0].event == EventType.TEXT + assert results[0].data["delta"] == "Hello" + + def test_dict_with_event_type_enum_custom(self): + """测试字典中直接使用 EventType.CUSTOM 枚举""" + normalizer = AguiEventNormalizer() + + event_dict = { + "event": EventType.CUSTOM, # 直接使用枚举 + "data": {"name": "test_event", "value": {"key": "value"}}, + } + results = list(normalizer.normalize(event_dict)) + + assert len(results) == 1 + assert results[0].event == EventType.CUSTOM + assert results[0].data["name"] == "test_event" + + def test_dict_with_event_type_enum_tool_call(self): + """测试字典中直接使用 EventType.TOOL_CALL 枚举""" + normalizer = AguiEventNormalizer() + + event_dict = { + "event": EventType.TOOL_CALL, # 直接使用枚举 + "data": {"id": "tc-1", "name": "test", "args": "{}"}, + } + results = list(normalizer.normalize(event_dict)) + + assert len(results) == 1 + assert results[0].event == EventType.TOOL_CALL + + def test_complete_tool_call_sequence(self): + """测试完整的工具调用序列追踪""" + normalizer = AguiEventNormalizer() + all_results = [] + + # 完整的工具调用序列 + events = [ + AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "call_1", + "name": "get_time", + "args_delta": '{"tz":', + }, + ), + AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "call_1", "args_delta": '"UTC"}'}, + ), + AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "call_1", "result": "12:00"}, + ), + ] + + for event in events: + all_results.extend(normalizer.normalize(event)) + + # 事件保持原样传递 + assert len(all_results) == 3 + event_types = [e.event for e in all_results] + assert event_types == [ + EventType.TOOL_CALL_CHUNK, + EventType.TOOL_CALL_CHUNK, + EventType.TOOL_RESULT, + ] + + # 最终状态:工具调用完成 + assert "call_1" not in normalizer.get_active_tool_calls() + assert "call_1" in normalizer.get_seen_tool_calls() + + +class TestAguiEventNormalizerWithAguiProtocol: + """使用 ag-ui-protocol 验证事件结构的测试""" + + @pytest.fixture + def ag_ui_available(self): + """检查 ag-ui-protocol 是否可用""" + try: + from ag_ui.core import ToolCallArgsEvent, ToolCallResultEvent + + return True + except ImportError: + pytest.skip("ag-ui-protocol not installed") + + def test_tool_call_events_have_valid_structure(self, ag_ui_available): + """测试工具调用事件结构有效""" + from ag_ui.core import ToolCallArgsEvent, ToolCallResultEvent + + normalizer = AguiEventNormalizer() + + events = [ + AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "call_1", "name": "test", "args_delta": '{"x": 1}'}, + ), + AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "call_1", "result": "success"}, + ), + ] + + all_results = [] + for event in events: + all_results.extend(normalizer.normalize(event)) + + # 验证 TOOL_CALL_CHUNK 可以映射到 ag-ui + chunk_result = all_results[0] + args_event = ToolCallArgsEvent( + tool_call_id=chunk_result.data["id"], + delta=chunk_result.data["args_delta"], + ) + assert args_event.tool_call_id == "call_1" + + # 验证 TOOL_RESULT 可以映射到 ag-ui + result_result = all_results[1] + result_event = ToolCallResultEvent( + message_id="msg_1", + tool_call_id=result_result.data["id"], + content=result_result.data["result"], + ) + assert result_event.tool_call_id == "call_1" diff --git a/tests/unittests/server/test_agui_protocol.py b/tests/unittests/server/test_agui_protocol.py new file mode 100644 index 0000000..3d413de --- /dev/null +++ b/tests/unittests/server/test_agui_protocol.py @@ -0,0 +1,1203 @@ +"""AG-UI 协议处理器测试 + +测试 AGUIProtocolHandler 的各种功能。 +""" + +import json +from typing import cast + +from fastapi.testclient import TestClient +import pytest + +from agentrun.server import ( + AgentEvent, + AgentRequest, + AgentRunServer, + AGUIProtocolHandler, + EventType, + ServerConfig, +) + + +class TestAGUIProtocolHandler: + """测试 AGUIProtocolHandler""" + + def test_get_prefix_default(self): + """测试默认前缀""" + handler = AGUIProtocolHandler() + assert handler.get_prefix() == "/ag-ui/agent" + + def test_get_prefix_custom(self): + """测试自定义前缀""" + from agentrun.server.model import AGUIProtocolConfig + + config = ServerConfig(agui=AGUIProtocolConfig(prefix="/custom/agui")) + handler = AGUIProtocolHandler(config) + assert handler.get_prefix() == "/custom/agui" + + +class TestAGUIProtocolEndpoints: + """测试 AG-UI 协议端点""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + @pytest.mark.asyncio + async def test_health_check(self): + """测试健康检查端点""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + client = self.get_client(invoke_agent) + response = client.get("/ag-ui/agent/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["protocol"] == "ag-ui" + assert data["version"] == "1.0" + + @pytest.mark.asyncio + async def test_value_error_handling(self): + """测试 ValueError 处理""" + + def invoke_agent(request: AgentRequest): + raise ValueError("Test error") + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该包含 RUN_ERROR 事件 + types = [] + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + types.append(data.get("type")) + + assert "RUN_ERROR" in types + + @pytest.mark.asyncio + async def test_general_exception_handling(self): + """测试一般异常处理(在 invoke_agent 中抛出异常)""" + + def invoke_agent(request: AgentRequest): + raise RuntimeError("Internal error") + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该包含 RUN_ERROR 事件 + types = [] + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + types.append(data.get("type")) + + assert "RUN_ERROR" in types + + @pytest.mark.asyncio + async def test_exception_in_parse_request(self): + """测试 parse_request 中的异常处理(覆盖 155-156 行) + + 通过发送一个会导致 parse_request 抛出非 ValueError 异常的请求 + """ + from unittest.mock import AsyncMock, patch + + def invoke_agent(request: AgentRequest): + return "Hello" + + client = self.get_client(invoke_agent) + + # 模拟 parse_request 抛出 RuntimeError + with patch.object( + AGUIProtocolHandler, + "parse_request", + new_callable=AsyncMock, + side_effect=RuntimeError("Unexpected error"), + ): + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该包含 RUN_ERROR 事件 + types = [] + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + types.append(data.get("type")) + + assert "RUN_ERROR" in types + # 错误消息应该包含 "Internal error" + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "RUN_ERROR": + assert "Internal error" in data.get("message", "") + break + + @pytest.mark.asyncio + async def test_parse_messages_with_non_dict(self): + """测试解析消息时跳过非字典项""" + + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["messages"] = request.messages + return "Done" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={ + "messages": [ + "invalid_message", # 非字典项,应该被跳过 + {"role": "user", "content": "Hello"}, + ], + }, + ) + + assert response.status_code == 200 + assert len(captured_request["messages"]) == 1 + + @pytest.mark.asyncio + async def test_parse_messages_with_invalid_role(self): + """测试解析消息时处理无效角色""" + + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["messages"] = request.messages + return "Done" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={ + "messages": [ + {"role": "invalid_role", "content": "Hello"}, # 无效角色 + ], + }, + ) + + assert response.status_code == 200 + # 无效角色应该默认为 user + from agentrun.server.model import MessageRole + + assert captured_request["messages"][0].role == MessageRole.USER + + @pytest.mark.asyncio + async def test_parse_messages_with_tool_calls(self): + """测试解析带有 toolCalls 的消息""" + + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["messages"] = request.messages + return "Done" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={ + "messages": [ + { + "id": "msg-1", + "role": "assistant", + "content": None, + "toolCalls": [{ + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Beijing"}', + }, + }], + }, + { + "id": "msg-2", + "role": "tool", + "content": "Sunny", + "toolCallId": "call_123", + }, + ], + }, + ) + + assert response.status_code == 200 + assert len(captured_request["messages"]) == 2 + assert captured_request["messages"][0].tool_calls is not None + assert captured_request["messages"][0].tool_calls[0].id == "call_123" + assert captured_request["messages"][1].tool_call_id == "call_123" + + @pytest.mark.asyncio + async def test_parse_tools(self): + """测试解析工具列表""" + + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["tools"] = request.tools + return "Done" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + }, + }], + }, + ) + + assert response.status_code == 200 + assert captured_request["tools"] is not None + assert len(captured_request["tools"]) == 1 + + @pytest.mark.asyncio + async def test_parse_tools_with_non_dict(self): + """测试解析工具列表时跳过非字典项""" + + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["tools"] = request.tools + return "Done" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "tools": [ + "invalid_tool", + {"type": "function", "function": {"name": "valid"}}, + ], + }, + ) + + assert response.status_code == 200 + assert captured_request["tools"] is not None + assert len(captured_request["tools"]) == 1 + + @pytest.mark.asyncio + async def test_parse_tools_empty_after_filter(self): + """测试解析工具列表后为空时返回 None""" + + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["tools"] = request.tools + return "Done" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "tools": ["invalid1", "invalid2"], + }, + ) + + assert response.status_code == 200 + assert captured_request["tools"] is None + + @pytest.mark.asyncio + async def test_state_delta_event(self): + """测试 STATE delta 事件""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.STATE, + data={"delta": [{"op": "add", "path": "/count", "value": 1}]}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 查找 STATE_DELTA 事件 + found_delta = False + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "STATE_DELTA": + found_delta = True + assert data["delta"] == [ + {"op": "add", "path": "/count", "value": 1} + ] + + assert found_delta + + @pytest.mark.asyncio + async def test_state_snapshot_fallback(self): + """测试 STATE 事件没有 snapshot 或 delta 时的回退""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.STATE, + data={ + "custom_key": "custom_value" + }, # 既不是 snapshot 也不是 delta + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该作为 STATE_SNAPSHOT 处理 + found_snapshot = False + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "STATE_SNAPSHOT": + found_snapshot = True + assert data["snapshot"]["custom_key"] == "custom_value" + + assert found_snapshot + + @pytest.mark.asyncio + async def test_unknown_event_type(self): + """测试未知事件类型转换为 CUSTOM""" + + # 直接测试 _process_event_with_boundaries 方法 + # 由于我们无法直接发送未知事件类型,我们测试 CUSTOM 事件的正常处理 + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.CUSTOM, + data={"name": "unknown_event", "value": {"data": "test"}}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 查找 CUSTOM 事件 + found_custom = False + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "CUSTOM": + found_custom = True + assert data["name"] == "unknown_event" + + assert found_custom + + @pytest.mark.asyncio + async def test_addition_merge_overrides(self): + """测试 addition 默认合并覆盖字段""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TEXT, + data={"delta": "Hello"}, + addition={"custom": "value", "delta": "overwritten"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 查找 TEXT_MESSAGE_CONTENT 事件 + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "TEXT_MESSAGE_CONTENT": + assert "custom" in data + assert data["delta"] == "overwritten" + break + + @pytest.mark.asyncio + async def test_addition_protocol_only_mode(self): + """测试 addition PROTOCOL_ONLY 模式""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TEXT, + data={"delta": "Hello"}, + addition={ + "delta": "overwritten", # 已存在的字段会被覆盖 + "new_field": "ignored", # 新字段会被忽略 + }, + addition_merge_options={"no_new_field": True}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 查找 TEXT_MESSAGE_CONTENT 事件 + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "TEXT_MESSAGE_CONTENT": + assert data["delta"] == "overwritten" + assert "new_field" not in data + break + + @pytest.mark.asyncio + async def test_raw_event_with_newline(self): + """测试 RAW 事件自动添加换行符""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.RAW, + data={"raw": '{"custom": "data"}'}, # 没有换行符 + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + # RAW 事件应该被正确处理 + content = response.text + assert '{"custom": "data"}' in content + + @pytest.mark.asyncio + async def test_raw_event_with_trailing_newlines(self): + """测试 RAW 事件带有尾随换行符时的处理""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.RAW, + data={"raw": '{"custom": "data"}\n'}, # 只有一个换行符 + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + content = response.text + assert '{"custom": "data"}' in content + + @pytest.mark.asyncio + async def test_raw_event_already_has_double_newline(self): + """测试 RAW 事件已经有双换行符时不再添加""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.RAW, + data={"raw": '{"custom": "data"}\n\n'}, # 已经有双换行符 + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + content = response.text + assert '{"custom": "data"}' in content + + @pytest.mark.asyncio + async def test_raw_event_empty(self): + """测试空 RAW 事件""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.RAW, + data={"raw": ""}, # 空内容 + ) + yield "Hello" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该正常完成,包含 Hello 文本 + types = [] + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + types.append(data.get("type")) + + assert "TEXT_MESSAGE_CONTENT" in types + + @pytest.mark.asyncio + async def test_thread_id_and_run_id_from_request(self): + """测试使用请求中的 threadId 和 runId""" + + async def invoke_agent(request: AgentRequest): + yield "Hello" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "threadId": "custom-thread-123", + "runId": "custom-run-456", + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 检查 RUN_STARTED 事件 + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "RUN_STARTED": + assert data["threadId"] == "custom-thread-123" + assert data["runId"] == "custom-run-456" + break + + @pytest.mark.asyncio + async def test_new_text_message_after_tool_result(self): + """测试工具结果后的新文本消息有新的 messageId""" + + async def invoke_agent(request: AgentRequest): + yield "First message" + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "done"}, + ) + # 这应该触发 TEXT_MESSAGE_END 然后 TEXT_MESSAGE_START(新消息) + yield "Second message" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 收集所有 messageId + message_ids = [] + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") in [ + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + ]: + message_ids.append(data.get("messageId")) + + # 第一段和第二段文本应该有相同的 messageId(AG-UI 允许并行) + # 实际上在当前实现中,工具调用后的文本是同一个消息的延续 + assert len(message_ids) >= 2 + + +class TestAGUIProtocolErrorStream: + """测试 AG-UI 协议错误流""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + @pytest.mark.asyncio + async def test_error_stream_on_json_parse_error(self): + """测试 JSON 解析错误时的错误流""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + client = self.get_client(invoke_agent) + # 发送无效的 JSON + response = client.post( + "/ag-ui/agent", + content="invalid json", + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该包含 RUN_STARTED 和 RUN_ERROR + types = [] + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + types.append(data.get("type")) + + assert "RUN_STARTED" in types + assert "RUN_ERROR" in types + + @pytest.mark.asyncio + async def test_error_stream_format(self): + """测试错误流的格式""" + + def invoke_agent(request: AgentRequest): + raise ValueError("Test validation error") + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 检查错误事件的格式 + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "RUN_ERROR": + # 错误事件应该包含 message + assert "message" in data + break + + +class TestAGUIProtocolApplyAddition: + """测试 _apply_addition 方法""" + + def test_apply_addition_default_merge(self): + """默认合并应覆盖已有字段""" + handler = AGUIProtocolHandler() + + event_data = {"delta": "Hello", "type": "TEXT_MESSAGE_CONTENT"} + addition = {"delta": "overwritten", "new_field": "added"} + + result = handler._apply_addition( + event_data.copy(), + addition.copy(), + ) + + assert result["delta"] == "overwritten" + assert result["new_field"] == "added" + assert result["type"] == "TEXT_MESSAGE_CONTENT" + + def test_apply_addition_merge_options_none(self): + """显式传入 merge_options=None 仍按默认合并""" + handler = AGUIProtocolHandler() + + event_data = {"delta": "Hello", "type": "TEXT_MESSAGE_CONTENT"} + addition = {"delta": "overwritten", "new_field": "added"} + + result = handler._apply_addition( + event_data.copy(), + addition.copy(), + ) + + assert result["delta"] == "overwritten" + assert result["new_field"] == "added" + + def test_apply_addition_protocol_only_mode(self): + """测试 PROTOCOL_ONLY 模式(覆盖 616->620 分支)""" + handler = AGUIProtocolHandler() + + event_data = {"delta": "Hello", "type": "TEXT_MESSAGE_CONTENT"} + addition = {"delta": "overwritten", "new_field": "ignored"} + + result = handler._apply_addition( + event_data.copy(), + addition.copy(), + {"no_new_field": True}, + ) + + # delta 被覆盖 + assert result["delta"] == "overwritten" + # new_field 不存在(被忽略) + assert "new_field" not in result + # type 保持不变 + assert result["type"] == "TEXT_MESSAGE_CONTENT" + + +class TestAGUIProtocolConvertMessages: + """测试消息转换功能""" + + def test_convert_messages_for_snapshot(self): + """测试 _convert_messages_for_snapshot 方法""" + handler = AGUIProtocolHandler() + + messages = [ + {"role": "user", "content": "Hello", "id": "msg-1"}, + {"role": "assistant", "content": "Hi there", "id": "msg-2"}, + {"role": "system", "content": "You are helpful", "id": "msg-3"}, + { + "role": "tool", + "content": "Result", + "id": "msg-4", + "tool_call_id": "tc-1", + }, + ] + + result = handler._convert_messages_for_snapshot(messages) + + assert len(result) == 4 + assert result[0].role == "user" + assert result[1].role == "assistant" + assert result[2].role == "system" + assert result[3].role == "tool" + + def test_convert_messages_with_non_dict(self): + """测试 _convert_messages_for_snapshot 跳过非字典项""" + handler = AGUIProtocolHandler() + + messages = [ + "invalid", + {"role": "user", "content": "Hello", "id": "msg-1"}, + 123, + {"role": "assistant", "content": "Hi", "id": "msg-2"}, + ] + + result = handler._convert_messages_for_snapshot(messages) + + assert len(result) == 2 + + def test_convert_messages_with_missing_id(self): + """测试 _convert_messages_for_snapshot 自动生成 id""" + handler = AGUIProtocolHandler() + + messages = [ + {"role": "user", "content": "Hello"}, # 没有 id + ] + + result = handler._convert_messages_for_snapshot(messages) + + assert len(result) == 1 + assert result[0].id is not None + assert len(result[0].id) > 0 + + def test_convert_messages_with_unknown_role(self): + """测试 _convert_messages_for_snapshot 处理未知角色""" + handler = AGUIProtocolHandler() + + messages = [ + {"role": "unknown", "content": "Hello", "id": "msg-1"}, + {"role": "user", "content": "World", "id": "msg-2"}, + ] + + result = handler._convert_messages_for_snapshot(messages) + + # 未知角色的消息会被跳过 + assert len(result) == 1 + assert result[0].role == "user" + + +class TestAGUIProtocolUnknownEvent: + """测试未知事件类型处理""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + @pytest.mark.asyncio + async def test_tool_result_event(self): + """测试 TOOL_RESULT 事件(使用 content 字段)""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={ + "id": "tc-1", + "content": "Tool content result", + }, # 使用 content + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 查找 TOOL_CALL_RESULT 事件 + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "TOOL_CALL_RESULT": + assert data["content"] == "Tool content result" + break + + +class TestAGUIProtocolTextMessageRestart: + """测试文本消息重新开始""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + @pytest.mark.asyncio + async def test_text_message_restart_after_end(self): + """测试文本消息结束后重新开始新消息 + + 当 text_state["ended"] 为 True 时,新的 TEXT 事件应该开始新消息 + """ + + async def invoke_agent(request: AgentRequest): + # 第一段文本 + yield "First" + # 工具调用会导致文本消息结束 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "done"}, + ) + # 第二段文本应该开始新消息 + yield "Second" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 统计 TEXT_MESSAGE_START 事件数量 + start_count = 0 + message_ids = set() + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "TEXT_MESSAGE_START": + start_count += 1 + message_ids.add(data.get("messageId")) + + # 应该只有一个 TEXT_MESSAGE_START(AG-UI 允许并行) + # 但如果实现了重新开始逻辑,可能有两个 + assert start_count >= 1 + + @pytest.mark.asyncio + async def test_state_delta_event(self): + """测试 STATE delta 事件""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.STATE, + data={ + "delta": [{"op": "replace", "path": "/count", "value": 10}] + }, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 查找 STATE_DELTA 事件 + found = False + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + if data.get("type") == "STATE_DELTA": + found = True + assert data["delta"][0]["op"] == "replace" + break + + assert found + + +class TestAGUIProtocolExceptionHandling: + """测试异常处理""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + @pytest.mark.asyncio + async def test_general_exception_returns_error_stream(self): + """测试一般异常返回错误流""" + + def invoke_agent(request: AgentRequest): + raise RuntimeError("Unexpected error") + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该包含 RUN_ERROR + types = [ + json.loads(line[6:]).get("type") + for line in lines + if line.startswith("data: ") + ] + assert "RUN_ERROR" in types + + +class TestAGUIProtocolUnknownEventType: + """测试未知事件类型处理(覆盖第 531 行) + + TOOL_CALL 事件没有在 _process_event_with_boundaries 中被显式处理, + 所以它会走到第 531 行的 else 分支,被转换为 CUSTOM 事件。 + """ + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + def test_process_event_with_boundaries_unknown_event(self): + """直接测试 _process_event_with_boundaries 处理未知事件类型 + + TOOL_CALL 事件没有在 _process_event_with_boundaries 中被显式处理, + 所以它会走到 else 分支,被转换为 CUSTOM 事件。 + """ + handler = AGUIProtocolHandler() + + # 创建 TOOL_CALL 事件(在协议层没有被显式处理) + event = AgentEvent( + event=EventType.TOOL_CALL, + data={"id": "tc-1", "name": "test_tool", "args": '{"x": 1}'}, + ) + + context = {"thread_id": "test-thread", "run_id": "test-run"} + + # 创建 StreamStateMachine 对象 + from agentrun.server.agui_protocol import StreamStateMachine + + state = StreamStateMachine(copilotkit_compatibility=False) + + # 调用方法 + results = list( + handler._process_event_with_boundaries(event, context, state) + ) + + # TOOL_CALL 现在被实际处理,生成 TOOL_CALL_START 和 TOOL_CALL_ARGS 事件 + assert len(results) == 2 + # 解析第一个 SSE 数据 (TOOL_CALL_START) + sse_data_1 = results[0] + assert sse_data_1.startswith("data: ") + data_1 = json.loads(sse_data_1[6:].strip()) + assert data_1["type"] == "TOOL_CALL_START" + assert data_1["toolCallId"] == "tc-1" + assert data_1["toolCallName"] == "test_tool" + + # 解析第二个 SSE 数据 (TOOL_CALL_ARGS) + sse_data_2 = results[1] + assert sse_data_2.startswith("data: ") + data_2 = json.loads(sse_data_2[6:].strip()) + assert data_2["type"] == "TOOL_CALL_ARGS" + assert data_2["toolCallId"] == "tc-1" + assert data_2["delta"] == '{"x": 1}' + + @pytest.mark.asyncio + async def test_tool_call_event_expanded_by_invoker(self): + """测试 TOOL_CALL 事件被 invoker 展开为 TOOL_CALL_CHUNK + + 通过端到端测试验证 TOOL_CALL 被正确处理。 + """ + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL, + data={"id": "tc-1", "name": "test_tool", "args": '{"x": 1}'}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # TOOL_CALL 会被 invoker 展开为 TOOL_CALL_CHUNK + # 所以应该有 TOOL_CALL_START 和 TOOL_CALL_ARGS 事件 + types = [] + for line in lines: + if line.startswith("data: "): + data = json.loads(line[6:]) + types.append(data.get("type")) + + assert "TOOL_CALL_START" in types + assert "TOOL_CALL_ARGS" in types + + +class TestAGUIProtocolToolCallBranches: + """测试工具调用的各种分支""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + @pytest.mark.asyncio + async def test_tool_result_for_already_ended_tool(self): + """测试对已结束的工具调用发送结果""" + + async def invoke_agent(request: AgentRequest): + # 工具调用 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool", "args_delta": "{}"}, + ) + # 第一个结果(会结束工具调用) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "result1"}, + ) + # 第二个结果(工具调用已结束) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "result2"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该有两个 TOOL_CALL_RESULT + result_count = sum( + 1 + for line in lines + if line.startswith("data: ") and "TOOL_CALL_RESULT" in line + ) + assert result_count == 2 + + @pytest.mark.asyncio + async def test_empty_sse_data_filtered(self): + """测试空 SSE 数据被过滤""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.RAW, + data={"raw": ""}, # 空内容 + ) + yield "Hello" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该正常完成 + types = [ + json.loads(line[6:]).get("type") + for line in lines + if line.startswith("data: ") + ] + assert "RUN_FINISHED" in types + + @pytest.mark.asyncio + async def test_tool_call_with_empty_id(self): + """测试空 tool_id 的工具调用""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "", "name": "tool", "args_delta": "{}"}, # 空 id + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_tool_result_without_prior_start(self): + """测试没有先发送 TOOL_CALL_CHUNK 的 TOOL_RESULT""" + + async def invoke_agent(request: AgentRequest): + # 直接发送 TOOL_RESULT,没有 TOOL_CALL_CHUNK + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-orphan", "result": "orphan result"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该自动补充 TOOL_CALL_START 和 TOOL_CALL_END + types = [ + json.loads(line[6:]).get("type") + for line in lines + if line.startswith("data: ") + ] + assert "TOOL_CALL_START" in types + assert "TOOL_CALL_END" in types + assert "TOOL_CALL_RESULT" in types diff --git a/tests/unittests/server/test_invoker.py b/tests/unittests/server/test_invoker.py new file mode 100644 index 0000000..76d8c65 --- /dev/null +++ b/tests/unittests/server/test_invoker.py @@ -0,0 +1,438 @@ +"""Agent Invoker 单元测试 + +测试 AgentInvoker 的各种调用场景。 + +新设计:invoker 只输出核心事件(TEXT, TOOL_CALL_CHUNK 等), +边界事件(LIFECYCLE_START/END, TEXT_MESSAGE_START/END 等)由协议层自动生成。 +""" + +from typing import AsyncGenerator, List + +import pytest + +from agentrun.server.invoker import AgentInvoker +from agentrun.server.model import AgentEvent, AgentRequest, EventType + + +class TestInvokerBasic: + """基本调用测试""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_async_generator_returns_stream(self, req): + """测试异步生成器返回流式结果""" + + async def invoke_agent(req: AgentRequest) -> AsyncGenerator[str, None]: + yield "hello" + yield " world" + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + # 结果应该是异步生成器 + assert hasattr(result, "__aiter__") + + # 收集所有结果 + items: List[AgentEvent] = [] + async for item in result: + items.append(item) + + # 应该有 2 个 TEXT 事件(不再有边界事件) + assert len(items) == 2 + + content_events = [ + item for item in items if item.event == EventType.TEXT + ] + assert len(content_events) == 2 + assert content_events[0].data["delta"] == "hello" + assert content_events[1].data["delta"] == " world" + + @pytest.mark.asyncio + async def test_text_events_structure(self, req): + """测试 TEXT 事件结构正确""" + + async def invoke_agent(req: AgentRequest) -> AsyncGenerator[str, None]: + yield "Hello" + yield " " + yield "World" + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + items: List[AgentEvent] = [] + async for item in result: + items.append(item) + + # 应该只有 TEXT 事件 + assert all(item.event == EventType.TEXT for item in items) + assert len(items) == 3 + + # 验证 delta 内容 + deltas = [item.data["delta"] for item in items] + assert deltas == ["Hello", " ", "World"] + + @pytest.mark.asyncio + async def test_async_coroutine_returns_list(self, req): + """测试异步协程返回列表结果""" + + async def invoke_agent(req: AgentRequest) -> str: + return "world" + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + # 非流式返回应该是列表 + assert isinstance(result, list) + + # 应该只包含 TEXT 事件(无边界事件) + assert len(result) == 1 + assert result[0].event == EventType.TEXT + assert result[0].data["delta"] == "world" + + +class TestInvokerStream: + """invoke_stream 方法测试""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_invoke_stream_with_string(self, req): + """测试 invoke_stream 返回核心事件""" + + async def invoke_agent(req: AgentRequest) -> str: + return "hello" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + # 应该只包含 TEXT 事件(边界事件由协议层生成) + event_types = [item.event for item in items] + assert EventType.TEXT in event_types + assert len(items) == 1 + + @pytest.mark.asyncio + async def test_invoke_stream_with_agent_event(self, req): + """测试返回 AgentEvent 事件""" + + async def invoke_agent( + req: AgentRequest, + ) -> AsyncGenerator[AgentEvent, None]: + yield AgentEvent( + event=EventType.CUSTOM, + data={"name": "step_started", "value": {"step": "test"}}, + ) + yield AgentEvent( + event=EventType.TEXT, + data={"delta": "hello"}, + ) + yield AgentEvent( + event=EventType.CUSTOM, + data={"name": "step_finished", "value": {"step": "test"}}, + ) + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + event_types = [item.event for item in items] + + # 应该包含用户返回的事件 + assert EventType.CUSTOM in event_types + assert EventType.TEXT in event_types + assert len(items) == 3 + + @pytest.mark.asyncio + async def test_invoke_stream_error_handling(self, req): + """测试错误处理""" + + async def invoke_agent(req: AgentRequest) -> str: + raise ValueError("Test error") + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + event_types = [item.event for item in items] + + # 应该包含 ERROR 事件 + assert EventType.ERROR in event_types + + # 检查错误信息 + error_event = next( + item for item in items if item.event == EventType.ERROR + ) + assert "Test error" in error_event.data["message"] + assert error_event.data["code"] == "ValueError" + + +class TestInvokerSync: + """同步调用测试""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_sync_generator(self, req): + """测试同步生成器""" + + def invoke_agent(req: AgentRequest): + yield "hello" + yield " world" + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + # 结果应该是异步生成器 + assert hasattr(result, "__aiter__") + + items: List[AgentEvent] = [] + async for item in result: + items.append(item) + + content_events = [ + item for item in items if item.event == EventType.TEXT + ] + assert len(content_events) == 2 + + @pytest.mark.asyncio + async def test_sync_return(self): + """测试同步函数返回字符串""" + + def invoke_agent(req: AgentRequest) -> str: + return "sync result" + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(AgentRequest(messages=[])) + + assert isinstance(result, list) + # 只有一个 TEXT 事件(无边界事件) + assert len(result) == 1 + + content_event = result[0] + assert content_event.event == EventType.TEXT + assert content_event.data["delta"] == "sync result" + + +class TestInvokerMixed: + """混合内容测试""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_mixed_string_and_events(self, req): + """测试混合字符串和事件""" + + async def invoke_agent(req: AgentRequest): + yield "Hello, " + yield AgentEvent( + event=EventType.TOOL_CALL, + data={"id": "tc-1", "name": "test", "args": "{}"}, + ) + yield "world!" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + event_types = [item.event for item in items] + + # 应该包含文本和工具调用事件 + # TOOL_CALL 被展开为 TOOL_CALL_CHUNK + assert EventType.TEXT in event_types + assert EventType.TOOL_CALL_CHUNK in event_types + + # 验证内容 + text_events = [i for i in items if i.event == EventType.TEXT] + assert len(text_events) == 2 + assert text_events[0].data["delta"] == "Hello, " + assert text_events[1].data["delta"] == "world!" + + tool_events = [i for i in items if i.event == EventType.TOOL_CALL_CHUNK] + assert len(tool_events) == 1 + assert tool_events[0].data["id"] == "tc-1" + assert tool_events[0].data["name"] == "test" + + @pytest.mark.asyncio + async def test_empty_string_ignored(self, req): + """测试空字符串被忽略""" + + async def invoke_agent(req: AgentRequest): + yield "" + yield "hello" + yield "" + yield "world" + yield "" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + content_events = [ + item for item in items if item.event == EventType.TEXT + ] + # 只有两个非空字符串 + assert len(content_events) == 2 + assert content_events[0].data["delta"] == "hello" + assert content_events[1].data["delta"] == "world" + + +class TestInvokerNone: + """None 值处理测试""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_none_return(self, req): + """测试返回 None""" + + async def invoke_agent(req: AgentRequest): + return None + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_none_in_stream(self, req): + """测试流中的 None 被忽略""" + + async def invoke_agent(req: AgentRequest): + yield None + yield "hello" + yield None + yield "world" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + content_events = [ + item for item in items if item.event == EventType.TEXT + ] + assert len(content_events) == 2 + + +class TestInvokerToolCall: + """工具调用测试""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_tool_call_expansion(self, req): + """测试 TOOL_CALL 被展开为 TOOL_CALL_CHUNK""" + + async def invoke_agent(req: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL, + data={ + "id": "call-123", + "name": "get_weather", + "args": '{"city": "Beijing"}', + }, + ) + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + # TOOL_CALL 被展开为 TOOL_CALL_CHUNK + assert len(items) == 1 + assert items[0].event == EventType.TOOL_CALL_CHUNK + assert items[0].data["id"] == "call-123" + assert items[0].data["name"] == "get_weather" + assert items[0].data["args_delta"] == '{"city": "Beijing"}' + + @pytest.mark.asyncio + async def test_tool_call_chunk_passthrough(self, req): + """测试 TOOL_CALL_CHUNK 直接透传""" + + async def invoke_agent(req: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "call-456", + "name": "search", + "args_delta": '{"query":', + }, + ) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "call-456", + "args_delta": '"hello"}', + }, + ) + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 2 + assert all(i.event == EventType.TOOL_CALL_CHUNK for i in items) diff --git a/tests/unittests/server/test_invoker_extended.py b/tests/unittests/server/test_invoker_extended.py new file mode 100644 index 0000000..5f65a29 --- /dev/null +++ b/tests/unittests/server/test_invoker_extended.py @@ -0,0 +1,722 @@ +"""Invoker 扩展测试 + +测试 AgentInvoker 的更多边界情况。 +""" + +from typing import List + +import pytest + +from agentrun.server.invoker import AgentInvoker +from agentrun.server.model import AgentEvent, AgentRequest, EventType + + +class TestInvokerListReturn: + """测试列表返回值处理""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_list_of_agent_events(self, req): + """测试返回 AgentEvent 列表""" + + def invoke_agent(req: AgentRequest): + return [ + AgentEvent(event=EventType.TEXT, data={"delta": "Hello"}), + AgentEvent(event=EventType.TEXT, data={"delta": " World"}), + ] + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 2 + assert result[0].event == EventType.TEXT + assert result[0].data["delta"] == "Hello" + assert result[1].data["delta"] == " World" + + @pytest.mark.asyncio + async def test_list_of_strings(self, req): + """测试返回字符串列表""" + + def invoke_agent(req: AgentRequest): + return ["Hello", "", "World"] # 空字符串被过滤 + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + # 注意:列表返回值中的空字符串不会被过滤,只有 `if item` 检查 + # 实际上 " " 不是空字符串,所以不会被过滤 + assert len(result) == 2 # 空字符串被过滤 + assert result[0].event == EventType.TEXT + assert result[0].data["delta"] == "Hello" + assert result[1].data["delta"] == "World" + + @pytest.mark.asyncio + async def test_list_of_mixed_items(self, req): + """测试返回混合列表""" + + def invoke_agent(req: AgentRequest): + return [ + "Hello", + AgentEvent( + event=EventType.TOOL_CALL, + data={"id": "tc-1", "name": "test", "args": "{}"}, + ), + "World", + ] + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 3 + assert result[0].event == EventType.TEXT + assert result[1].event == EventType.TOOL_CALL_CHUNK # TOOL_CALL 被展开 + assert result[2].event == EventType.TEXT + + @pytest.mark.asyncio + async def test_list_with_empty_strings(self, req): + """测试列表中的空字符串被过滤""" + + def invoke_agent(req: AgentRequest): + return ["Hello", "", "World", ""] + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 2 + assert result[0].data["delta"] == "Hello" + assert result[1].data["delta"] == "World" + + +class TestInvokerAsyncHandler: + """测试异步处理器的各种情况""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_async_function_returning_non_awaitable(self, req): + """测试异步函数返回非 awaitable 值""" + + # 这种情况在实际中不太可能发生,但代码中有处理 + async def invoke_agent(req: AgentRequest): + # 直接返回字符串(不是 awaitable) + return "Hello" + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].data["delta"] == "Hello" + + +class TestInvokerWrapStream: + """测试 _wrap_stream 方法""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_wrap_stream_with_none(self, req): + """测试流中的 None 值被过滤""" + + async def invoke_agent(req: AgentRequest): + yield None + yield "Hello" + yield None + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + items: List[AgentEvent] = [] + async for item in result: + items.append(item) + + assert len(items) == 1 + assert items[0].data["delta"] == "Hello" + + @pytest.mark.asyncio + async def test_wrap_stream_with_empty_string(self, req): + """测试流中的空字符串被过滤""" + + async def invoke_agent(req: AgentRequest): + yield "" + yield "Hello" + yield "" + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + items: List[AgentEvent] = [] + async for item in result: + items.append(item) + + assert len(items) == 1 + assert items[0].data["delta"] == "Hello" + + @pytest.mark.asyncio + async def test_wrap_stream_with_agent_event(self, req): + """测试流中的 AgentEvent""" + + async def invoke_agent(req: AgentRequest): + yield AgentEvent(event=EventType.TEXT, data={"delta": "Hello"}) + yield AgentEvent( + event=EventType.TOOL_CALL, + data={"id": "tc-1", "name": "test", "args": "{}"}, + ) + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + items: List[AgentEvent] = [] + async for item in result: + items.append(item) + + assert len(items) == 2 + assert items[0].event == EventType.TEXT + assert items[1].event == EventType.TOOL_CALL_CHUNK # TOOL_CALL 被展开 + + +class TestInvokerIterateAsync: + """测试 _iterate_async 方法""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_iterate_sync_generator(self, req): + """测试迭代同步生成器""" + + def invoke_agent(req: AgentRequest): + yield "Hello" + yield "World" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 2 + assert items[0].data["delta"] == "Hello" + assert items[1].data["delta"] == "World" + + @pytest.mark.asyncio + async def test_iterate_async_generator(self, req): + """测试迭代异步生成器""" + + async def invoke_agent(req: AgentRequest): + yield "Hello" + yield "World" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 2 + assert items[0].data["delta"] == "Hello" + assert items[1].data["delta"] == "World" + + +class TestInvokerIsIterator: + """测试 _is_iterator 方法""" + + def test_is_iterator_with_agent_event(self): + """测试 AgentEvent 不是迭代器""" + + def invoke_agent(req: AgentRequest): + return "Hello" + + invoker = AgentInvoker(invoke_agent) + + event = AgentEvent(event=EventType.TEXT, data={"delta": "Hello"}) + assert invoker._is_iterator(event) is False + + def test_is_iterator_with_string(self): + """测试字符串不是迭代器""" + + def invoke_agent(req: AgentRequest): + return "Hello" + + invoker = AgentInvoker(invoke_agent) + assert invoker._is_iterator("Hello") is False + + def test_is_iterator_with_bytes(self): + """测试字节不是迭代器""" + + def invoke_agent(req: AgentRequest): + return "Hello" + + invoker = AgentInvoker(invoke_agent) + assert invoker._is_iterator(b"Hello") is False + + def test_is_iterator_with_dict(self): + """测试字典不是迭代器""" + + def invoke_agent(req: AgentRequest): + return "Hello" + + invoker = AgentInvoker(invoke_agent) + assert invoker._is_iterator({"key": "value"}) is False + + def test_is_iterator_with_list(self): + """测试列表不是迭代器""" + + def invoke_agent(req: AgentRequest): + return "Hello" + + invoker = AgentInvoker(invoke_agent) + assert invoker._is_iterator([1, 2, 3]) is False + + def test_is_iterator_with_generator(self): + """测试生成器是迭代器""" + + def invoke_agent(req: AgentRequest): + return "Hello" + + invoker = AgentInvoker(invoke_agent) + + def gen(): + yield 1 + + assert invoker._is_iterator(gen()) is True + + def test_is_iterator_with_async_generator(self): + """测试异步生成器是迭代器""" + + def invoke_agent(req: AgentRequest): + return "Hello" + + invoker = AgentInvoker(invoke_agent) + + async def async_gen(): + yield 1 + + assert invoker._is_iterator(async_gen()) is True + + +class TestInvokerProcessUserEvent: + """测试 _process_user_event 方法""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_tool_call_expansion_with_missing_id(self, req): + """测试 TOOL_CALL 展开时缺少 id""" + + async def invoke_agent(req: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL, + data={"name": "test", "args": "{}"}, # 没有 id + ) + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 1 + assert items[0].event == EventType.TOOL_CALL_CHUNK + # id 应该被自动生成(UUID) + assert items[0].data["id"] is not None + assert len(items[0].data["id"]) > 0 + + @pytest.mark.asyncio + async def test_other_events_passthrough(self, req): + """测试其他事件直接传递""" + + async def invoke_agent(req: AgentRequest): + yield AgentEvent( + event=EventType.CUSTOM, + data={"name": "test", "value": {"data": 123}}, + ) + yield AgentEvent( + event=EventType.STATE, + data={"snapshot": {"key": "value"}}, + ) + yield AgentEvent( + event=EventType.ERROR, + data={"message": "Test error", "code": "TEST"}, + ) + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 3 + assert items[0].event == EventType.CUSTOM + assert items[1].event == EventType.STATE + assert items[2].event == EventType.ERROR + + +class TestInvokerSyncHandler: + """测试同步处理器""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_sync_handler_returning_string(self, req): + """测试同步处理器返回字符串""" + + def invoke_agent(req: AgentRequest) -> str: + return "Hello from sync" + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].data["delta"] == "Hello from sync" + + @pytest.mark.asyncio + async def test_sync_handler_in_invoke_stream(self, req): + """测试同步处理器在 invoke_stream 中""" + + def invoke_agent(req: AgentRequest) -> str: + return "Hello from sync" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 1 + assert items[0].data["delta"] == "Hello from sync" + + +class TestInvokerAsyncNonAwaitable: + """测试异步函数返回非 awaitable 值""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_async_function_returning_list_directly(self, req): + """测试异步函数直接返回列表(非 awaitable) + + 这种情况在实际中不太可能发生,但代码中有处理 + """ + + # 创建一个返回列表的异步函数 + async def invoke_agent(req: AgentRequest): + # 直接返回列表(不是 awaitable) + return [ + AgentEvent(event=EventType.TEXT, data={"delta": "Hello"}), + ] + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 1 + + +class TestInvokerStreamError: + """测试流式错误处理""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_error_in_async_generator(self, req): + """测试异步生成器中的错误""" + + async def invoke_agent(req: AgentRequest): + yield "Hello" + raise RuntimeError("Test error in generator") + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + # 应该有文本事件和错误事件 + assert len(items) == 2 + assert items[0].event == EventType.TEXT + assert items[1].event == EventType.ERROR + assert "Test error in generator" in items[1].data["message"] + + @pytest.mark.asyncio + async def test_error_in_sync_generator(self, req): + """测试同步生成器中的错误""" + + def invoke_agent(req: AgentRequest): + yield "Hello" + raise ValueError("Test error in sync generator") + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + # 应该有文本事件和错误事件 + assert len(items) == 2 + assert items[0].event == EventType.TEXT + assert items[1].event == EventType.ERROR + + +class TestInvokerStreamAgentEvent: + """测试流式返回 AgentEvent 的情况(覆盖 123->111 和 267->255 分支)""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_stream_with_agent_event_in_async_generator(self, req): + """测试异步生成器中返回 AgentEvent""" + + async def invoke_agent(req: AgentRequest): + yield AgentEvent(event=EventType.TEXT, data={"delta": "Hello"}) + yield AgentEvent(event=EventType.TEXT, data={"delta": " World"}) + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 2 + assert items[0].event == EventType.TEXT + assert items[0].data["delta"] == "Hello" + assert items[1].data["delta"] == " World" + + @pytest.mark.asyncio + async def test_stream_with_mixed_types_in_async_generator(self, req): + """测试异步生成器中混合返回字符串和 AgentEvent""" + + async def invoke_agent(req: AgentRequest): + yield "Hello" + yield AgentEvent( + event=EventType.TOOL_CALL, + data={"id": "tc-1", "name": "test", "args": "{}"}, + ) + yield " World" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 3 + assert items[0].event == EventType.TEXT + assert items[0].data["delta"] == "Hello" + assert items[1].event == EventType.TOOL_CALL_CHUNK # TOOL_CALL 被展开 + assert items[2].event == EventType.TEXT + assert items[2].data["delta"] == " World" + + @pytest.mark.asyncio + async def test_stream_with_agent_event_in_sync_generator(self, req): + """测试同步生成器中返回 AgentEvent""" + + def invoke_agent(req: AgentRequest): + yield AgentEvent(event=EventType.TEXT, data={"delta": "Hello"}) + yield AgentEvent( + event=EventType.CUSTOM, data={"name": "test", "value": {}} + ) + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 2 + assert items[0].event == EventType.TEXT + assert items[1].event == EventType.CUSTOM + + +class TestInvokerNonStreamList: + """测试非流式返回列表的情况(覆盖 230->242 分支)""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_return_list_with_only_agent_events(self, req): + """测试返回只包含 AgentEvent 的列表""" + + def invoke_agent(req: AgentRequest): + return [ + AgentEvent(event=EventType.TEXT, data={"delta": "Hello"}), + AgentEvent(event=EventType.TEXT, data={"delta": " World"}), + ] + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_return_list_with_only_strings(self, req): + """测试返回只包含字符串的列表""" + + def invoke_agent(req: AgentRequest): + return ["Hello", "World"] + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 2 + assert result[0].event == EventType.TEXT + assert result[0].data["delta"] == "Hello" + + @pytest.mark.asyncio + async def test_return_list_with_mixed_types(self, req): + """测试返回混合类型的列表""" + + def invoke_agent(req: AgentRequest): + return [ + "Hello", + AgentEvent(event=EventType.CUSTOM, data={"name": "test"}), + "World", + ] + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 3 + assert result[0].event == EventType.TEXT + assert result[1].event == EventType.CUSTOM + assert result[2].event == EventType.TEXT + + +class TestInvokerAsyncNonIterator: + """测试异步函数返回非迭代器值的情况(覆盖 196 分支)""" + + @pytest.fixture + def req(self): + return AgentRequest( + messages=[], + tools=[], + stream=False, + raw_request=None, + protocol="unknown", + ) + + @pytest.mark.asyncio + async def test_async_function_returning_list(self, req): + """测试异步函数返回列表""" + + async def invoke_agent(req: AgentRequest): + return [AgentEvent(event=EventType.TEXT, data={"delta": "Test"})] + + invoker = AgentInvoker(invoke_agent) + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_handler_detected_as_async_but_returns_non_awaitable( + self, req + ): + """测试被检测为异步但返回非 awaitable 值的处理器(覆盖 196 行) + + 通过 mock is_async 属性来模拟这种边界情况 + """ + from unittest.mock import patch + + def sync_handler(req: AgentRequest): + return "Hello" + + invoker = AgentInvoker(sync_handler) + + # 强制设置 is_async 为 True,模拟边界情况 + with patch.object(invoker, "is_async", True): + # 由于 sync_handler 返回的是字符串(不是 awaitable), + # 代码会走到第 196 行的 else 分支 + result = await invoker.invoke(req) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].data["delta"] == "Hello" diff --git a/tests/unittests/server/test_openai_protocol.py b/tests/unittests/server/test_openai_protocol.py new file mode 100644 index 0000000..53e2c30 --- /dev/null +++ b/tests/unittests/server/test_openai_protocol.py @@ -0,0 +1,1008 @@ +"""OpenAI 协议处理器测试 + +测试 OpenAIProtocolHandler 的各种功能。 +""" + +import json +from typing import cast + +from fastapi.testclient import TestClient +import pytest + +from agentrun.server import ( + AgentEvent, + AgentRequest, + AgentRunServer, + EventType, + OpenAIProtocolHandler, + ServerConfig, +) + + +class TestOpenAIProtocolHandler: + """测试 OpenAIProtocolHandler""" + + def test_get_prefix_default(self): + """测试默认前缀""" + handler = OpenAIProtocolHandler() + assert handler.get_prefix() == "/openai/v1" + + def test_get_prefix_custom(self): + """测试自定义前缀""" + from agentrun.server.model import OpenAIProtocolConfig + + config = ServerConfig(openai=OpenAIProtocolConfig(prefix="/custom/api")) + handler = OpenAIProtocolHandler(config) + assert handler.get_prefix() == "/custom/api" + + def test_get_model_name_default(self): + """测试默认模型名称""" + handler = OpenAIProtocolHandler() + assert handler.get_model_name() == "agentrun" + + def test_get_model_name_custom(self): + """测试自定义模型名称""" + from agentrun.server.model import OpenAIProtocolConfig + + config = ServerConfig( + openai=OpenAIProtocolConfig(model_name="custom-model") + ) + handler = OpenAIProtocolHandler(config) + assert handler.get_model_name() == "custom-model" + + +class TestOpenAIProtocolEndpoints: + """测试 OpenAI 协议端点""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + @pytest.mark.asyncio + async def test_list_models(self): + """测试 /models 端点""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + client = self.get_client(invoke_agent) + response = client.get("/openai/v1/models") + + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["id"] == "agentrun" + assert data["data"][0]["object"] == "model" + assert data["data"][0]["owned_by"] == "agentrun" + + @pytest.mark.asyncio + async def test_missing_messages_error(self): + """测试缺少 messages 字段时返回错误""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={"model": "test"}, # 缺少 messages + ) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "invalid_request_error" + assert "messages" in data["error"]["message"] + + @pytest.mark.asyncio + async def test_invalid_message_format(self): + """测试无效消息格式""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={"messages": ["not a dict"]}, # 无效格式 + ) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + assert "Invalid message format" in data["error"]["message"] + + @pytest.mark.asyncio + async def test_missing_role_in_message(self): + """测试消息缺少 role 字段""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={"messages": [{"content": "Hello"}]}, # 缺少 role + ) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + assert "role" in data["error"]["message"] + + @pytest.mark.asyncio + async def test_invalid_role(self): + """测试无效的消息角色""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={"messages": [{"role": "invalid_role", "content": "Hello"}]}, + ) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + assert "Invalid message role" in data["error"]["message"] + + @pytest.mark.asyncio + async def test_internal_error(self): + """测试内部错误处理""" + + def invoke_agent(request: AgentRequest): + raise RuntimeError("Internal error") + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 500 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "internal_error" + + @pytest.mark.asyncio + async def test_non_stream_with_tool_calls(self): + """测试非流式响应中的工具调用""" + + def invoke_agent(request: AgentRequest): + return AgentEvent( + event=EventType.TOOL_CALL, + data={ + "id": "tc-1", + "name": "test_tool", + "args": '{"param": "value"}', + }, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hello"}], + "stream": False, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["choices"][0]["finish_reason"] == "tool_calls" + assert "tool_calls" in data["choices"][0]["message"] + assert data["choices"][0]["message"]["tool_calls"][0]["id"] == "tc-1" + + @pytest.mark.asyncio + async def test_non_stream_response_collection(self): + """测试非流式响应收集流式结果""" + + async def invoke_agent(request: AgentRequest): + yield "Hello" + yield " World" + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["choices"][0]["message"]["content"] == "Hello World" + + @pytest.mark.asyncio + async def test_message_with_tool_calls(self): + """测试解析带有 tool_calls 的消息""" + + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["messages"] = request.messages + return "Done" + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [ + { + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Beijing"}', + }, + }], + }, + { + "role": "tool", + "content": "Sunny", + "tool_call_id": "call_123", + }, + ], + }, + ) + + assert response.status_code == 200 + assert len(captured_request["messages"]) == 2 + assert captured_request["messages"][0].tool_calls is not None + assert captured_request["messages"][0].tool_calls[0].id == "call_123" + assert captured_request["messages"][1].tool_call_id == "call_123" + + @pytest.mark.asyncio + async def test_parse_tools(self): + """测试解析工具列表""" + + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["tools"] = request.tools + return "Done" + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": {"type": "object"}, + }, + }], + }, + ) + + assert response.status_code == 200 + assert captured_request["tools"] is not None + assert len(captured_request["tools"]) == 1 + assert captured_request["tools"][0].function["name"] == "get_weather" + + @pytest.mark.asyncio + async def test_parse_tools_with_non_dict(self): + """测试解析工具列表时跳过非字典项""" + + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["tools"] = request.tools + return "Done" + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "tools": [ + "invalid_tool", # 非字典项,应该被跳过 + { + "type": "function", + "function": {"name": "valid_tool"}, + }, + ], + }, + ) + + assert response.status_code == 200 + assert captured_request["tools"] is not None + assert len(captured_request["tools"]) == 1 + + @pytest.mark.asyncio + async def test_parse_tools_empty_after_filter(self): + """测试解析工具列表后为空时返回 None""" + + captured_request = {} + + def invoke_agent(request: AgentRequest): + captured_request["tools"] = request.tools + return "Done" + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "tools": ["invalid1", "invalid2"], # 全部非字典项 + }, + ) + + assert response.status_code == 200 + assert captured_request["tools"] is None + + @pytest.mark.asyncio + async def test_addition_merge_overrides(self): + """测试 addition 默认合并覆盖字段""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TEXT, + data={"delta": "Hello"}, + addition={"custom": "value"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 第一个 chunk 应该包含 addition 字段 + first_line = lines[0] + assert first_line.startswith("data: ") + data = json.loads(first_line[6:]) + assert "custom" in data["choices"][0]["delta"] + + @pytest.mark.asyncio + async def test_addition_protocol_only_mode(self): + """测试 addition PROTOCOL_ONLY 模式""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TEXT, + data={"delta": "Hello"}, + addition={ + "content": "overwritten", # 已存在的字段会被覆盖 + "new_field": "ignored", # 新字段会被忽略 + }, + addition_merge_options={"no_new_field": True}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + first_line = lines[0] + data = json.loads(first_line[6:]) + delta = data["choices"][0]["delta"] + # content 被覆盖 + assert delta["content"] == "overwritten" + # new_field 不存在(被忽略) + assert "new_field" not in delta + + @pytest.mark.asyncio + async def test_tool_call_with_addition(self): + """测试工具调用时的 addition 处理""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "test", "args_delta": "{}"}, + addition={"custom_tool_field": "value"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 第二个 chunk 是参数增量,应该包含 addition + args_line = lines[1] + data = json.loads(args_line[6:]) + delta = data["choices"][0]["delta"] + assert "custom_tool_field" in delta + + @pytest.mark.asyncio + async def test_non_stream_multiple_tool_call_chunks(self): + """测试非流式响应中多个工具调用 chunk 的合并""" + + def invoke_agent(request: AgentRequest): + return [ + AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": '{"a":'}, + ), + AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "args_delta": "1}"}, + ), + ] + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, + }, + ) + + assert response.status_code == 200 + data = response.json() + tool_calls = data["choices"][0]["message"]["tool_calls"] + assert len(tool_calls) == 1 + assert tool_calls[0]["function"]["arguments"] == '{"a":1}' + + +class TestOpenAIProtocolStreamBranches: + """测试 OpenAI 协议流式响应的各种分支""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + @pytest.mark.asyncio + async def test_stream_with_only_tool_calls(self): + """测试只有工具调用没有文本的流式响应""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "{}"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 最后一个非 [DONE] 行应该有 finish_reason: tool_calls + for line in reversed(lines): + if line.startswith("data: {"): + data = json.loads(line[6:]) + if data["choices"][0].get("finish_reason"): + assert data["choices"][0]["finish_reason"] == "tool_calls" + break + + @pytest.mark.asyncio + async def test_stream_with_empty_content(self): + """测试空内容不会发送""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TEXT, + data={"delta": ""}, # 空内容 + ) + yield AgentEvent( + event=EventType.TEXT, + data={"delta": "Hello"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 计算实际内容行数(排除 [DONE] 和 finish_reason 行) + content_lines = [] + for line in lines: + if line.startswith("data: {"): + data = json.loads(line[6:]) + delta = data["choices"][0].get("delta", {}) + if delta.get("content"): + content_lines.append(data) + + # 只有一个非空内容 + assert len(content_lines) == 1 + + @pytest.mark.asyncio + async def test_stream_tool_call_without_args(self): + """测试工具调用没有参数增量""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": ""}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该有工具调用开始和结束 + assert len(lines) >= 2 + + @pytest.mark.asyncio + async def test_stream_multiple_tool_calls(self): + """测试多个工具调用的流式响应""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-2", "name": "tool2", "args_delta": "{}"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 检查两个工具调用的索引 + tool_indices = set() + for line in lines: + if line.startswith("data: {"): + data = json.loads(line[6:]) + delta = data["choices"][0].get("delta", {}) + if "tool_calls" in delta: + for tc in delta["tool_calls"]: + tool_indices.add(tc.get("index")) + + assert 0 in tool_indices + assert 1 in tool_indices + + @pytest.mark.asyncio + async def test_stream_raw_event(self): + """测试 RAW 事件在流式响应中""" + + async def invoke_agent(request: AgentRequest): + yield "Hello" + yield AgentEvent( + event=EventType.RAW, + data={"raw": "custom raw data"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + content = response.text + assert "custom raw data" in content + + @pytest.mark.asyncio + async def test_stream_tool_result_ignored(self): + """测试 TOOL_RESULT 事件在流式响应中被忽略""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "{}"}, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "result"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # TOOL_RESULT 不应该出现在响应中 + for line in lines: + if line.startswith("data: {"): + data = json.loads(line[6:]) + # 检查没有 tool_result 相关内容 + assert "result" not in str(data) + + +class TestOpenAIProtocolNonStreamBranches: + """测试 OpenAI 协议非流式响应的各种分支""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + @pytest.mark.asyncio + async def test_non_stream_with_multiple_tools(self): + """测试非流式响应中多个工具调用""" + + def invoke_agent(request: AgentRequest): + return [ + AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": "{}"}, + ), + AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "tc-2", + "name": "tool2", + "args_delta": '{"x": 1}', + }, + ), + ] + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, + }, + ) + + assert response.status_code == 200 + data = response.json() + tool_calls = data["choices"][0]["message"]["tool_calls"] + assert len(tool_calls) == 2 + + @pytest.mark.asyncio + async def test_non_stream_with_empty_tool_id(self): + """测试非流式响应中空工具 ID""" + + def invoke_agent(request: AgentRequest): + return [ + AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "", + "name": "tool1", + "args_delta": "{}", + }, # 空 ID + ), + ] + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, + }, + ) + + assert response.status_code == 200 + data = response.json() + # 空 ID 的工具调用不会被添加 + assert data["choices"][0]["message"].get("tool_calls") is None + + @pytest.mark.asyncio + async def test_non_stream_with_empty_args_delta(self): + """测试非流式响应中空参数增量""" + + def invoke_agent(request: AgentRequest): + return [ + AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool1", "args_delta": ""}, + ), + ] + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, + }, + ) + + assert response.status_code == 200 + data = response.json() + tool_calls = data["choices"][0]["message"]["tool_calls"] + assert len(tool_calls) == 1 + assert tool_calls[0]["function"]["arguments"] == "" + + @pytest.mark.asyncio + async def test_non_stream_with_text_events(self): + """测试非流式响应中的文本事件""" + + def invoke_agent(request: AgentRequest): + return [ + AgentEvent(event=EventType.TEXT, data={"delta": "Hello"}), + AgentEvent(event=EventType.TEXT, data={"delta": " World"}), + ] + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["choices"][0]["message"]["content"] == "Hello World" + assert data["choices"][0]["finish_reason"] == "stop" + + +class TestOpenAIProtocolApplyAddition: + """测试 _apply_addition 方法""" + + def test_apply_addition_default_merge(self): + """默认合并应覆盖已有字段并保留原字段""" + handler = OpenAIProtocolHandler() + + delta = {"content": "Hello", "role": "assistant"} + addition = {"content": "overwritten", "new_field": "added"} + + result = handler._apply_addition( + delta.copy(), + addition.copy(), + ) + + assert result["content"] == "overwritten" + assert result["new_field"] == "added" + assert result["role"] == "assistant" + + def test_apply_addition_merge_options_none(self): + """显式传入 merge_options=None 仍按默认合并""" + handler = OpenAIProtocolHandler() + + delta = {"content": "Hello", "role": "assistant"} + addition = {"content": "overwritten", "new_field": "added"} + + result = handler._apply_addition( + delta.copy(), + addition.copy(), + ) + + assert result["content"] == "overwritten" + assert result["new_field"] == "added" + + def test_apply_addition_protocol_only_mode(self): + """测试 PROTOCOL_ONLY 模式(覆盖 527->530 分支)""" + handler = OpenAIProtocolHandler() + + delta = {"content": "Hello", "role": "assistant"} + addition = {"content": "overwritten", "new_field": "ignored"} + + result = handler._apply_addition( + delta.copy(), + addition.copy(), + {"no_new_field": True}, + ) + + # content 被覆盖 + assert result["content"] == "overwritten" + # new_field 不存在(被忽略) + assert "new_field" not in result + # role 保持不变 + assert result["role"] == "assistant" + + +class TestOpenAIProtocolRawEvent: + """测试 RAW 事件处理""" + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + return TestClient(server.as_fastapi_app()) + + @pytest.mark.asyncio + async def test_raw_event_with_newline(self): + """测试 RAW 事件已有换行符""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.RAW, + data={"raw": "custom data\n\n"}, # 已有换行符 + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + content = response.text + assert "custom data" in content + + @pytest.mark.asyncio + async def test_raw_event_empty(self): + """测试空 RAW 事件""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.RAW, + data={"raw": ""}, # 空内容 + ) + yield "Hello" + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_tool_call_chunk_without_id(self): + """测试没有 id 的 TOOL_CALL_CHUNK""" + + async def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "", "name": "tool", "args_delta": "{}"}, # 空 id + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_tool_call_chunk_existing_id(self): + """测试已存在 id 的 TOOL_CALL_CHUNK""" + + async def invoke_agent(request: AgentRequest): + # 第一个 chunk + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "name": "tool", "args_delta": '{"a":'}, + ) + # 第二个 chunk(同一个 id) + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={"id": "tc-1", "args_delta": "1}"}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 应该有多个工具调用相关的行 + tool_lines = [ + line + for line in lines + if line.startswith("data: {") and "tool_calls" in line + ] + assert len(tool_lines) >= 2 + + @pytest.mark.asyncio + async def test_stream_no_text_no_tools(self): + """测试没有文本也没有工具调用的流式响应""" + + async def invoke_agent(request: AgentRequest): + # 只发送其他类型的事件 + yield AgentEvent( + event=EventType.CUSTOM, + data={"name": "test", "value": {}}, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": True, + }, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines() if line] + + # 最后应该是 [DONE] + assert lines[-1] == "data: [DONE]" + + @pytest.mark.asyncio + async def test_non_stream_tool_call_without_args(self): + """测试非流式响应中没有参数增量的工具调用""" + + def invoke_agent(request: AgentRequest): + return [ + AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "tc-1", + "name": "tool1", + "args_delta": "", + }, # 空参数 + ), + ] + + client = self.get_client(invoke_agent) + response = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "stream": False, + }, + ) + + assert response.status_code == 200 + data = response.json() + tool_calls = data["choices"][0]["message"]["tool_calls"] + assert len(tool_calls) == 1 + assert tool_calls[0]["function"]["arguments"] == "" diff --git a/tests/unittests/server/test_protocol.py b/tests/unittests/server/test_protocol.py new file mode 100644 index 0000000..05fb95d --- /dev/null +++ b/tests/unittests/server/test_protocol.py @@ -0,0 +1,146 @@ +"""协议基类测试 + +测试 ProtocolHandler 和 BaseProtocolHandler 的基类方法。 +""" + +import pytest + +from agentrun.server.protocol import BaseProtocolHandler, ProtocolHandler + + +class TestProtocolHandler: + """测试 ProtocolHandler 基类""" + + def test_get_prefix_default(self): + """测试 get_prefix 默认返回空字符串""" + + class TestHandler(ProtocolHandler): + name = "test" + + def as_fastapi_router(self, agent_invoker): + pass + + handler = TestHandler() + # 调用父类的 get_prefix 方法 + assert ProtocolHandler.get_prefix(handler) == "" + + def test_get_prefix_override(self): + """测试子类可以覆盖 get_prefix""" + + class TestHandler(ProtocolHandler): + name = "test" + + def as_fastapi_router(self, agent_invoker): + pass + + def get_prefix(self): + return "/custom" + + handler = TestHandler() + assert handler.get_prefix() == "/custom" + + +class TestBaseProtocolHandler: + """测试 BaseProtocolHandler 基类""" + + def test_parse_request_not_implemented(self): + """测试 parse_request 未实现时抛出异常""" + + class TestHandler(BaseProtocolHandler): + name = "test" + + def as_fastapi_router(self, agent_invoker): + pass + + handler = TestHandler() + + with pytest.raises( + NotImplementedError, match="Subclass must implement" + ): + import asyncio + + asyncio.run(handler.parse_request(None, {})) + + def test_is_iterator_with_dict(self): + """测试 _is_iterator 对字典返回 False""" + + class TestHandler(BaseProtocolHandler): + name = "test" + + def as_fastapi_router(self, agent_invoker): + pass + + handler = TestHandler() + assert handler._is_iterator({}) is False + assert handler._is_iterator({"a": 1}) is False + + def test_is_iterator_with_list(self): + """测试 _is_iterator 对列表返回 False""" + + class TestHandler(BaseProtocolHandler): + name = "test" + + def as_fastapi_router(self, agent_invoker): + pass + + handler = TestHandler() + assert handler._is_iterator([]) is False + assert handler._is_iterator([1, 2, 3]) is False + + def test_is_iterator_with_string(self): + """测试 _is_iterator 对字符串返回 False""" + + class TestHandler(BaseProtocolHandler): + name = "test" + + def as_fastapi_router(self, agent_invoker): + pass + + handler = TestHandler() + assert handler._is_iterator("") is False + assert handler._is_iterator("hello") is False + + def test_is_iterator_with_bytes(self): + """测试 _is_iterator 对字节返回 False""" + + class TestHandler(BaseProtocolHandler): + name = "test" + + def as_fastapi_router(self, agent_invoker): + pass + + handler = TestHandler() + assert handler._is_iterator(b"") is False + assert handler._is_iterator(b"hello") is False + + def test_is_iterator_with_generator(self): + """测试 _is_iterator 对生成器返回 True""" + + class TestHandler(BaseProtocolHandler): + name = "test" + + def as_fastapi_router(self, agent_invoker): + pass + + handler = TestHandler() + + def gen(): + yield 1 + + assert handler._is_iterator(gen()) is True + + def test_is_iterator_with_async_generator(self): + """测试 _is_iterator 对异步生成器返回 True""" + + class TestHandler(BaseProtocolHandler): + name = "test" + + def as_fastapi_router(self, agent_invoker): + pass + + handler = TestHandler() + + async def async_gen(): + yield 1 + + assert handler._is_iterator(async_gen()) is True diff --git a/tests/unittests/server/test_server.py b/tests/unittests/server/test_server.py new file mode 100644 index 0000000..20150b6 --- /dev/null +++ b/tests/unittests/server/test_server.py @@ -0,0 +1,933 @@ +import asyncio +import json +from typing import Any, cast, Union + +import pytest + +from agentrun.server.model import AgentRequest, MessageRole +from agentrun.server.server import AgentRunServer + + +class ProtocolValidator: + + def try_parse_streaming_line(self, line: str) -> Union[str, dict, list]: + """解析流式响应行,去除前缀 'data: ' 并转换为 JSON""" + + if type(line) is not str or line.startswith("data: [DONE]"): + return line + if line.startswith("data: {") or line.startswith("data: ["): + json_str = line[len("data: ") :] + return json.loads(json_str) + return line + + def parse_streaming_line(self, line: str) -> Union[str, dict, list]: + """解析流式响应行,去除前缀 'data: ' 并转换为 JSON""" + + if type(line) is not str or line.startswith("data: [DONE]"): + return line + if line.startswith("data: {") or line.startswith("data: ["): + json_str = line[len("data: ") :] + return json.loads(json_str) + return line + + def valid_json(self, got: Any, expect: Any): + """检查 json 是否匹配,如果 expect 的 value 为 mock-placeholder 则仅检查 key 存在""" + + def valid(path: str, got: Any, expect: Any): + if expect == "mock-placeholder": + assert got, f"{path} 存在但值为空" + else: + got = self.try_parse_streaming_line(got) + expect = self.try_parse_streaming_line(expect) + + if isinstance(expect, dict): + assert isinstance( + got, dict + ), f"{path} 类型不匹配,期望 dict,实际 {type(got)}" + for k, v in expect.items(): + valid(f"{path}.{k}", got.get(k), v) + + for k in got.keys(): + if k not in expect: + assert False, f"{path} 多余的键: {k}" + elif isinstance(expect, list): + assert isinstance( + got, list + ), f"{path} 类型不匹配,期望 list,实际 {type(got)}" + assert len(got) == len( + expect + ), f"{path} 列表长度不匹配,期望 {len(expect)},实际 {len(got)}" + for i in range(len(expect)): + valid(f"{path}[{i}]", got[i], expect[i]) + else: + assert ( + got == expect + ), f"{path} 不匹配,期望: {type(expect)} {expect},实际: {got}" + + print("valid", got, expect) + valid("", got, expect) + + def all_field_equal( + self, + key: str, + arr: list, + ): + """检查列表中所有对象的指定字段值是否相等""" + print("all_field_equal", arr, key) + + value = "" + for item in arr: + data = self.try_parse_streaming_line(item) + data = cast(dict, data) + + if value == "": + value = data[key] + + assert value == data[key] + assert value + + +class TestServer(ProtocolValidator): + + def get_invoke_agent_non_streaming(self): + + def invoke_agent(request: AgentRequest): + # 检查请求消息,返回预期的响应 + user_message = next( + ( + msg.content + for msg in request.messages + if msg.role == MessageRole.USER + ), + "Hello", + ) + + return f"You said: {user_message}" + + return invoke_agent + + def get_invoke_agent_streaming(self): + + async def streaming_invoke_agent(request: AgentRequest): + yield "Hello, " + await asyncio.sleep(0.01) # 短暂延迟 + yield "this is " + await asyncio.sleep(0.01) + yield "a test." + + return streaming_invoke_agent + + def get_client(self, invoke_agent): + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + from fastapi.testclient import TestClient + + return TestClient(app) + + async def test_server_non_streaming_protocols(self): + """测试非流式的 OpenAI 和 AGUI 服务器响应功能""" + + client = self.get_client(self.get_invoke_agent_non_streaming()) + + # 测试 OpenAI 协议 + response_openai = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "AgentRun"}], + "model": "test-model", + }, + ) + + # 检查响应状态 + assert response_openai.status_code == 200 + + # 检查响应内容 + response_data_openai = response_openai.json() + + self.valid_json( + response_data_openai, + { + "id": "mock-placeholder", + "object": "chat.completion", + "created": "mock-placeholder", + "model": "test-model", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "You said: AgentRun", + }, + "finish_reason": "stop", + }], + }, + ) + + # AGUI 协议始终是流式传输,因此没有非流式测试 + # 测试 AGUI 协议(即使非流式请求也会以流式方式处理) + response_agui = client.post( + "/ag-ui/agent", + json={ + "messages": [{"role": "user", "content": "AgentRun"}], + "model": "test-model", + }, + ) + + # 检查响应状态 + assert response_agui.status_code == 200 + lines_agui = [line async for line in response_agui.aiter_lines()] + lines_agui = [line for line in lines_agui if line] + + # AG-UI 流式格式:RUN_STARTED + TEXT_MESSAGE_START + TEXT_MESSAGE_CONTENT + TEXT_MESSAGE_END + RUN_FINISHED + assert len(lines_agui) == 5 + + # 验证 AGUI 流式事件序列 + self.valid_json( + lines_agui, + [ + ( + "data:" + ' {"type":"RUN_STARTED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_START","messageId":"mock-placeholder","role":"assistant"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"You' + ' said: AgentRun"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_END","messageId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"RUN_FINISHED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ], + ) + + async def test_server_streaming_protocols(self): + """测试流式的 OpenAI 和 AGUI 服务器响应功能""" + + # 测试 OpenAI 协议流式响应 + client = self.get_client(self.get_invoke_agent_streaming()) + + response_openai = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "AgentRun"}], + "model": "test-model", + "stream": True, + }, + ) + + # 检查响应状态 + assert response_openai.status_code == 200 + lines_openai = [line async for line in response_openai.aiter_lines()] + + # 过滤空行 + lines_openai = [line for line in lines_openai if line] + + # OpenAI 流式格式:第一个 chunk 是 role 声明,后续是内容 + # 格式:data: {...} + self.valid_json( + lines_openai, + [ + ( + 'data: {"id": "mock-placeholder", "object":' + ' "chat.completion.chunk", "created": "mock-placeholder",' + ' "model": "test-model", "choices": [{"index": 0, "delta":' + ' {"role": "assistant", "content": "Hello, "},' + ' "finish_reason": null}]}' + ), + ( + 'data: {"id": "mock-placeholder", "object":' + ' "chat.completion.chunk", "created": "mock-placeholder",' + ' "model": "test-model", "choices": [{"index": 0, "delta":' + ' {"content": "this is "}, "finish_reason": null}]}' + ), + ( + 'data: {"id": "mock-placeholder", "object":' + ' "chat.completion.chunk", "created": "mock-placeholder",' + ' "model": "test-model", "choices": [{"index": 0, "delta":' + ' {"content": "a test."}, "finish_reason": null}]}' + ), + ( + 'data: {"id": "mock-placeholder", "object":' + ' "chat.completion.chunk", "created": "mock-placeholder",' + ' "model": "test-model", "choices": [{"index": 0, "delta":' + ' {}, "finish_reason": "stop"}]}' + ), + "data: [DONE]", + ], + ) + self.all_field_equal("id", lines_openai[:-1]) + + # 测试 AGUI 协议流式响应 + response_agui = client.post( + "/ag-ui/agent", + json={ + "messages": [{"role": "user", "content": "AgentRun"}], + "model": "test-model", + "stream": True, + }, + ) + + # 检查响应状态 + assert response_agui.status_code == 200 + lines_agui = [line async for line in response_agui.aiter_lines()] + + # 过滤空行 + lines_agui = [line for line in lines_agui if line] + + # AG-UI 流式格式:每个 chunk 是一个 JSON 对象 + self.valid_json( + lines_agui, + [ + ( + "data:" + ' {"type":"RUN_STARTED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_START","messageId":"mock-placeholder","role":"assistant"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"Hello, "}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"this' + ' is "}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"a' + ' test."}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_END","messageId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"RUN_FINISHED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ], + ) + self.all_field_equal("threadId", [lines_agui[0], lines_agui[-1]]) + self.all_field_equal("runId", [lines_agui[0], lines_agui[-1]]) + self.all_field_equal("messageId", lines_agui[1:6]) + + async def test_server_raw_event_protocols(self): + """测试 RAW 事件直接返回原始数据(OpenAI 和 AG-UI 协议) + + RAW 事件可以在任何时间触发,输出原始 SSE 内容,不影响其他事件的正常处理。 + 支持任意 SSE 格式(data:, :注释, 等)。 + """ + from agentrun.server import AgentEvent, AgentRequest, EventType + + async def streaming_invoke_agent(request: AgentRequest): + # 测试 RAW 事件与其他事件混合 + yield "你好" + yield AgentEvent( + event=EventType.RAW, + data={"raw": '{"custom": "data"}'}, + ) + yield AgentEvent(event=EventType.TEXT, data={"delta": "再见"}) + + client = self.get_client(streaming_invoke_agent) + + # 测试 OpenAI 协议的 RAW 事件 + response_openai = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "test"}], + "model": "agentrun", + "stream": True, + }, + ) + + assert response_openai.status_code == 200 + lines_openai = [line async for line in response_openai.aiter_lines()] + lines_openai = [line for line in lines_openai if line] + + # OpenAI 流式响应: + # 1. role: assistant + content: 你好(合并在首个 chunk) + # 2. RAW: {"custom": "data"} + # 3. content: 再见 + # 4. finish_reason: stop + # 5. [DONE] + self.valid_json( + lines_openai, + [ + ( + 'data: {"id": "mock-placeholder", "object":' + ' "chat.completion.chunk", "created": "mock-placeholder",' + ' "model": "agentrun", "choices": [{"index": 0, "delta":' + ' {"role": "assistant", "content": "你好"},' + ' "finish_reason": null}]}' + ), + '{"custom": "data"}', + ( + 'data: {"id": "mock-placeholder", "object":' + ' "chat.completion.chunk", "created": "mock-placeholder",' + ' "model": "agentrun", "choices": [{"index": 0, "delta":' + ' {"content": "再见"}, "finish_reason": null}]}' + ), + ( + 'data: {"id": "mock-placeholder", "object":' + ' "chat.completion.chunk", "created": "mock-placeholder",' + ' "model": "agentrun", "choices": [{"index": 0, "delta":' + ' {}, "finish_reason": "stop"}]}' + ), + "data: [DONE]", + ], + ) + self.all_field_equal("id", [lines_openai[0], *lines_openai[2:-1]]) + + # 测试 AGUI 协议的 RAW 事件 + response_agui = client.post( + "/ag-ui/agent", + json={ + "messages": [{"role": "user", "content": "test"}], + "stream": True, + }, + ) + + assert response_agui.status_code == 200 + lines_agui = [line async for line in response_agui.aiter_lines()] + lines_agui = [line for line in lines_agui if line] + + # AGUI 流式响应中应该包含 RAW 事件 + # 1. RUN_STARTED + # 2. TEXT_MESSAGE_START + # 3. TEXT_MESSAGE_CONTENT ("你好") + # 4. RAW 事件 '{"custom": "data"}' + # 5. TEXT_MESSAGE_CONTENT ("再见") + # 6. TEXT_MESSAGE_END + # 7. RUN_FINISHED + self.valid_json( + lines_agui, + [ + ( + "data:" + ' {"type":"RUN_STARTED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_START","messageId":"mock-placeholder","role":"assistant"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"你好"}' + ), + '{"custom": "data"}', + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"再见"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_END","messageId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"RUN_FINISHED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ], + ) + self.all_field_equal("threadId", [lines_agui[0], lines_agui[-1]]) + self.all_field_equal("runId", [lines_agui[0], lines_agui[-1]]) + self.all_field_equal( + "messageId", + [lines_agui[1], lines_agui[2], lines_agui[4], lines_agui[5]], + ) + + async def test_server_addition_merge(self): + """测试 addition 字段的合并功能""" + from agentrun.server import AgentEvent, AgentRequest, EventType + + async def streaming_invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TEXT, + data={"message_id": "msg_1", "delta": "Hello"}, + addition={ + "model": "custom_model", + "custom_field": "custom_value", + }, + ) + + client = self.get_client(streaming_invoke_agent) + + # 测试 OpenAI 协议 + response_openai = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "test"}], + "model": "test-model", + "stream": True, + }, + ) + + assert response_openai.status_code == 200 + lines = [line async for line in response_openai.aiter_lines()] + lines = [line for line in lines if line] + + # OpenAI 流式格式:只有一个内容行 + 完成行 + [DONE] + + self.valid_json( + lines, + [ + ( + 'data: {"id": "mock-placeholder", "object":' + ' "chat.completion.chunk", "created": "mock-placeholder",' + ' "model": "test-model", "choices": [{"index": 0, "delta":' + ' {"role": "assistant", "content": "Hello", "model":' + ' "custom_model", "custom_field": "custom_value"},' + ' "finish_reason": null}]}' + ), + ( + 'data: {"id": "mock-placeholder", "object":' + ' "chat.completion.chunk", "created": "mock-placeholder",' + ' "model": "test-model", "choices": [{"index": 0, "delta":' + ' {}, "finish_reason": "stop"}]}' + ), + "data: [DONE]", + ], + ) + self.all_field_equal("id", lines[:-1]) + + response_agui = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + ) + + assert response_agui.status_code == 200 + lines_agui = [line async for line in response_agui.aiter_lines()] + lines_agui = [line for line in lines_agui if line] + + # AG-UI 流式格式:RUN_STARTED + TEXT_MESSAGE_START + TEXT_MESSAGE_CONTENT + TEXT_MESSAGE_END + RUN_FINISHED + self.valid_json( + lines_agui, + [ + ( + "data:" + ' {"type":"RUN_STARTED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_START","messageId":"mock-placeholder","role":"assistant"}' + ), + ( + 'data: {"type": "TEXT_MESSAGE_CONTENT", "messageId":' + ' "mock-placeholder", "delta": "Hello", "model":' + ' "custom_model", "custom_field": "custom_value"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_END","messageId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"RUN_FINISHED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ], + ) + self.all_field_equal("threadId", [lines_agui[0], lines_agui[-1]]) + self.all_field_equal("runId", [lines_agui[0], lines_agui[-1]]) + self.all_field_equal("messageId", lines_agui[1:4]) + + async def test_server_tool_call_protocols(self): + """测试 OpenAI 和 AG-UI 协议中的工具调用事件序列""" + from agentrun.server import AgentEvent, AgentRequest, EventType + + async def streaming_invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.TOOL_CALL, + data={ + "id": "tc-1", + "name": "weather_tool", + "args": '{"location": "Beijing"}', + }, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "Sunny, 25°C"}, + ) + + client = self.get_client(streaming_invoke_agent) + + # 测试 OpenAI 协议的工具调用 + response_openai = client.post( + "/openai/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "What's the weather?"} + ], + "stream": True, + }, + ) + + assert response_openai.status_code == 200 + lines_openai = [line async for line in response_openai.aiter_lines()] + lines_openai = [line for line in lines_openai if line] + + # OpenAI 流式格式:包含工具调用的事件序列 + # 1. role + tool_calls(包含 id, type, function.name, function.arguments) + # 2. tool_calls(包含 id 和 function.arguments delta) + # 3. finish_reason: tool_calls + # 4. [DONE] + assert len(lines_openai) == 4 + + # 第一个 chunk 包含工具调用信息(可能不包含role,具体取决于工具调用的类型) + assert lines_openai[0].startswith("data: {") + line0 = self.try_parse_streaming_line(lines_openai[0]) + line0 = cast(dict, line0) # 类型断言 + assert line0["object"] == "chat.completion.chunk" + assert "tool_calls" in line0["choices"][0]["delta"] + assert ( + line0["choices"][0]["delta"]["tool_calls"][0]["type"] == "function" + ) + assert ( + line0["choices"][0]["delta"]["tool_calls"][0]["function"]["name"] + == "weather_tool" + ) + assert line0["choices"][0]["delta"]["tool_calls"][0]["id"] == "tc-1" + + # 第二个 chunk 包含函数参数(不包含ID,只有参数) + assert lines_openai[1].startswith("data: {") + line1 = self.try_parse_streaming_line(lines_openai[1]) + line1 = cast(dict, line1) # 类型断言 + assert line1["object"] == "chat.completion.chunk" + assert ( + line1["choices"][0]["delta"]["tool_calls"][0]["function"][ + "arguments" + ] + == '{"location": "Beijing"}' + ) + + # 第三个 chunk 包含 finish_reason + assert lines_openai[2].startswith("data: {") + line2 = self.try_parse_streaming_line(lines_openai[2]) + line2 = cast(dict, line2) # 类型断言 + assert line2["object"] == "chat.completion.chunk" + assert line2["choices"][0]["finish_reason"] == "tool_calls" + + # 最后是 [DONE] + assert lines_openai[3] == "data: [DONE]" + + # 测试 AG-UI 协议的工具调用 + response_agui = client.post( + "/ag-ui/agent", + json={ + "messages": [{"role": "user", "content": "What's the weather?"}] + }, + ) + + assert response_agui.status_code == 200 + lines_agui = [line async for line in response_agui.aiter_lines()] + lines_agui = [line for line in lines_agui if line] + + # AG-UI 流式格式:RUN_STARTED + TOOL_CALL_START + TOOL_CALL_ARGS + TOOL_CALL_END + TOOL_CALL_RESULT + RUN_FINISHED + # 注意:由于没有文本内容,所以不会触发 TEXT_MESSAGE_* 事件 + # TOOL_CALL 会先触发 TOOL_CALL_START,然后是 TOOL_CALL_ARGS(使用 args_delta),最后是 TOOL_CALL_END + # TOOL_RESULT 会被转换为 TOOL_CALL_RESULT + self.valid_json( + lines_agui, + [ + ( + "data:" + ' {"type":"RUN_STARTED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_START","toolCallId":"tc-1","toolCallName":"weather_tool"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_ARGS","toolCallId":"tc-1","delta":"{\\"location\\":' + ' \\"Beijing\\"}"}' + ), + 'data: {"type":"TOOL_CALL_END","toolCallId":"tc-1"}', + ( + "data:" + ' {"type":"TOOL_CALL_RESULT","messageId":"mock-placeholder","toolCallId":"tc-1","content":"Sunny,' + ' 25°C","role":"tool"}' + ), + ( + "data:" + ' {"type":"RUN_FINISHED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ], + ) + self.all_field_equal("threadId", [lines_agui[0], lines_agui[-1]]) + self.all_field_equal("runId", [lines_agui[0], lines_agui[-1]]) + self.all_field_equal("toolCallId", lines_agui[1:5]) + + @pytest.mark.asyncio + async def test_server_text_then_tool_call_agui(self): + """测试 AG-UI 协议中先文本后工具调用的事件序列 + + AG-UI 协议要求:发送 TOOL_CALL_START 前必须先发送 TEXT_MESSAGE_END + """ + from agentrun.server import AgentEvent, AgentRequest, EventType + + async def streaming_invoke_agent(request: AgentRequest): + # 先发送文本 + yield "思考中..." + # 然后发送工具调用 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "tc-1", + "name": "search_tool", + "args_delta": '{"query": "test"}', + }, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "搜索结果"}, + ) + + client = self.get_client(streaming_invoke_agent) + + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "搜索一下"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines()] + lines = [line for line in lines if line] + + # 预期事件序列: + # 1. RUN_STARTED + # 2. TEXT_MESSAGE_START + # 3. TEXT_MESSAGE_CONTENT + # 4. TEXT_MESSAGE_END <-- 必须在 TOOL_CALL_START 之前 + # 5. TOOL_CALL_START + # 6. TOOL_CALL_ARGS + # 7. TOOL_CALL_END + # 8. TOOL_CALL_RESULT + # 9. RUN_FINISHED + self.valid_json( + lines, + [ + ( + "data:" + ' {"type":"RUN_STARTED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_START","messageId":"mock-placeholder","role":"assistant"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"思考中..."}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_END","messageId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_START","toolCallId":"tc-1","toolCallName":"search_tool"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_ARGS","toolCallId":"tc-1","delta":"{\\"query\\":' + ' \\"test\\"}"}' + ), + 'data: {"type":"TOOL_CALL_END","toolCallId":"tc-1"}', + ( + "data:" + ' {"type":"TOOL_CALL_RESULT","messageId":"mock-placeholder","toolCallId":"tc-1","content":"搜索结果","role":"tool"}' + ), + ( + "data:" + ' {"type":"RUN_FINISHED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ], + ) + self.all_field_equal("threadId", [lines[0], lines[-1]]) + self.all_field_equal("runId", [lines[0], lines[-1]]) + self.all_field_equal("messageId", [lines[1], lines[2], lines[3]]) + self.all_field_equal("toolCallId", lines[4:8]) + + @pytest.mark.asyncio + async def test_server_text_tool_text_agui(self): + """测试 AG-UI 协议中 文本->工具调用->文本 的事件序列 + + 场景:先输出思考内容,然后调用工具,最后输出结果 + AG-UI 协议要求: + 1. 发送 TOOL_CALL_START 前必须先发送 TEXT_MESSAGE_END + 2. 工具调用后的新文本需要新的 TEXT_MESSAGE_START + """ + from agentrun.server import AgentEvent, AgentRequest, EventType + + async def streaming_invoke_agent(request: AgentRequest): + # 第一段文本 + yield "让我搜索一下..." + # 工具调用 + yield AgentEvent( + event=EventType.TOOL_CALL_CHUNK, + data={ + "id": "tc-1", + "name": "search", + "args_delta": '{"q": "天气"}', + }, + ) + yield AgentEvent( + event=EventType.TOOL_RESULT, + data={"id": "tc-1", "result": "晴天"}, + ) + # 第二段文本(工具调用后) + yield "根据搜索结果,今天是晴天。" + + client = self.get_client(streaming_invoke_agent) + + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "今天天气如何"}]}, + ) + + assert response.status_code == 200 + lines = [line async for line in response.aiter_lines()] + lines = [line for line in lines if line] + + # 预期事件序列: + # 1. RUN_STARTED + # 2. TEXT_MESSAGE_START (第一个文本消息) + # 3. TEXT_MESSAGE_CONTENT + # 4. TEXT_MESSAGE_END <-- 工具调用前必须结束 + # 5. TOOL_CALL_START + # 6. TOOL_CALL_ARGS + # 7. TOOL_CALL_END + # 8. TOOL_CALL_RESULT + # 9. TEXT_MESSAGE_START (第二个文本消息,新的 messageId) + # 10. TEXT_MESSAGE_CONTENT + # 11. TEXT_MESSAGE_END + # 12. RUN_FINISHED + self.valid_json( + lines, + [ + ( + "data:" + ' {"type":"RUN_STARTED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_START","messageId":"mock-placeholder","role":"assistant"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"让我搜索一下..."}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_END","messageId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_START","toolCallId":"tc-1","toolCallName":"search"}' + ), + ( + "data:" + ' {"type":"TOOL_CALL_ARGS","toolCallId":"tc-1","delta":"{\\"q\\":' + ' \\"天气\\"}"}' + ), + 'data: {"type":"TOOL_CALL_END","toolCallId":"tc-1"}', + ( + "data:" + ' {"type":"TOOL_CALL_RESULT","messageId":"mock-placeholder","toolCallId":"tc-1","content":"晴天","role":"tool"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_START","messageId":"mock-placeholder","role":"assistant"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_CONTENT","messageId":"mock-placeholder","delta":"根据搜索结果,' + '今天是晴天。"}' + ), + ( + "data:" + ' {"type":"TEXT_MESSAGE_END","messageId":"mock-placeholder"}' + ), + ( + "data:" + ' {"type":"RUN_FINISHED","threadId":"mock-placeholder","runId":"mock-placeholder"}' + ), + ], + ) + self.all_field_equal("threadId", [lines[0], lines[-1]]) + self.all_field_equal("runId", [lines[0], lines[-1]]) + self.all_field_equal("messageId", [lines[1], lines[2], lines[3]]) + self.all_field_equal("toolCallId", lines[4:8]) + + @pytest.mark.asyncio + async def test_agent_request_raw_request(self): + """测试 AgentRequest.raw_request 可以访问原始请求对象 + + 验证: + 1. raw_request 包含完整的 Starlette Request 对象 + 2. 可以访问 headers, query_params, client 等属性 + """ + from agentrun.server import AgentRequest + + captured_request: dict = {} + + async def invoke_agent(request: AgentRequest): + # 捕获请求信息 + captured_request["protocol"] = request.protocol + captured_request["has_raw_request"] = ( + request.raw_request is not None + ) + if request.raw_request: + captured_request["headers"] = dict(request.raw_request.headers) + captured_request["path"] = request.raw_request.url.path + captured_request["method"] = request.raw_request.method + return "Hello" + + client = self.get_client(invoke_agent) + + # 测试 AG-UI 协议 + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "test"}]}, + headers={"X-Custom-Header": "custom-value"}, + ) + assert response.status_code == 200 + + # 验证捕获的请求信息 + assert captured_request["protocol"] == "agui" + assert captured_request["has_raw_request"] is True + assert captured_request["path"] == "/ag-ui/agent" + assert captured_request["method"] == "POST" + assert ( + captured_request["headers"].get("x-custom-header") == "custom-value" + ) + + # 重置 + captured_request.clear() + + # 测试 OpenAI 协议 + response = client.post( + "/openai/v1/chat/completions", + json={"messages": [{"role": "user", "content": "test"}]}, + headers={"Authorization": "Bearer test-token"}, + ) + assert response.status_code == 200 + + # 验证捕获的请求信息 + assert captured_request["protocol"] == "openai" + assert captured_request["has_raw_request"] is True + assert captured_request["path"] == "/openai/v1/chat/completions" + assert ( + captured_request["headers"].get("authorization") + == "Bearer test-token" + ) diff --git a/tests/unittests/server/test_server_extended.py b/tests/unittests/server/test_server_extended.py new file mode 100644 index 0000000..8e4db03 --- /dev/null +++ b/tests/unittests/server/test_server_extended.py @@ -0,0 +1,212 @@ +"""Server 扩展测试 + +测试 AgentRunServer 的 CORS 配置和其他功能。 +""" + +from unittest.mock import MagicMock, patch + +from fastapi.testclient import TestClient +import pytest + +from agentrun.server import ( + AgentRequest, + AgentRunServer, + OpenAIProtocolHandler, + ServerConfig, +) + + +class TestServerCORS: + """测试 CORS 配置""" + + def test_cors_enabled(self): + """测试启用 CORS""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + config = ServerConfig(cors_origins=["http://localhost:3000"]) + server = AgentRunServer(invoke_agent=invoke_agent, config=config) + client = TestClient(server.as_fastapi_app()) + + # 发送预检请求 + response = client.options( + "/openai/v1/chat/completions", + headers={ + "Origin": "http://localhost:3000", + "Access-Control-Request-Method": "POST", + }, + ) + + # CORS 头应该存在 + assert "access-control-allow-origin" in response.headers + + def test_cors_disabled_by_default(self): + """测试默认不启用 CORS""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + server = AgentRunServer(invoke_agent=invoke_agent) + client = TestClient(server.as_fastapi_app()) + + response = client.post( + "/openai/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hi"}]}, + headers={"Origin": "http://localhost:3000"}, + ) + + # 没有 CORS 配置时,不应该有 CORS 头 + assert response.status_code == 200 + + def test_cors_with_multiple_origins(self): + """测试多个允许的源""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + config = ServerConfig( + cors_origins=["http://localhost:3000", "http://example.com"] + ) + server = AgentRunServer(invoke_agent=invoke_agent, config=config) + client = TestClient(server.as_fastapi_app()) + + # 测试第一个源 + response = client.options( + "/openai/v1/chat/completions", + headers={ + "Origin": "http://localhost:3000", + "Access-Control-Request-Method": "POST", + }, + ) + assert "access-control-allow-origin" in response.headers + + +class TestServerProtocols: + """测试协议配置""" + + def test_custom_protocols(self): + """测试自定义协议列表""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + # 只使用 OpenAI 协议 + server = AgentRunServer( + invoke_agent=invoke_agent, + protocols=[OpenAIProtocolHandler()], + ) + client = TestClient(server.as_fastapi_app()) + + # OpenAI 端点应该存在 + response = client.post( + "/openai/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + assert response.status_code == 200 + + # AG-UI 端点应该不存在 + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + assert response.status_code == 404 + + +class TestServerFastAPIApp: + """测试 FastAPI 应用导出""" + + def test_as_fastapi_app(self): + """测试 as_fastapi_app 方法""" + from fastapi import FastAPI + + def invoke_agent(request: AgentRequest): + return "Hello" + + server = AgentRunServer(invoke_agent=invoke_agent) + app = server.as_fastapi_app() + + assert isinstance(app, FastAPI) + assert app.title == "AgentRun Server" + + def test_mount_to_existing_app(self): + """测试挂载到现有 FastAPI 应用""" + from fastapi import FastAPI + + def invoke_agent(request: AgentRequest): + return "Hello" + + # 创建主应用 + main_app = FastAPI() + + @main_app.get("/") + def root(): + return {"message": "Main app"} + + # 挂载 AgentRun Server + agent_server = AgentRunServer(invoke_agent=invoke_agent) + main_app.mount("/agent", agent_server.as_fastapi_app()) + + client = TestClient(main_app) + + # 主应用端点 + response = client.get("/") + assert response.status_code == 200 + assert response.json()["message"] == "Main app" + + # AgentRun 端点 + response = client.post( + "/agent/openai/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hi"}]}, + ) + assert response.status_code == 200 + + +class TestServerStartMethod: + """测试 start 方法""" + + @patch("agentrun.server.server.uvicorn.run") + @patch("agentrun.server.server.logger") + def test_start_calls_uvicorn_run(self, mock_logger, mock_uvicorn_run): + """测试 start 方法调用 uvicorn.run""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + server = AgentRunServer(invoke_agent=invoke_agent) + server.start( + host="127.0.0.1", port=8080, log_level="debug", reload=True + ) + + # 验证 uvicorn.run 被调用 + mock_uvicorn_run.assert_called_once_with( + server.app, + host="127.0.0.1", + port=8080, + log_level="debug", + reload=True, + ) + + # 验证日志被记录 + mock_logger.info.assert_called_once() + assert "127.0.0.1" in mock_logger.info.call_args[0][0] + assert "8080" in mock_logger.info.call_args[0][0] + + @patch("agentrun.server.server.uvicorn.run") + @patch("agentrun.server.server.logger") + def test_start_with_default_args(self, mock_logger, mock_uvicorn_run): + """测试 start 方法使用默认参数""" + + def invoke_agent(request: AgentRequest): + return "Hello" + + server = AgentRunServer(invoke_agent=invoke_agent) + server.start() + + # 验证使用默认参数 + mock_uvicorn_run.assert_called_once_with( + server.app, + host="0.0.0.0", + port=9000, + log_level="info", + ) diff --git a/tests/unittests/utils/test_helper.py b/tests/unittests/utils/test_helper.py index 976e580..445e92f 100644 --- a/tests/unittests/utils/test_helper.py +++ b/tests/unittests/utils/test_helper.py @@ -1,3 +1,8 @@ +from typing import Optional + +from agentrun.utils.model import BaseModel + + def test_mask_password(): from agentrun.utils.helper import mask_password @@ -9,3 +14,109 @@ def test_mask_password(): assert mask_password("12") == "**" assert mask_password("1") == "*" assert mask_password("") == "" + + +def test_merge(): + from agentrun.utils.helper import merge + + assert merge(1, 2) == 2 + assert merge( + {"key1": "value1", "key2": {"subkey1": "subvalue1"}, "key3": 0}, + {"key2": {"subkey2": "subvalue2"}, "key3": "value3"}, + ) == { + "key1": "value1", + "key2": {"subkey1": "subvalue1", "subkey2": "subvalue2"}, + "key3": "value3", + } + + from agentrun.utils.helper import merge + + +def test_merge_list(): + from agentrun.utils.helper import merge + + assert merge({"a": ["a", "b"]}, {"a": ["b", "c"]}) == {"a": ["b", "c"]} + assert merge({"a": ["a", "b"]}, {"a": ["b", "c"]}, concat_list=True) == { + "a": ["a", "b", "b", "c"] + } + + assert merge([1, 2], [3, 4]) == [3, 4] + assert merge([1, 2], [3, 4], concat_list=True) == [1, 2, 3, 4] + assert merge([1, 2], [3, 4], ignore_empty_list=True) == [3, 4] + + assert merge([1, 2], []) == [] + assert merge([1, 2], [], concat_list=True) == [1, 2] + assert merge([1, 2], [], ignore_empty_list=True) == [1, 2] + + +def test_merge_dict(): + from agentrun.utils.helper import merge + + assert merge( + {"key1": "value1", "key2": "value2"}, + {"key2": "newvalue2", "key3": "newvalue3"}, + ) == {"key1": "value1", "key2": "newvalue2", "key3": "newvalue3"} + + assert merge( + {"key1": "value1", "key2": "value2"}, + {"key2": "newvalue2", "key3": "newvalue3"}, + no_new_field=True, + ) == {"key1": "value1", "key2": "newvalue2"} + + assert merge( + {"key1": {"subkey1": "subvalue1"}, "key2": {"subkey2": "subvalue2"}}, + { + "key2": {"subkey2": "newsubvalue2", "subkey3": "newsubvalue3"}, + "key3": "newvalue3", + }, + ) == { + "key1": {"subkey1": "subvalue1"}, + "key2": {"subkey2": "newsubvalue2", "subkey3": "newsubvalue3"}, + "key3": "newvalue3", + } + + assert merge( + {"key1": {"subkey1": "subvalue1"}, "key2": {"subkey2": "subvalue2"}}, + { + "key2": {"subkey2": "newsubvalue2", "subkey3": "newsubvalue3"}, + "key3": "newvalue3", + }, + no_new_field=True, + ) == {"key1": {"subkey1": "subvalue1"}, "key2": {"subkey2": "newsubvalue2"}} + + +def test_merge_class(): + from agentrun.utils.helper import merge + + class T(BaseModel): + a: Optional[int] = None + b: Optional[str] = None + c: Optional["T"] = None + d: Optional[list] = None + + assert merge( + T(b="2", c=T(a=3), d=[1, 2]), + T(a=5, c=T(b="8", c=None, d=[]), d=[3, 4]), + ) == T(a=5, b="2", c=T(a=3, b="8", c=None, d=[]), d=[3, 4]) + + assert merge( + T(b="2", c=T(a=3), d=[1, 2]), + T(a=5, c=T(b="8", c=None, d=[]), d=[3, 4]), + concat_list=True, + ) == T(a=5, b="2", c=T(a=3, b="8", c=None, d=[]), d=[1, 2, 3, 4]) + + assert merge( + T(b="2", c=T(a=3), d=[1, 2]), + T(a=5, c=T(b="8", c=None, d=[]), d=[3, 4]), + ignore_empty_list=True, + ) == T(a=5, b="2", c=T(a=3, b="8", c=None, d=[]), d=[3, 4]) + + # class 所有字段都是存在的,因此不会被 no_new_field 影响 + assert merge( + T(b="2", c=T(a=3), d=[1, 2]), + T(a=5, c=T(b="8", c=None, d=[]), d=[3, 4]), + ) == merge( + T(b="2", c=T(a=3), d=[1, 2]), + T(a=5, c=T(b="8", c=None, d=[]), d=[3, 4]), + no_new_field=True, + )