Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# PostgreSQL 数据库文件
data/postgresql/

# Excel 文件存储
data/sqlbot/excel/

# 图片文件存储
data/sqlbot/images/

# 日志文件存储
data/sqlbot/logs/

.vscode
node_modules/
/test-results/
Expand Down
17 changes: 16 additions & 1 deletion backend/apps/ai_model/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,27 @@ def _init_llm(self) -> AzureChatOpenAI:
)
class OpenAILLM(BaseLLM):
def _init_llm(self) -> BaseChatModel:

params = {}
for key, value in self.config.additional_params.items():
if isinstance(value, str) and value.strip().startswith(('{', '[')):
try:
import json
parsed_value = json.loads(value)
params[key] = parsed_value
except json.JSONDecodeError as e:

params[key] = value
else:
params[key] = value

return BaseChatOpenAI(
model=self.config.model_name,
api_key=self.config.api_key or 'Empty',
base_url=self.config.api_base_url,
stream_usage=True,
**self.config.additional_params,
# **self.config.additional_params,
**params,
)

def generate(self, prompt: str) -> str:
Expand Down
174 changes: 167 additions & 7 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False,
config: LLMConfig = None):
self.chunk_list = []
self._retry_thinking_updates = [] # 存储重试时的thinking更新
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
session_maker = sessionmaker(bind=engine)
self.session = session_maker()
Expand Down Expand Up @@ -545,6 +546,9 @@ def generate_sql(self):
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
# 如果没有 reasoning_content,就用普通内容作为思考过程
if not reasoning_content_chunk and chunk.content:
reasoning_content_chunk = chunk.content
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk
Expand All @@ -561,8 +565,29 @@ def generate_sql(self):
for msg in self.sql_message],
reasoning_content=full_thinking_text,
token_usage=token_usage)
self.record = save_sql_answer(session=self.session, record_id=self.record.id,
answer=orjson.dumps({'content': full_sql_text}).decode())
# 如果没有思考内容,用完整输出作为思考过程
if not full_thinking_text.strip():
full_thinking_text = f"""SQL 生成过程:

用户问题:
{self.chat_question.question}

AI 完整输出:
{full_sql_text}
"""

# 保存思考过程到 sql_answer 字段
if full_thinking_text.strip():
save_sql_answer(session=self.session, record_id=self.record.id,
answer=full_thinking_text)
else:
# 原有逻辑
thinking_json = {
"reasoning_content": full_thinking_text,
"content": "SQL生成思考过程"
}
save_sql_answer(session=self.session, record_id=self.record.id,
answer=orjson.dumps(thinking_json).decode())

def generate_with_sub_sql(self, sql, sub_mappings: list):
sub_query = json.dumps(sub_mappings, ensure_ascii=False)
Expand Down Expand Up @@ -780,7 +805,126 @@ def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
return None

return chart_type
def validate_and_retry_sql(self, initial_sql: str, max_retries: int = 3, in_chat: bool = True):
"""验证 SQL 并在失败时重试生成新的 SQL"""
current_sql = initial_sql
retry_count = 0
retry_errors = []

while retry_count < max_retries:
try:
test_result = self.execute_sql(sql=current_sql)
yield current_sql # 改为yield返回成功的SQL
return

except Exception as sql_error:
retry_count += 1
error_msg = str(sql_error)
retry_errors.append(f"第{retry_count}次尝试: {error_msg}")

SQLBotLogUtil.warning(f"SQL 执行失败 (第{retry_count}次): {error_msg}")

if retry_count >= max_retries:
final_error_msg = f"SQL验证失败,共重试{retry_count}次:\n" + "\n".join(retry_errors)
save_error_message(
session=self.session,
record_id=self.record.id,
message=final_error_msg
)
raise sql_error

# 直接流式输出重试思考过程
retry_sql_res = self.regenerate_sql_with_error(current_sql, error_msg)
full_retry_text = ''

for chunk in retry_sql_res:
if isinstance(chunk, dict):
full_retry_text += chunk.get('content', '')
# 直接yield,不存储
if in_chat:
yield 'data:' + orjson.dumps({
'content': chunk.get('content'),
'reasoning_content': chunk.get('reasoning_content'),
'type': 'sql-retry-thinking'
}).decode() + '\n\n'
elif isinstance(chunk, str):
full_retry_text = chunk

