From dd54174611d4f3537817d732dbff8d147c365278 Mon Sep 17 00:00:00 2001 From: ulleo Date: Tue, 23 Sep 2025 18:59:41 +0800 Subject: [PATCH] feat: support parse reasoning block --- backend/apps/chat/task/llm.py | 261 +++++++++++++++++++--------------- backend/common/core/config.py | 4 + 2 files changed, 154 insertions(+), 111 deletions(-) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 15b9f03a1..b88d04b53 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -6,7 +6,7 @@ import warnings from concurrent.futures import ThreadPoolExecutor, Future from datetime import datetime -from typing import Any, List, Optional, Union, Dict +from typing import Any, List, Optional, Union, Dict, Iterator import numpy as np import orjson @@ -259,22 +259,14 @@ def generate_analysis(self): in analysis_msg]) full_thinking_text = '' full_analysis_text = '' - res = self.llm.stream(analysis_msg) token_usage = {} + res = process_stream(self.llm.stream(analysis_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_analysis_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_analysis_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk analysis_msg.append(AIMessage(full_analysis_text)) @@ -311,22 +303,14 @@ def generate_predict(self): in predict_msg]) full_thinking_text = '' full_predict_text = '' - res = self.llm.stream(predict_msg) token_usage = {} + res = process_stream(self.llm.stream(predict_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_predict_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_predict_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk predict_msg.append(AIMessage(full_predict_text)) self.record = save_predict_answer(session=self.session, record_id=self.record.id, @@ -370,21 +354,13 @@ def generate_recommend_questions_task(self): full_thinking_text = '' full_guess_text = '' token_usage = {} - res = self.llm.stream(guess_msg) + res = process_stream(self.llm.stream(guess_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_guess_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_guess_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk guess_msg.append(AIMessage(full_guess_text)) @@ -450,21 +426,13 @@ def select_datasource(self): msg in datasource_msg]) token_usage = {} - res = self.llm.stream(datasource_msg) + res = process_stream(self.llm.stream(datasource_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk datasource_msg.append(AIMessage(full_text)) self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session, @@ -560,21 +528,13 @@ def generate_sql(self): full_thinking_text = '' full_sql_text = '' token_usage = {} - res = self.llm.stream(self.sql_message) + res = process_stream(self.llm.stream(self.sql_message), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_sql_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_sql_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk self.sql_message.append(AIMessage(full_sql_text)) @@ -607,18 +567,14 @@ def generate_with_sub_sql(self, sql, sub_mappings: list): full_thinking_text = '' full_dynamic_text = '' - res = self.llm.stream(dynamic_sql_msg) token_usage = {} + res = process_stream(self.llm.stream(dynamic_sql_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - full_dynamic_text += chunk.content - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_dynamic_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk dynamic_sql_msg.append(AIMessage(full_dynamic_text)) @@ -670,22 +626,13 @@ def build_table_filter(self, sql: str, filters: list): in permission_sql_msg]) full_thinking_text = '' full_filter_text = '' - res = self.llm.stream(permission_sql_msg) token_usage = {} + res = process_stream(self.llm.stream(permission_sql_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_filter_text += chunk.content - # yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_filter_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') permission_sql_msg.append(AIMessage(full_filter_text)) @@ -735,21 +682,13 @@ def generate_chart(self, chart_type: Optional[str] = ''): full_thinking_text = '' full_chart_text = '' token_usage = {} - res = self.llm.stream(self.chart_message) + res = process_stream(self.llm.stream(self.chart_message), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_chart_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_chart_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk self.chart_message.append(AIMessage(full_chart_text)) @@ -1053,7 +992,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, else: sql = self.check_save_sql(res=full_sql_text) - SQLBotLogUtil.info(sql) + SQLBotLogUtil.info('sql: ' + sql) if not stream: json_result['sql'] = sql @@ -1372,9 +1311,11 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict): return request_path -def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}): +def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = None): try: if chunk.usage_metadata: + if token_usage is None: + token_usage = {} token_usage['input_tokens'] = chunk.usage_metadata.get('input_tokens') token_usage['output_tokens'] = chunk.usage_metadata.get('output_tokens') token_usage['total_tokens'] = chunk.usage_metadata.get('total_tokens') @@ -1382,6 +1323,104 @@ def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}): pass +def process_stream(res: Iterator[BaseMessageChunk], + token_usage: Dict[str, Any] = None, + enable_tag_parsing: bool = settings.PARSE_REASONING_BLOCK_ENABLED, + start_tag: str = settings.DEFAULT_REASONING_CONTENT_START, + end_tag: str = settings.DEFAULT_REASONING_CONTENT_END + ): + if token_usage is None: + token_usage = {} + in_thinking_block = False # 标记是否在思考过程块中 + current_thinking = '' # 当前收集的思考过程内容 + pending_start_tag = '' # 用于缓存可能被截断的开始标签部分 + + for chunk in res: + SQLBotLogUtil.info(chunk) + reasoning_content_chunk = '' + content = chunk.content + output_content = '' # 实际要输出的内容 + + # 检查additional_kwargs中的reasoning_content + if 'reasoning_content' in chunk.additional_kwargs: + reasoning_content = chunk.additional_kwargs.get('reasoning_content', '') + if reasoning_content is None: + reasoning_content = '' + + # 累积additional_kwargs中的思考内容到current_thinking + current_thinking += reasoning_content + reasoning_content_chunk = reasoning_content + + # 只有当current_thinking不是空字符串时才跳过标签解析 + if not in_thinking_block and current_thinking.strip() != '': + output_content = content # 正常输出content + yield { + 'content': output_content, + 'reasoning_content': reasoning_content_chunk + } + get_token_usage(chunk, token_usage) + continue # 跳过后续的标签解析逻辑 + + # 如果没有有效的思考内容,并且启用了标签解析,才执行标签解析逻辑 + # 如果有缓存的开始标签部分,先拼接当前内容 + if pending_start_tag: + content = pending_start_tag + content + pending_start_tag = '' + + # 检查是否开始思考过程块(处理可能被截断的开始标签) + if enable_tag_parsing and not in_thinking_block and start_tag: + if start_tag in content: + start_idx = content.index(start_tag) + # 只有当开始标签前面没有其他文本时才认为是真正的思考块开始 + if start_idx == 0 or content[:start_idx].strip() == '': + # 完整标签存在且前面没有其他文本 + output_content += content[:start_idx] # 输出开始标签之前的内容 + content = content[start_idx + len(start_tag):] # 移除开始标签 + in_thinking_block = True + else: + # 开始标签前面有其他文本,不认为是思考块开始 + output_content += content + content = '' + else: + # 检查是否可能有部分开始标签 + for i in range(1, len(start_tag)): + if content.endswith(start_tag[:i]): + # 只有当当前内容全是空白时才缓存部分标签 + if content[:-i].strip() == '': + pending_start_tag = start_tag[:i] + content = content[:-i] # 移除可能的部分标签 + output_content += content + content = '' + break + + # 处理思考块内容 + if enable_tag_parsing and in_thinking_block and end_tag: + if end_tag in content: + # 找到结束标签 + end_idx = content.index(end_tag) + current_thinking += content[:end_idx] # 收集思考内容 + reasoning_content_chunk += current_thinking # 添加到当前块的思考内容 + content = content[end_idx + len(end_tag):] # 移除结束标签后的内容 + current_thinking = '' # 重置当前思考内容 + in_thinking_block = False + output_content += content # 输出结束标签之后的内容 + else: + # 在遇到结束标签前,持续收集思考内容 + current_thinking += content + reasoning_content_chunk += content + content = '' + + else: + # 不在思考块中或标签解析未启用,正常输出 + output_content += content + + yield { + 'content': output_content, + 'reasoning_content': reasoning_content_chunk + } + get_token_usage(chunk, token_usage) + + def get_lang_name(lang: str): if lang and lang == 'en': return '英文' diff --git a/backend/common/core/config.py b/backend/common/core/config.py index 64ff9242f..e528d1cfd 100644 --- a/backend/common/core/config.py +++ b/backend/common/core/config.py @@ -96,6 +96,10 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str: EMBEDDING_TERMINOLOGY_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT EMBEDDING_DATA_TRAINING_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT + PARSE_REASONING_BLOCK_ENABLED: bool = True + DEFAULT_REASONING_CONTENT_START: str = '' + DEFAULT_REASONING_CONTENT_END: str = '' + PG_POOL_SIZE: int = 20 PG_MAX_OVERFLOW: int = 30 PG_POOL_RECYCLE: int = 3600