-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathasync_llm_chat.py
More file actions
318 lines (270 loc) · 13.6 KB
/
async_llm_chat.py
File metadata and controls
318 lines (270 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import sys
import os
import configparser
import logging
import yaml
import base64
import imghdr
from openai import AsyncAzureOpenAI, AsyncOpenAI
from .llm_response_cache import LLMResponseCache
class AsyncLLMChat:
def __init__(self, model, config_path='./configs/models.yaml', logger=None, cache_config=None):
self.config = self.load_config(config_path)
self.model_config = self.config['LLM_engines'][model]
self.model = self.model_config['model']
self.logger = logger if logger is not None else self._default_logger()
self.logger.info(f'[LLMChat] Initializing LLMChat with model: {model}')
self.client = self._initialize_client()
self.enable_cache = cache_config.get('enable', True) if cache_config else False
if self.enable_cache:
cache_file = cache_config.get('cache_file', './llm_cache.json')
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
self.cache = LLMResponseCache(cache_file=cache_file)
self.logger.info(f'[LLMChat] LLM response caching enabled. Cache file: {cache_file}')
def load_config(self, path):
with open(path, 'r', encoding='utf-8') as file:
return yaml.safe_load(file)
def _initialize_client(self):
self.translate_to_cht = self.model_config.get('translate_to_cht', False)
if self.translate_to_cht:
from opencc import OpenCC
self.logger.info(f'[LLMChat] Translation to traditional Chinese enabled')
self.zn_converter = OpenCC('s2twp')
if 'gpt' in self.model and "oss" not in self.model:
self.logger.info(f'[LLMChat] Initializing AzureOpenAI client')
return AsyncAzureOpenAI(
azure_endpoint=self.model_config['azure_api_base'],
api_key=self.model_config['azure_api_key'],
api_version=self.model_config['azure_api_version']
)
else:
self.logger.info(f'[LLMChat] Initializing Local OpenAI client')
return AsyncOpenAI(
api_key=self.model_config['local_api_key'],
base_url=self.model_config['local_base_url']
)
def _default_logger(self):
logger = logging.getLogger("LLMChatLogger")
logger.addHandler(logging.NullHandler())
return logger
def initialize_history(self, system_message, user_message):
if system_message is None:
return [{"role": "user", "content": user_message}]
else:
return [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
async def chat(self, query, history=[], system='你是一個專業的助手,會用繁體中文回答問題。', params=None, response_format=None, stream=False, extra_body=None, multi_response=False, include_reasoning=False, tools=None, tool_choice=None):
if multi_response:
self.logger.info(f'[LLMChat] Multi-response mode enabled, You can pass n (in params) parameter to specify the number of responses to return.')
self.logger.info(f'[LLMChat] Stream mode disabled')
stream = False
if tools is not None and stream:
self.logger.info(f'[LLMChat] Tool calling mode enabled, Stream mode disabled')
stream = False
if not history:
self.logger.info(f'[LLMChat] Initializing history')
history = self.initialize_history(system, query)
else:
history.append({'role': 'user', 'content': query})
if params is None:
self.logger.info(f'[LLMChat] Using default parameters')
params = {
'temperature': self.config['params']['default']['temperature'],
'max_tokens': self.config['params']['default']['max_tokens'],
'top_p': self.config['params']['default']['top_p'],
'frequency_penalty': self.config['params']['default']['frequency_penalty'],
'presence_penalty': self.config['params']['default']['presence_penalty']
}
else:
params = params
if extra_body is not None:
params['extra_body'] = extra_body
# detect n in params
if params.get('n', 1) > 1 and not multi_response:
self.logger.info(f'[LLMChat] please turn on multi-response mode to get multi responses.')
cache_key = self.cache.make_key(self.model, history, params, tools=tools) if self.enable_cache else None
if self.enable_cache:
cached = await self.cache.get(cache_key)
if cached:
self.logger.info(f'[LLMChat] Cache hit for key: {cache_key}')
return cached['return'], history
completion_params = {
'model': self.model,
'stream': stream,
'messages': history,
**params
}
if response_format is not None:
completion_params['response_format'] = {'type': response_format}
if tools is not None:
completion_params['tools'] = tools
if tool_choice is not None:
completion_params['tool_choice'] = tool_choice
self.logger.info(f'[LLMChat] Tool calling enabled with {len(tools)} tools')
completion = await self.client.chat.completions.create(**completion_params)
if stream:
return self._handle_stream_response(completion, include_reasoning=include_reasoning)
response = await self._handle_response(completion, multi_response=multi_response, include_reasoning=include_reasoning, tools_enabled=tools is not None)
if self.enable_cache:
await self.cache.set(
cache_key,
result=response,
model=self.model
)
return response, history
async def _handle_response(self, completion, multi_response=False, include_reasoning=False, tools_enabled=False):
if multi_response:
responses = []
for choice in completion.choices:
response = choice.message.content
if include_reasoning:
reasoning_content = getattr(choice.message, 'reasoning_content', '')
if reasoning_content:
response = f"\n<think>\n{reasoning_content}\n</think>\n" + response
responses.append(self._maybe_translate(response))
return responses
else:
message = completion.choices[0].message
response = message.content
if include_reasoning:
reasoning_content = getattr(message, 'reasoning_content', '')
if reasoning_content:
response = f"\n<think>\n{reasoning_content}\n</think>\n" + response
if tools_enabled and hasattr(message, 'tool_calls') and message.tool_calls:
return {
'content': self._maybe_translate(response) if response else None,
'tool_calls': message.tool_calls
}
else:
return self._maybe_translate(response)
async def _handle_stream_response(self, completion, include_reasoning=False):
reasoning_started = False
async for chunk in completion:
try:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
content = delta.content
if content:
if reasoning_started:
content = "\n</think>\n" + content
reasoning_started = False
translated_content = self._maybe_translate(content)
yield translated_content
reasoning_content = getattr(delta, 'reasoning_content', None)
if reasoning_content and include_reasoning:
translated_reasoning_content = self._maybe_translate(reasoning_content)
if not reasoning_started:
translated_reasoning_content = "\n<think>\n" + translated_reasoning_content
reasoning_started = True
yield translated_reasoning_content
else:
yield translated_reasoning_content
except Exception as e:
self.logger.error(f"Streaming error: {e}")
def _maybe_translate(self, content):
if self.translate_to_cht:
return self.zn_converter.convert(content)
return content
def _image_to_base64(self, image_path):
"""Convert image file to base64 string."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def _detect_image_type(self, image_path):
"""Detect image type using imghdr."""
with open(image_path, "rb") as image_file:
return imghdr.what(image_file)
def prepare_image_base64(self, image_path):
"""
Prepare base64 encoded image with data URI prefix for reuse.
Args:
image_path (str): Path to the image file.
Returns:
str: Base64 encoded image with data URI prefix.
"""
try:
image_base64 = self._image_to_base64(image_path)
image_type = self._detect_image_type(image_path)
if image_type is None:
raise ValueError("Unsupported image type or file is not an image.")
return f"data:image/{image_type};base64,{image_base64}"
except Exception as e:
self.logger.error(f"Error preparing image base64: {e}")
raise
async def vision_chat(self, query, image_path=None, image_base64_with_prefix=None, history=[], system='你是一個專業的助手,會用繁體中文回答問題。', params=None, response_format=None, stream=False, extra_body=None):
"""
Process an image and get a response from the vision LLM asynchronously.
Args:
query (str): The query to be sent to the chat completion API.
image_path (str, optional): Path to the image file.
image_base64_with_prefix (str, optional): Base64 encoded image with data URI prefix.
history (list): Conversation history.
system (str): System message.
params (dict): Parameters for the API call.
response_format (str): Response format type.
stream (bool): Whether to stream the response.
extra_body (dict): Extra body parameters.
Returns:
tuple: (response, history) if not streaming, otherwise async generator for streaming.
Note:
Either image_path or image_base64_with_prefix must be provided, but not both.
"""
# Validate input parameters
if image_path is None and image_base64_with_prefix is None:
raise ValueError("Either image_path or image_base64_with_prefix must be provided.")
if image_path is not None and image_base64_with_prefix is not None:
raise ValueError("Cannot provide both image_path and image_base64_with_prefix. Choose one.")
# Process image based on input type
if image_path is not None:
# Convert image file to base64
try:
image_base64 = self._image_to_base64(image_path)
image_type = self._detect_image_type(image_path)
if image_type is None:
raise ValueError("Unsupported image type or file is not an image.")
image_base64_with_prefix = f"data:image/{image_type};base64,{image_base64}"
except Exception as e:
self.logger.error(f"Error processing image file: {e}")
raise
else:
# Use provided base64 string directly
if not image_base64_with_prefix.startswith('data:image/'):
raise ValueError("image_base64_with_prefix must start with 'data:image/' prefix.")
self.logger.info(f'[LLMChat] Using provided base64 image data')
# Prepare message with image
user_message = {
"role": "user",
"content": [
{"type": "text", "text": query},
{"type": "image_url", "image_url": {"url": image_base64_with_prefix}},
],
}
if not history:
self.logger.info(f'[LLMChat] Initializing history for vision chat')
history = [{'role': 'system', 'content': system}]
history.append(user_message)
cache_key = self.cache.make_key(self.model, history, params) if self.enable_cache else None
# Check cache
if self.enable_cache:
cached = await self.cache.get(cache_key)
if cached:
self.logger.info(f'[LLMChat] Cache hit for key: {cache_key}')
return cached['return'], history
completion = await self.client.chat.completions.create(
model=self.model,
stream=stream,
response_format={'type':response_format} if response_format is not None else None,
messages=history,
)
if stream:
return self._handle_stream_response(completion)
response = await self._handle_response(completion)
if self.enable_cache:
await self.cache.set(
cache_key,
result=response,
model=self.model
)
return response, history