try:
current_sql, _ = self.check_sql(full_retry_text)
SQLBotLogUtil.info(f"第{retry_count}次重试生成的SQL: {current_sql}")
except Exception as parse_error:
SQLBotLogUtil.error(f"重试生成的SQL解析失败: {str(parse_error)}")
continue

yield current_sql # 最后yield返回SQL

def regenerate_sql_with_error(self, failed_sql: str, error_message: str):
"""基于错误信息重新生成 SQL"""
fix_sql_msg = []
fix_sql_msg.append(SystemMessage(content=self.chat_question.sql_sys_question()))

fix_prompt = f"""
之前生成的 SQL 语句执行失败:

SQL: {failed_sql}
错误信息: {error_message}

请分析错误原因并生成修正后的 SQL 语句。常见错误类型:
1. 字段名错误 - 检查表结构中的实际字段名
2. 表名错误 - 确认表是否存在
3. 语法错误 - 检查 SQL 语法
4. 数据类型不匹配 - 检查字段类型

原始问题: {self.chat_question.question}
数据库结构: {self.chat_question.db_schema}
"""

fix_sql_msg.append(HumanMessage(content=fix_prompt))
SQLBotLogUtil.info(f"正在重新生成 SQL,原错误: {error_message}")

full_text = ''
collecting = False

res = self.llm.stream(fix_sql_msg)
for chunk in res:
full_text += chunk.content
content = chunk.content

# 从<think>开始收集内容
if not collecting and '<think>' in content:
collecting = True
start_pos = content.find('<think>')
content = content[start_pos:]

if collecting:
# 仿照原来的方式,返回包含content和reasoning_content的chunk
yield {
'content': content,
'reasoning_content': content
}

# 返回完整文本用于解析SQL
yield full_text


def check_save_sql_with_validation(self, res: str, in_chat: bool = True):
"""带验证的 SQL 检查和保存"""
sql, *_ = self.check_sql(res=res)

validation_result = self.validate_and_retry_sql(sql, in_chat=in_chat)
validated_sql = None

# 传递流式输出
for item in validation_result:
if isinstance(item, str) and item.startswith('data:'):
yield item # 传递流式输出
else:
validated_sql = item

save_sql(session=self.session, sql=validated_sql, record_id=self.record.id)
self.chat_question.sql = validated_sql
yield validated_sql # 返回最终SQL
def check_save_sql(self, res: str) -> str:
sql, *_ = self.check_sql(res=res)
save_sql(session=self.session, sql=sql, record_id=self.record.id)
Expand Down Expand Up @@ -1001,14 +1145,30 @@ def run_task(self, in_chat: bool = True):

if sql_result:
SQLBotLogUtil.info(sql_result)
sql = self.check_save_sql(res=sql_result)
# 处理生成器输出
for item in self.check_save_sql_with_validation(res=sql_result, in_chat=in_chat):
if isinstance(item, str) and item.startswith('data:'):
yield item # 传递流式输出
else:
sql = item # 最终SQL
elif dynamic_sql_result:
sql = self.check_save_sql(res=dynamic_sql_result)
for item in self.check_save_sql_with_validation(res=dynamic_sql_result, in_chat=in_chat):
if isinstance(item, str) and item.startswith('data:'):
yield item
else:
sql = item
else:
sql = self.check_save_sql(res=full_sql_text)
for item in self.check_save_sql_with_validation(res=full_sql_text, in_chat=in_chat):
if isinstance(item, str) and item.startswith('data:'):
yield item
else:
sql = item
else:
sql = self.check_save_sql(res=full_sql_text)

for item in self.check_save_sql_with_validation(res=full_sql_text, in_chat=in_chat):
if isinstance(item, str) and item.startswith('data:'):
yield item
else:
sql = item
SQLBotLogUtil.info(sql)
format_sql = sqlparse.format(sql, reindent=True)
if in_chat:
Expand Down
13 changes: 11 additions & 2 deletions backend/apps/system/schemas/ai_model_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

from typing import List
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
import json

from common.core.schemas import BaseCreatorDTO

Expand All @@ -19,7 +20,15 @@ class AiModelConfigItem(BaseModel):
key: str
val: object
name: str

