Skip to content

Commit 6d41d85

Browse files
committed
fix: Fix SQL injection / LLM Prompt Injection vulnerability causing unauthorized queries
Security hardening: - Add SQLBOT_ALLOW_METADATA_QUERIES config option, disable SHOW/DESCRIBE/EXPLAIN by default - Add table whitelist check, use sqlglot to parse actual SQL table names and compare with authorized table list - Add dangerous function check, block LOAD_FILE, INTO OUTFILE, EXEC etc. by database type - Improve check_sql_read to return specific error reasons for better debugging
1 parent 82b1028 commit 6d41d85

4 files changed

Lines changed: 141 additions & 17 deletions

File tree

backend/apps/chat/task/llm.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import pandas as pd
1414
import requests
1515
import sqlparse
16+
import sqlglot
17+
from sqlglot import exp
1618
from langchain.chat_models.base import BaseChatModel
1719
from langchain_community.utilities import SQLDatabase
1820
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, BaseMessageChunk
@@ -40,7 +42,7 @@
4042
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
4143
from apps.datasource.embedding.ds_embedding import get_ds_embedding
4244
from apps.datasource.models.datasource import CoreDatasource
43-
from apps.db.db import exec_sql, get_version, check_connection
45+
from apps.db.db import exec_sql, get_version, check_connection, get_sqlglot_dialect
4446
from apps.system.crud.aimodel_manage import get_ai_model_list_by_workspace
4547
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
4648
from apps.system.crud.parameter_manage import get_groups
@@ -66,6 +68,23 @@
6668
i18n = I18n()
6769

6870

71+
72+
def extract_tables_from_sql(sql: str, ds_type: str = None) -> set:
73+
"""从 SQL 中提取表名(使用 sqlglot 解析,可信)"""
74+
tables = set()
75+
dialect = get_sqlglot_dialect(ds_type)
76+
try:
77+
statements = sqlglot.parse(sql, dialect=dialect)
78+
for stmt in statements:
79+
if stmt:
80+
for table in stmt.find_all(exp.Table):
81+
if table.name:
82+
tables.add(table.name)
83+
except Exception:
84+
pass
85+
return tables
86+
87+
6988
class LLMService:
7089
ds: CoreDatasource
7190
chat_question: ChatQuestion
@@ -106,6 +125,9 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
106125
self.chunk_list = []
107126
self.current_user = current_user
108127
self.current_assistant = current_assistant
128+
129+
self.table_name_list = []
130+
109131
chat_id = chat_question.chat_id
110132
chat: Chat | None = session.get(Chat, chat_id)
111133
if not chat:
@@ -222,7 +244,7 @@ def is_running(self, timeout=0.5):
222244

223245
def init_messages(self, session: Session):
224246

225-
self.choose_table_schema(session)
247+
self.table_name_list = self.choose_table_schema(session)
226248

227249
last_sql_messages: List[dict[str, Any]] = self.generate_sql_logs[-1].messages if len(
228250
self.generate_sql_logs) > 0 else []
@@ -404,6 +426,7 @@ def choose_table_schema(self, _session: Session):
404426
self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session,
405427
log=self.current_logs[OperationEnum.CHOOSE_TABLE],
406428
full_message=self.chat_question.db_schema)
429+
return tables
407430

408431
def generate_analysis(self, _session: Session):
409432
fields = self.get_fields_from_chart(_session)
@@ -1266,6 +1289,22 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
12661289

12671290
sql_operate = OperationEnum.GENERATE_SQL
12681291
sql, tables = self.check_sql(session=_session, res=full_sql_text, operate=sql_operate)
1292+
1293+
# 表名安全检查:用 sqlglot 解析真实 SQL,不信任 AI 返回的 tables
1294+
actual_tables = extract_tables_from_sql(sql, ds_type=self.ds.type)
1295+
if not actual_tables:
1296+
raise SingleMessageError(
1297+
"SQL parsing failed: unable to extract table names. "
1298+
"This may indicate an unsupported SQL syntax or a security issue."
1299+
)
1300+
allowed_tables = set(self.table_name_list)
1301+
unauthorized_tables = actual_tables - allowed_tables
1302+
if unauthorized_tables:
1303+
raise SingleMessageError(
1304+
f"SQL contains unauthorized tables: {', '.join(unauthorized_tables)}. "
1305+
f"Allowed tables: {', '.join(allowed_tables)}"
1306+
)
1307+
12691308
if ((not self.current_assistant or is_page_embedded) and is_normal_user(
12701309
self.current_user)) or use_dynamic_ds:
12711310
sql_result = None

