-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_chat.py
More file actions
321 lines (277 loc) · 13.8 KB
/
llm_chat.py
File metadata and controls
321 lines (277 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import sys
import logging
import yaml
import base64
import imghdr
from openai import AzureOpenAI, OpenAI
class LLMChat:
def __init__(self, model, config_path='./configs/models.yaml', logger=None):
self.config = self.load_config(config_path)
self.model_config = self.config['LLM_engines'][model]
self.model = self.model_config['model']
self.logger = logger if logger is not None else self._default_logger()
self.logger.info(f'[LLMChat] Initializing LLMChat with model: {model}')
self.client = self._initialize_client()
def load_config(self, path):
with open(path, 'r', encoding='utf-8') as file:
return yaml.safe_load(file)
def _initialize_client(self):
self.translate_to_cht = self.model_config.get('translate_to_cht', False)
if self.translate_to_cht:
from opencc import OpenCC
self.logger.info(f'[LLMChat] Translation to traditional Chinese enabled')
self.zn_converter = OpenCC('s2twp')
if 'gpt' in self.model and "oss" not in self.model:
self.logger.info(f'[LLMChat] Initializing AzureOpenAI client')
self.translate_to_cht = False
return AzureOpenAI(
azure_endpoint=self.model_config['azure_api_base'],
api_key=self.model_config['azure_api_key'],
api_version=self.model_config['azure_api_version']
)
else:
self.logger.info(f'[LLMChat] Initializing Local OpenAI client')
return OpenAI(
api_key=self.model_config['local_api_key'],
base_url=self.model_config['local_base_url']
)
def _default_logger(self):
logger = logging.getLogger('LLMChatLogger')
logger.addHandler(logging.NullHandler())
return logger
def initialize_history(self, system_message, user_message):
if system_message is None:
return [{'role': 'user', 'content': user_message}]
else:
return [
{'role': 'system', 'content': system_message},
{'role': 'user', 'content': user_message},
]
def chat(self, query, history=[], system='你是一個專業的助手,會用繁體中文回答問題。', params=None, response_format=None, stream=False, extra_body=None, multi_response=False, include_reasoning=False, tools=None, tool_choice=None):
if multi_response:
self.logger.info(f'[LLMChat] Multi-response mode enabled, You can pass n (in params) parameter to specify the number of responses to return.')
self.logger.info(f'[LLMChat] Stream mode disabled')
stream = False
if tools is not None and stream:
self.logger.info(f'[LLMChat] Tool calling mode enabled, Stream mode disabled')
stream = False
if not history:
self.logger.info(f'[LLMChat] Initializing history')
history = self.initialize_history(system, query)
else:
history.append({'role': 'user', 'content': query})
if params == None:
self.logger.info(f'[LLMChat] Using default parameters')
params = {
'temperature': self.config['params']['default']['temperature'],
'max_tokens': self.config['params']['default']['max_tokens'],
'top_p': self.config['params']['default']['top_p'],
'frequency_penalty': self.config['params']['default']['frequency_penalty'],
'presence_penalty': self.config['params']['default']['presence_penalty']
}
else:
params = params
if extra_body is not None:
params['extra_body'] = extra_body
# detect n in params
if params.get('n', 1) > 1 and not multi_response:
self.logger.info(f'[LLMChat] please turn on multi-response mode to get multi responses.')
completion_params = {
'model': self.model,
'stream': stream,
'messages': history,
**params
}
if response_format is not None:
completion_params['response_format'] = {'type': response_format}
if tools is not None:
completion_params['tools'] = tools
if tool_choice is not None:
completion_params['tool_choice'] = tool_choice
self.logger.info(f'[LLMChat] Tool calling enabled with {len(tools)} tools')
completion = self.client.chat.completions.create(**completion_params)
if stream:
return self._handle_stream_response(completion, include_reasoning=include_reasoning)
else:
response = self._handle_response(completion, multi_response=multi_response, include_reasoning=include_reasoning, tools_enabled=tools is not None)
return response, history
def _handle_response(self, completion, multi_response=False, include_reasoning=False, tools_enabled=False):
if multi_response:
responses = []
for choice in completion.choices:
response = choice.message.content
if include_reasoning:
reasoning_content = getattr(choice.message, 'reasoning_content', '')
if reasoning_content:
response = f"\n<think>\n{reasoning_content}\n</think>\n" + response
responses.append(self._maybe_translate(response))
return responses
else:
message = completion.choices[0].message
response = message.content
if include_reasoning:
reasoning_content = getattr(message, 'reasoning_content', '')
if reasoning_content:
response = f"\n<think>\n{reasoning_content}\n</think>\n" + response
if tools_enabled and hasattr(message, 'tool_calls') and message.tool_calls:
return {
'content': self._maybe_translate(response) if response else None,
'tool_calls': message.tool_calls
}
else:
return self._maybe_translate(response)
def _handle_stream_response(self, completion, include_reasoning=False):
reasoning_started = False
for chunk in completion:
try:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
content = delta.content
if content:
if reasoning_started:
content = "\n</think>\n" + content
reasoning_started = False
translated_content = self._maybe_translate(content)
yield translated_content
reasoning_content = getattr(delta, 'reasoning_content', None)
if reasoning_content and include_reasoning:
translated_reasoning_content = self._maybe_translate(reasoning_content)
if not reasoning_started:
translated_reasoning_content = "\n<think>\n" + translated_reasoning_content
reasoning_started = True
yield translated_reasoning_content
else:
yield translated_reasoning_content
except Exception as e:
self.logger.error(f"Streaming error: {e}")
def _maybe_translate(self, content):
if self.translate_to_cht:
return self.zn_converter.convert(content)
return content
def _image_to_base64(self, image_path):
"""Convert image file to base64 string."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def _detect_image_type(self, image_path):
"""Detect image type using imghdr."""
with open(image_path, "rb") as image_file:
return imghdr.what(image_file)
def prepare_image_base64(self, image_path):
"""
Prepare base64 encoded image with data URI prefix for reuse.
Args:
image_path (str): Path to the image file.
Returns:
str: Base64 encoded image with data URI prefix.
"""
try:
image_base64 = self._image_to_base64(image_path)
image_type = self._detect_image_type(image_path)
if image_type is None:
raise ValueError("Unsupported image type or file is not an image.")
return f"data:image/{image_type};base64,{image_base64}"
except Exception as e:
self.logger.error(f"Error preparing image base64: {e}")
raise
def vision_chat(self, query, image_path=None, image_base64_with_prefix=None, history=[], system='你是一個專業的助手,會用繁體中文回答問題。', params=None, response_format=None, stream=False):
"""
Process an image and get a response from the vision LLM.
Args:
query (str): The query to be sent to the chat completion API.
image_path (str, optional): Path to the image file.
image_base64_with_prefix (str, optional): Base64 encoded image with data URI prefix.
history (list): Conversation history.
system (str): System message.
params (dict): Parameters for the API call.
response_format (str): Response format type.
stream (bool): Whether to stream the response.
Returns:
tuple: (response, history) if not streaming, otherwise generator for streaming.
Note:
Either image_path or image_base64_with_prefix must be provided, but not both.
"""
# Validate input parameters
if image_path is None and image_base64_with_prefix is None:
raise ValueError("Either image_path or image_base64_with_prefix must be provided.")
if image_path is not None and image_base64_with_prefix is not None:
raise ValueError("Cannot provide both image_path and image_base64_with_prefix. Choose one.")
# Process image based on input type
if image_path is not None:
# Convert image file to base64
try:
image_base64 = self._image_to_base64(image_path)
image_type = self._detect_image_type(image_path)
if image_type is None:
raise ValueError("Unsupported image type or file is not an image.")
image_base64_with_prefix = f"data:image/{image_type};base64,{image_base64}"
except Exception as e:
self.logger.error(f"Error processing image file: {e}")
raise
else:
# Use provided base64 string directly
if not image_base64_with_prefix.startswith('data:image/'):
raise ValueError("image_base64_with_prefix must start with 'data:image/' prefix.")
self.logger.info(f'[LLMChat] Using provided base64 image data')
# Prepare message with image
user_message = {
"role": "user",
"content": [
{"type": "text", "text": query},
{"type": "image_url", "image_url": {"url": image_base64_with_prefix}},
],
}
if not history:
self.logger.info(f'[LLMChat] Initializing history for vision chat')
history = [{'role': 'system', 'content': system}]
history.append(user_message)
completion = self.client.chat.completions.create(
model=self.model,
stream=stream,
response_format={'type':response_format} if response_format is not None else None,
messages=history,
)
if stream:
return self._handle_stream_response(completion)
else:
response = self._handle_response(completion)
return response, history
# 使用範例
if __name__ == '__main__':
# llmchat = LLMChat(model='gpt-4o', config_path='./configs/models.yaml')
# llmchat = LLMChat(model='qwen2')
# llmchat = LLMChat(model='Qwen1.5-14B-Chat')
llmchat = LLMChat(model='Qwen2-7B-Instruct', config_path='./configs/models.yaml')
# llmchat = LLMChat(model='Qwen1.5-14B-Chat')
# params = {
# 'temperature': 0.8,
# 'max_tokens': 1000,
# 'top_p': 1,
# 'frequency_penalty': 1.4,
# 'presence_penalty': 0
# }
# try:
# query = '''
# 任務:解讀主題與內容
# 請從下面的文字中解讀該段落的語意主題與內容,應該詳細閱讀資訊,並提取其語意主題與描述內容。
# 咱們的護國神山,台積電在今年四月的北美技術論壇上,發表埃米級的 A16 製程,預計於 2026 年開始量產,引發各大科技龍頭的瘋搶。
# 台積的 A16 製程更引入 SPR 系統,不只減少了 IR 降壓,相較於過去的 N2P 製程,能在相同性能下,減少最多 20% 的功耗,同時提升近 10% 的晶片密度。
# OpenAI 也希望得到 A16 製程的協助,打造屬於自家的 AI 晶片,用以強化旗下的 GPT 語言模型,以及影像生成模型 Sora。
# '''
# response, history = llmchat.chat(query=query, stream=True)
# print()
# print(response)
# except ValueError as e:
# print(e)
history = None
is_stream=True
while True:
input_text = input('請輸入你的問題: ')
if input_text.lower() == 'quit':
break
response, history = llmchat.chat(query=input_text, history=history, stream=is_stream)
# print(history)
if is_stream==True:
print()
else:
print(response)