Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 150 additions & 111 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from concurrent.futures import ThreadPoolExecutor, Future
from datetime import datetime
from typing import Any, List, Optional, Union, Dict
from typing import Any, List, Optional, Union, Dict, Iterator

import numpy as np
import orjson
Expand Down Expand Up @@ -259,22 +259,14 @@ def generate_analysis(self):
in analysis_msg])
full_thinking_text = ''
full_analysis_text = ''
res = self.llm.stream(analysis_msg)
token_usage = {}
res = process_stream(self.llm.stream(analysis_msg), token_usage)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk

full_analysis_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
if chunk.get('content'):
full_analysis_text += chunk.get('content')
if chunk.get('reasoning_content'):
full_thinking_text += chunk.get('reasoning_content')
yield chunk

analysis_msg.append(AIMessage(full_analysis_text))

Expand Down Expand Up @@ -311,22 +303,14 @@ def generate_predict(self):
in predict_msg])
full_thinking_text = ''
full_predict_text = ''
res = self.llm.stream(predict_msg)
token_usage = {}
res = process_stream(self.llm.stream(predict_msg), token_usage)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk

full_predict_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
if chunk.get('content'):
full_predict_text += chunk.get('content')
if chunk.get('reasoning_content'):
full_thinking_text += chunk.get('reasoning_content')
yield chunk

predict_msg.append(AIMessage(full_predict_text))
self.record = save_predict_answer(session=self.session, record_id=self.record.id,
Expand Down Expand Up @@ -370,21 +354,13 @@ def generate_recommend_questions_task(self):
full_thinking_text = ''
full_guess_text = ''
token_usage = {}
res = self.llm.stream(guess_msg)
res = process_stream(self.llm.stream(guess_msg), token_usage)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk

full_guess_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
if chunk.get('content'):
full_guess_text += chunk.get('content')
if chunk.get('reasoning_content'):
full_thinking_text += chunk.get('reasoning_content')
yield chunk

guess_msg.append(AIMessage(full_guess_text))

Expand Down Expand Up @@ -450,21 +426,13 @@ def select_datasource(self):
msg in datasource_msg])

token_usage = {}
res = self.llm.stream(datasource_msg)
res = process_stream(self.llm.stream(datasource_msg), token_usage)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk

full_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
if chunk.get('content'):
full_text += chunk.get('content')
if chunk.get('reasoning_content'):
full_thinking_text += chunk.get('reasoning_content')
yield chunk
datasource_msg.append(AIMessage(full_text))

self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
Expand Down Expand Up @@ -560,21 +528,13 @@ def generate_sql(self):
full_thinking_text = ''
full_sql_text = ''
token_usage = {}
res = self.llm.stream(self.sql_message)
res = process_stream(self.llm.stream(self.sql_message), token_usage)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk

full_sql_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
if chunk.get('content'):
full_sql_text += chunk.get('content')
if chunk.get('reasoning_content'):
full_thinking_text += chunk.get('reasoning_content')
yield chunk

self.sql_message.append(AIMessage(full_sql_text))

Expand Down Expand Up @@ -607,18 +567,14 @@ def generate_with_sub_sql(self, sql, sub_mappings: list):

full_thinking_text = ''
full_dynamic_text = ''
res = self.llm.stream(dynamic_sql_msg)
token_usage = {}
res = process_stream(self.llm.stream(dynamic_sql_msg), token_usage)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk
full_dynamic_text += chunk.content
get_token_usage(chunk, token_usage)
if chunk.get('content'):
full_dynamic_text += chunk.get('content')
if chunk.get('reasoning_content'):
full_thinking_text += chunk.get('reasoning_content')
yield chunk

dynamic_sql_msg.append(AIMessage(full_dynamic_text))

Expand Down Expand Up @@ -670,22 +626,13 @@ def build_table_filter(self, sql: str, filters: list):
in permission_sql_msg])
full_thinking_text = ''
full_filter_text = ''
res = self.llm.stream(permission_sql_msg)
token_usage = {}
res = process_stream(self.llm.stream(permission_sql_msg), token_usage)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk

full_filter_text += chunk.content
# yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
if chunk.get('content'):
full_filter_text += chunk.get('content')
if chunk.get('reasoning_content'):
full_thinking_text += chunk.get('reasoning_content')

permission_sql_msg.append(AIMessage(full_filter_text))

Expand Down Expand Up @@ -735,21 +682,13 @@ def generate_chart(self, chart_type: Optional[str] = ''):
full_thinking_text = ''
full_chart_text = ''
token_usage = {}
res = self.llm.stream(self.chart_message)
res = process_stream(self.llm.stream(self.chart_message), token_usage)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk

full_chart_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
if chunk.get('content'):
full_chart_text += chunk.get('content')
if chunk.get('reasoning_content'):
full_thinking_text += chunk.get('reasoning_content')
yield chunk