backend/apps/db/db.py

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,9 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
586586
while sql.endswith(';'):
587587
sql = sql[:-1]
588588
# check execute sql only contain read operations
589-
if not check_sql_read(sql, ds):
590-
raise ValueError(f"SQL can only contain read operations")
589+
is_safe, error_reason = check_sql_read(sql, ds)
590+
if not is_safe:
591+
raise ValueError(f"SQL can only contain read operations: {error_reason}")
591592

592593
db = DB.get_db(ds.type)
593594
if db.connect_type == ConnectType.sqlalchemy:
@@ -716,11 +717,78 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
716717
raise ParseSQLResultError(str(ex))
717718

718719

719-
def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
720+
def get_sqlglot_dialect(ds_type: str) -> str:
721+
"""根据数据源类型获取 sqlglot dialect"""
722+
if equals_ignore_case(ds_type, 'mysql', 'doris', 'starrocks'):
723+
return 'mysql'
724+
elif equals_ignore_case(ds_type, 'sqlServer'):
725+
return 'tsql'
726+
elif equals_ignore_case(ds_type, 'hive'):
727+
return 'hive'
728+
return None
729+
730+
731+
# 通用危险函数(适用于所有数据库)
732+
COMMON_DANGEROUS_FUNCTIONS = {'version', 'current_user', 'user', 'database'}
733+
734+
# 特定数据库的危险函数
735+
DS_SPECIFIC_DANGEROUS_FUNCTIONS = {
736+
'mysql': {'LOAD_FILE', 'INTO OUTFILE', 'INTO DUMPFILE'},
737+
'doris': {'LOAD_FILE', 'INTO OUTFILE', 'INTO DUMPFILE'},
738+
'starrocks': {'LOAD_FILE', 'INTO OUTFILE', 'INTO DUMPFILE'},
739+
'postgresql': {'pg_read_file', 'pg_write_file', 'lo_import', 'lo_export'},
740+
'sqlserver': {'EXEC', 'xp_cmdshell', 'sp_executesql'},
741+
'oracle': {'UTL_FILE', 'DBMS_PIPE', 'DBMS_LOCK'},
742+
'hive': {'ADD FILE', 'ADD JAR'},
743+
}
744+
745+
# 危险模式正则表达式(用于检查特殊语法)
746+
import re
747+
DANGEROUS_PATTERNS = [
748+
r'\bINTO\s+OUTFILE\b',
749+
r'\bINTO\s+DUMPFILE\b',
750+
r'\bEXEC\s*\(',
751+
r'\bCOPY\s+.*\bTO\s+PROGRAM\b',
752+
]
753+
754+
755+
def get_dangerous_functions(ds_type: str) -> set:
756+
"""获取危险函数(通用 + 特定数据源)"""
757+
functions = COMMON_DANGEROUS_FUNCTIONS.copy()
758+
ds_key = ds_type.lower() if ds_type else ''
759+
if ds_key in DS_SPECIFIC_DANGEROUS_FUNCTIONS:
760+
functions.update(DS_SPECIFIC_DANGEROUS_FUNCTIONS[ds_key])
761+
return functions
762+
763+
764+
def check_dangerous_functions(statements: list, ds_type: str) -> bool:
765+
"""检查是否使用了危险函数,返回 True 表示安全"""
766+
dangerous_functions = get_dangerous_functions(ds_type)
767+
dangerous_functions_upper = {f.upper() for f in dangerous_functions}
768+
769+
for stmt in statements:
770+
if stmt:
771+
for func in stmt.find_all(exp.Anonymous):
772+
if func.name.upper() in dangerous_functions_upper:
773+
return False
774+
return True
775+
776+
777+
def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema) -> tuple[bool, str]:
778+
"""
779+
检查 SQL 是否为安全的只读查询
780+
返回: (是否安全, 错误原因)
781+
"""
720782
try:
721783
normalized_sql = sql.strip().lstrip("(").strip()
722784
first_keyword = normalized_sql.split(None, 1)[0].upper() if normalized_sql else ""
723-
allowed_read_commands = {"SELECT", "WITH", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"}
785+
786+
# 根据配置决定是否允许元数据查询
787+
if settings.SQLBOT_ALLOW_METADATA_QUERIES:
788+
allowed_read_commands = {"SELECT", "WITH", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"}
789+
else:
790+
allowed_read_commands = {"SELECT", "WITH"}
791+
724792
denied_write_commands = {
725793
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
726794
"TRUNCATE", "MERGE", "COPY", "REPLACE", "GRANT", "REVOKE",
@@ -730,21 +798,29 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
730798
if not first_keyword:
731799
raise ValueError("Parse SQL Error")
732800
if first_keyword in denied_write_commands:
733-
return False
801+
return False, f"Write operation '{first_keyword}' is not allowed"
734802

735-
dialect = None
736-
if equals_ignore_case(ds.type, 'mysql', 'doris', 'starrocks'):
737-
dialect = 'mysql'
738-
elif equals_ignore_case(ds.type, 'sqlServer'):
739-
dialect = 'tsql'
740-
elif equals_ignore_case(ds.type, 'hive'):
741-
dialect = 'hive'
803+
# 1. 使用正则检查特殊模式
804+
for pattern in DANGEROUS_PATTERNS:
805+
if re.search(pattern, sql, re.IGNORECASE):
806+
return False, f"SQL contains dangerous pattern: {pattern}"
742807

808+
dialect = get_sqlglot_dialect(ds.type)
743809
statements = sqlglot.parse(sql, dialect=dialect)
744810

745811
if not statements:
746812
raise ValueError("Parse SQL Error")
747813

814+
# 2. 使用 sqlglot 检查函数调用
815+
dangerous_functions = get_dangerous_functions(ds.type)
816+
dangerous_functions_upper = {f.upper() for f in dangerous_functions}
817+
for stmt in statements:
818+
if stmt:
819+
for func in stmt.find_all(exp.Anonymous):
820+
if func.name.upper() in dangerous_functions_upper:
821+
return False, f"SQL contains dangerous function: {func.name}"
822+
823+
# 3. 检查写操作类型
748824
write_types = (
749825
exp.Insert, exp.Update, exp.Delete,
750826
exp.Create, exp.Drop, exp.Alter,
@@ -755,9 +831,12 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
755831
if stmt is None:
756832
continue
757833
if isinstance(stmt, write_types):
758-
return False
834+
return False, f"SQL contains write operation: {type(stmt).__name__}"
835+
836+
if first_keyword not in allowed_read_commands:
837+
return False, f"SQL command '{first_keyword}' is not allowed. Only SELECT and WITH are permitted"
759838

760-
return first_keyword in allowed_read_commands
839+
return True, ""
761840

762841
except Exception as e:
763842
raise ValueError(f"Parse SQL Error: {e}")

backend/apps/system/crud/assistant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
187187
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
188188
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
189189
tables = []
190+
table_name_list = []
190191
i = 0
191192
for table in ds.tables:
192193
# 如果传入了 table_list,则只处理在列表中的表
@@ -213,6 +214,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
213214
schema_table += '\n]\n'
214215
t_obj = {"id": i, "schema_table": schema_table}
215216
tables.append(t_obj)
217+
table_name_list.append(table.name)
216218

217219
# do table embedding
218220
# if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
@@ -222,7 +224,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
222224
for s in tables:
223225
schema_str += s.get('schema_table')
224226

225-
return schema_str, []
227+
return schema_str, table_name_list
226228

227229
def get_ds(self, ds_id: int, trans: Trans = None):
228230
if self.ds_list:

backend/common/core/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
111111
GENERATE_SQL_QUERY_LIMIT_ENABLED: bool = True
112112
GENERATE_SQL_QUERY_HISTORY_ROUND_COUNT: int = 3
113113

114+
# 安全配置:是否允许元数据查询(SHOW/DESCRIBE/DESC/EXPLAIN)
115+
# 默认关闭,防止通过元数据查询泄露数据库结构
116+
SQLBOT_ALLOW_METADATA_QUERIES: bool = False
117+
114118
PARSE_REASONING_BLOCK_ENABLED: bool = True
115119
DEFAULT_REASONING_CONTENT_START: str = '<think>'
116120
DEFAULT_REASONING_CONTENT_END: str = '</think>'

0 commit comments

Comments
 (0)