|
97 | 97 | import asyncio |
98 | 98 |
|
99 | 99 | import numpy as np |
100 | | -from openai.types.chat import ChatCompletionMessage |
101 | 100 |
|
102 | 101 | from eaa.api.memory import MemoryManagerConfig |
| 102 | +from eaa.message_proc import ( |
| 103 | + generate_openai_message, |
| 104 | + has_tool_call, |
| 105 | + get_tool_call_info, |
| 106 | + print_message |
| 107 | +) |
103 | 108 | from eaa.tools.base import BaseTool, ToolReturnType, ExposedToolSpec, generate_openai_tool_schema |
104 | 109 | from eaa.tools.mcp import MCPTool |
105 | 110 | from eaa.comms import get_api_key |
106 | | -from eaa.util import encode_image_base64, get_image_path_from_text |
| 111 | +from eaa.util import get_image_path_from_text |
107 | 112 | from eaa.api.llm_config import LLMConfig |
108 | 113 | from eaa.agents.memory import ( |
109 | 114 | MemoryManager, |
@@ -919,256 +924,3 @@ def register_message_hook(self, hook: Callable) -> None: |
919 | 924 | The hook function. |
920 | 925 | """ |
921 | 926 | self.message_hooks.append(hook) |
922 | | - |
923 | | - |
924 | | -def to_dict(message: ChatCompletionMessage | Dict[str, Any]) -> dict: |
925 | | - """Convert a ChatCompletionMessage to a dictionary. |
926 | | - """ |
927 | | - if isinstance(message, ChatCompletionMessage): |
928 | | - return message.to_dict() |
929 | | - else: |
930 | | - return message |
931 | | - |
932 | | - |
933 | | -def generate_openai_message( |
934 | | - content: str, |
935 | | - role: Literal["user", "system", "tool"] = "user", |
936 | | - tool_call_id: str = None, |
937 | | - image: np.ndarray = None, |
938 | | - image_path: str = None, |
939 | | - encoded_image: str = None |
940 | | -) -> Dict[str, Any]: |
941 | | - """Generate a dictionary in OpenAI-compatible format |
942 | | - containing the message to be sent to the agent. |
943 | | -
|
944 | | - Parameters |
945 | | - ---------- |
946 | | - content : str |
947 | | - The content of the message. |
948 | | - role : Literal["user", "system", "tool"], optional |
949 | | - The role of the sender. |
950 | | - image : np.ndarray, optional |
951 | | - The image to be sent to the agent. Exclusive with `encoded_image` and `image_path`. |
952 | | - image_path : str, optional |
953 | | - The path to the image to be sent to the agent. Exclusive with `image` and `encoded_image`. |
954 | | - encoded_image : str, optional |
955 | | - The base-64 encoded image to be sent to the agent. Exclusive with `image` and `image_path`. |
956 | | - """ |
957 | | - if sum([image is not None, encoded_image is not None, image_path is not None]) > 1: |
958 | | - raise ValueError("Only one of `image`, `encoded_image`, or `image_path` should be provided.") |
959 | | - if role not in ["user", "system", "tool"]: |
960 | | - raise ValueError("Invalid role. Must be one of `user`, `system`, or `tool`.") |
961 | | - |
962 | | - if image is not None or image_path is not None: |
963 | | - encoded_image = encode_image_base64(image=image, image_path=image_path) |
964 | | - |
965 | | - if role == "user": |
966 | | - message = { |
967 | | - "role": "user", |
968 | | - "content": content |
969 | | - } |
970 | | - elif role == "system": |
971 | | - message = { |
972 | | - "role": "system", |
973 | | - "content": content |
974 | | - } |
975 | | - elif role == "tool": |
976 | | - message = { |
977 | | - "role": "tool", |
978 | | - "content": content, |
979 | | - "tool_call_id": tool_call_id |
980 | | - } |
981 | | - |
982 | | - if encoded_image is not None: |
983 | | - message["content"] = [ |
984 | | - { |
985 | | - "type": "text", |
986 | | - "text": content |
987 | | - }, |
988 | | - { |
989 | | - "type": "image_url", |
990 | | - "image_url": { |
991 | | - "url": f"data:image/png;base64,{encoded_image}" |
992 | | - } |
993 | | - } |
994 | | - ] |
995 | | - return message |
996 | | - |
997 | | - |
998 | | -def has_tool_call(message: dict | ChatCompletionMessage) -> bool: |
999 | | - """Check if the message has a tool call. |
1000 | | - |
1001 | | - Parameters |
1002 | | - ---------- |
1003 | | - message : dict | ChatCompletionMessage |
1004 | | - A message in OpenAI-compatible format. |
1005 | | - |
1006 | | - Returns |
1007 | | - ------- |
1008 | | - """ |
1009 | | - message = to_dict(message) |
1010 | | - if "tool_calls" in message.keys(): |
1011 | | - return True |
1012 | | - else: |
1013 | | - return False |
1014 | | - |
1015 | | - |
1016 | | -def get_tool_call_info( |
1017 | | - message: dict | ChatCompletionMessage, |
1018 | | - index: Optional[int] = 0 |
1019 | | -) -> str | List[str]: |
1020 | | - """Get the tool call ID from the message. |
1021 | | -
|
1022 | | - Parameters |
1023 | | - ---------- |
1024 | | - message : dict | ChatCompletionMessage |
1025 | | - The message to get the tool call ID from. The message |
1026 | | - should be in OpenAI-compatible format. |
1027 | | - index : int, optional |
1028 | | - The index of the tool call to get the ID from. If None, |
1029 | | - all tool calls are returned as a list. |
1030 | | -
|
1031 | | - Returns |
1032 | | - ------- |
1033 | | - str | List[str] |
1034 | | - The tool call(s). |
1035 | | - """ |
1036 | | - message = to_dict(message) |
1037 | | - if index is None: |
1038 | | - return message["tool_calls"] |
1039 | | - else: |
1040 | | - return message["tool_calls"][index] |
1041 | | - |
1042 | | - |
1043 | | -def get_message_elements_as_text(message: Dict[str, Any]) -> Dict[str, Any]: |
1044 | | - """Get the elements of the message as human readable text. |
1045 | | -
|
1046 | | - Parameters |
1047 | | - ---------- |
1048 | | - message : Dict[str, Any] |
1049 | | - The message to get the elements from. The message |
1050 | | - should be in OpenAI-compatible format. |
1051 | | -
|
1052 | | - Returns |
1053 | | - ------- |
1054 | | - Dict[str, Any] |
1055 | | - The elements of the message. |
1056 | | - """ |
1057 | | - role = message["role"] |
1058 | | - |
1059 | | - image = None |
1060 | | - content = "" |
1061 | | - if "content" in message.keys(): |
1062 | | - if isinstance(message["content"], str): |
1063 | | - content += message["content"] + "\n" |
1064 | | - elif isinstance(message["content"], list): |
1065 | | - for item in message["content"]: |
1066 | | - if item["type"] == "text": |
1067 | | - content += item["text"] + "\n" |
1068 | | - elif item["type"] == "image_url": |
1069 | | - content += "<image> \n" |
1070 | | - image = item["image_url"]["url"] |
1071 | | - |
1072 | | - tool_calls = None |
1073 | | - if "tool_calls" in message.keys(): |
1074 | | - tool_calls = "" |
1075 | | - for tool_call in message["tool_calls"]: |
1076 | | - tool_calls += f"{tool_call['id']}: {tool_call['function']['name']}\n" |
1077 | | - tool_calls += f"Arguments: {tool_call['function']['arguments']}\n" |
1078 | | - |
1079 | | - return { |
1080 | | - "role": role, |
1081 | | - "content": content, |
1082 | | - "tool_calls": tool_calls, |
1083 | | - "image": image |
1084 | | - } |
1085 | | - |
1086 | | - |
1087 | | -def get_message_elements(message: Dict[str, Any]) -> Dict[str, Any]: |
1088 | | - """Get the elements of the message as a structured dictionary. |
1089 | | -
|
1090 | | - Parameters |
1091 | | - ---------- |
1092 | | - message : Dict[str, Any] |
1093 | | - The message to get the elements from. The message |
1094 | | - should be in OpenAI-compatible format. |
1095 | | -
|
1096 | | - Returns |
1097 | | - ------- |
1098 | | - Dict[str, Any] |
1099 | | - The elements of the message. |
1100 | | - """ |
1101 | | - role = message["role"] |
1102 | | - |
1103 | | - image = [] |
1104 | | - content = [] |
1105 | | - if "content" in message.keys(): |
1106 | | - if isinstance(message["content"], str): |
1107 | | - content.append(message["content"]) |
1108 | | - elif isinstance(message["content"], list): |
1109 | | - content = message["content"] |
1110 | | - for item in content: |
1111 | | - if item["type"] == "image_url": |
1112 | | - image.append(item["image_url"]["url"]) |
1113 | | - |
1114 | | - tool_calls = None |
1115 | | - if "tool_calls" in message.keys(): |
1116 | | - tool_calls = message["tool_calls"] |
1117 | | - |
1118 | | - tool_response_id = None |
1119 | | - if "tool_call_id" in message.keys(): |
1120 | | - tool_response_id = message["tool_call_id"] |
1121 | | - |
1122 | | - return { |
1123 | | - "role": role, |
1124 | | - "content": content, |
1125 | | - "tool_calls": tool_calls, |
1126 | | - "image": image, |
1127 | | - "tool_response_id": tool_response_id |
1128 | | - } |
1129 | | - |
1130 | | - |
1131 | | -def print_message( |
1132 | | - message: Dict[str, Any], |
1133 | | - response_requested: Optional[bool] = None, |
1134 | | - return_string: bool = False |
1135 | | -) -> None: |
1136 | | - """Print the message. |
1137 | | -
|
1138 | | - Parameters |
1139 | | - ---------- |
1140 | | - message : Dict[str, Any] |
1141 | | - The message to be printed. The message should be in |
1142 | | - OpenAI-compatible format. |
1143 | | - response_requested : bool, optional |
1144 | | - Whether a response is requested for the message. |
1145 | | - return_string : bool, optional |
1146 | | - If True, the message is returned as a string instead of printed. |
1147 | | - """ |
1148 | | - color_dict = { |
1149 | | - "user": "\033[94m", |
1150 | | - "system": "\033[92m", |
1151 | | - "tool": "\033[93m", |
1152 | | - "assistant": "\033[91m" |
1153 | | - } |
1154 | | - color = color_dict[message["role"]] |
1155 | | - |
1156 | | - text = f"[Role] {message['role']}\n" |
1157 | | - if response_requested is not None: |
1158 | | - text += f"[Response requested] {response_requested}\n" |
1159 | | - |
1160 | | - elements = get_message_elements_as_text(message) |
1161 | | - |
1162 | | - text += "[Content]\n" |
1163 | | - text += elements["content"] + "\n" |
1164 | | - |
1165 | | - if elements["tool_calls"] is not None: |
1166 | | - text += "[Tool call]\n" |
1167 | | - text += elements["tool_calls"] + "\n" |
1168 | | - |
1169 | | - text += "\n ========================================= \n" |
1170 | | - |
1171 | | - if return_string: |
1172 | | - return text |
1173 | | - else: |
1174 | | - print(f"{color}{text}\033[0m") |
0 commit comments