self.chart_message.append(AIMessage(full_chart_text))

Expand Down Expand Up @@ -1053,7 +992,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
else:
sql = self.check_save_sql(res=full_sql_text)

SQLBotLogUtil.info(sql)
SQLBotLogUtil.info('sql: ' + sql)

if not stream:
json_result['sql'] = sql
Expand Down Expand Up @@ -1372,16 +1311,116 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
return request_path


def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}):
def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = None):
try:
if chunk.usage_metadata:
if token_usage is None:
token_usage = {}
token_usage['input_tokens'] = chunk.usage_metadata.get('input_tokens')
token_usage['output_tokens'] = chunk.usage_metadata.get('output_tokens')
token_usage['total_tokens'] = chunk.usage_metadata.get('total_tokens')
except Exception:
pass


def process_stream(res: Iterator[BaseMessageChunk],
token_usage: Dict[str, Any] = None,
enable_tag_parsing: bool = settings.PARSE_REASONING_BLOCK_ENABLED,
start_tag: str = settings.DEFAULT_REASONING_CONTENT_START,
end_tag: str = settings.DEFAULT_REASONING_CONTENT_END
):
if token_usage is None:
token_usage = {}
in_thinking_block = False # 标记是否在思考过程块中
current_thinking = '' # 当前收集的思考过程内容
pending_start_tag = '' # 用于缓存可能被截断的开始标签部分

for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
content = chunk.content
output_content = '' # 实际要输出的内容

# 检查additional_kwargs中的reasoning_content
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content = chunk.additional_kwargs.get('reasoning_content', '')
if reasoning_content is None:
reasoning_content = ''

# 累积additional_kwargs中的思考内容到current_thinking
current_thinking += reasoning_content
reasoning_content_chunk = reasoning_content

# 只有当current_thinking不是空字符串时才跳过标签解析
if not in_thinking_block and current_thinking.strip() != '':
output_content = content # 正常输出content
yield {
'content': output_content,
'reasoning_content': reasoning_content_chunk
}
get_token_usage(chunk, token_usage)
continue # 跳过后续的标签解析逻辑

# 如果没有有效的思考内容,并且启用了标签解析,才执行标签解析逻辑
# 如果有缓存的开始标签部分,先拼接当前内容
if pending_start_tag:
content = pending_start_tag + content
pending_start_tag = ''

# 检查是否开始思考过程块(处理可能被截断的开始标签)
if enable_tag_parsing and not in_thinking_block and start_tag:
if start_tag in content:
start_idx = content.index(start_tag)
# 只有当开始标签前面没有其他文本时才认为是真正的思考块开始
if start_idx == 0 or content[:start_idx].strip() == '':
# 完整标签存在且前面没有其他文本
output_content += content[:start_idx] # 输出开始标签之前的内容
content = content[start_idx + len(start_tag):] # 移除开始标签
in_thinking_block = True
else:
# 开始标签前面有其他文本,不认为是思考块开始
output_content += content
content = ''
else:
# 检查是否可能有部分开始标签
for i in range(1, len(start_tag)):
if content.endswith(start_tag[:i]):
# 只有当当前内容全是空白时才缓存部分标签
if content[:-i].strip() == '':
pending_start_tag = start_tag[:i]
content = content[:-i] # 移除可能的部分标签
output_content += content
content = ''
break

# 处理思考块内容
if enable_tag_parsing and in_thinking_block and end_tag:
if end_tag in content:
# 找到结束标签
end_idx = content.index(end_tag)
current_thinking += content[:end_idx] # 收集思考内容
reasoning_content_chunk += current_thinking # 添加到当前块的思考内容
content = content[end_idx + len(end_tag):] # 移除结束标签后的内容
current_thinking = '' # 重置当前思考内容
in_thinking_block = False
output_content += content # 输出结束标签之后的内容
else:
# 在遇到结束标签前,持续收集思考内容
current_thinking += content
reasoning_content_chunk += content
content = ''

else:
# 不在思考块中或标签解析未启用,正常输出
output_content += content

yield {
'content': output_content,
'reasoning_content': reasoning_content_chunk
}
get_token_usage(chunk, token_usage)


def get_lang_name(lang: str):
if lang and lang == 'en':
return '英文'
Expand Down
4 changes: 4 additions & 0 deletions backend/common/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
EMBEDDING_TERMINOLOGY_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT
EMBEDDING_DATA_TRAINING_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT

PARSE_REASONING_BLOCK_ENABLED: bool = True
DEFAULT_REASONING_CONTENT_START: str = '<think>'
DEFAULT_REASONING_CONTENT_END: str = '</think>'

PG_POOL_SIZE: int = 20
PG_MAX_OVERFLOW: int = 30
PG_POOL_RECYCLE: int = 3600
Expand Down