@field_validator('val')
@classmethod
def parse_json_strings(cls, v):
if isinstance(v, str) and v.strip().startswith(('{', '[')):
try:
return json.loads(v)
except json.JSONDecodeError:
pass
return v
class AiModelCreator(AiModelItem):
api_domain: str
api_key: str
Expand Down
1 change: 1 addition & 0 deletions backend/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def prepare_for_orjson(data):
def prepare_model_arg(origin_arg: str):
if not isinstance(origin_arg, str):
return origin_arg
origin_arg = str(origin_arg).strip()
if not origin_arg.strip()[0] in {'{', '['}:
return origin_arg
try:
Expand Down
4 changes: 4 additions & 0 deletions backend/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ template:
<rule>
生成的SQL必须符合<db-engine>内提供数据库引擎的规范
</rule>
<rule>
对字符串字段的查询,一律使用模糊匹配: `ILIKE '%关键词%'`
只有当用户明确说明“精确匹配”时,才使用 `=`
</rule>
<rule>
若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句
</rule>
Expand Down
26 changes: 23 additions & 3 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ services:
networks:
- sqlbot-network
ports:
- 8000:8000
- 9090:8000
- 8001:8001
environment:
# Database configuration
Expand All @@ -23,7 +23,7 @@ services:
# Auth & Security
SECRET_KEY: y5txe1mRmS_JpOrUzFzHEu-kIQn3lf7ll0AOv9DQh0s
# CORS settings
BACKEND_CORS_ORIGINS: "http://localhost,http://localhost:5173,https://localhost,https://localhost:5173"
BACKEND_CORS_ORIGINS: "http://localhost,http://localhost:5173,http://sqlbot-frontend-dev:5173,https://localhost,https://localhost:5173"
# Logging
LOG_LEVEL: "INFO"
SQL_DEBUG: False
Expand All @@ -32,6 +32,26 @@ services:
- ./data/sqlbot/images:/opt/sqlbot/images
- ./data/sqlbot/logs:/opt/sqlbot/logs
- ./data/postgresql:/var/lib/postgresql/data

- ./backend/apps/ai_model/model_factory.py:/opt/sqlbot/app/apps/ai_model/model_factory.py
- ./backend/apps:/opt/sqlbot/app/apps
- ./backend/common:/opt/sqlbot/app/common
- ./backend/main.py:/opt/sqlbot/app/main.py
- ./frontend/dist:/opt/sqlbot/frontend/dist
frontend:
image: node:18
container_name: sqlbot-frontend-dev
working_dir: /opt/sqlbot/frontend
volumes:
- ./frontend:/opt/sqlbot/frontend # 挂载源码
- /opt/sqlbot/frontend/node_modules # 避免宿主机覆盖 node_modules
ports:
- "5173:5173" # Vue3 Vite 默认端口
command: sh -c "npm install && npm run dev -- --host 0.0.0.0"
stdin_open: true
tty: true
networks:
- sqlbot-network
depends_on:
- sqlbot
networks:
sqlbot-network:
2 changes: 1 addition & 1 deletion frontend/.env.development
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VITE_API_BASE_URL=http://localhost:8000/api/v1
VITE_API_BASE_URL=http://localhost:9090/api/v1
VITE_APP_TITLE=SQLBot (Development)
2 changes: 1 addition & 1 deletion frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"@types/crypto-js": "^4.2.2",
"@types/element-resize-detector": "^1.1.6",
"@types/markdown-it": "^14.1.2",
"@types/node": "^22.14.1",
"@types/node": "^22.18.4",
"@typescript-eslint/eslint-plugin": "^8.34.0",
"@typescript-eslint/parser": "^8.34.0",
"@vitejs/plugin-vue": "^5.2.2",
Expand Down
1 change: 1 addition & 0 deletions frontend/src/api/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export class ChatRecord {
recommended_question?: string
analysis_record_id?: number
predict_record_id?: number
sql_retry_thinking?: string // 添加SQL重试思考字段

constructor()
constructor(
Expand Down
1 change: 1 addition & 0 deletions frontend/src/i18n/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
"thinking": "Thinking",
"thinking_step": "Thought Process",
"ask_again": "Regenerate",
"retry_thinking": "Rethink",
"today": "Today",
"week": "This Week",
"earlier": "Earlier",
Expand Down
3 changes: 2 additions & 1 deletion frontend/src/i18n/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,10 @@
"data_analysis": "数据分析",
"data_predict": "数据预测",
"chat_search": "搜索",
"thinking": "思考中",
"thinking": "思考中...🔥",
"thinking_step": "思考过程",
"ask_again": "重新生成",
"retry_thinking": "重新思考",
"today": "今天",
"week": "7天内",
"earlier": "更早以前",
Expand Down
Loading