@@ -586,8 +586,9 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
586586 while sql .endswith (';' ):
587587 sql = sql [:- 1 ]
588588 # check execute sql only contain read operations
589- if not check_sql_read (sql , ds ):
590- raise ValueError (f"SQL can only contain read operations" )
589+ is_safe , error_reason = check_sql_read (sql , ds )
590+ if not is_safe :
591+ raise ValueError (f"SQL can only contain read operations: { error_reason } " )
591592
592593 db = DB .get_db (ds .type )
593594 if db .connect_type == ConnectType .sqlalchemy :
@@ -716,11 +717,78 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
716717 raise ParseSQLResultError (str (ex ))
717718
718719
719- def check_sql_read (sql : str , ds : CoreDatasource | AssistantOutDsSchema ):
720+ def get_sqlglot_dialect (ds_type : str ) -> str :
721+ """根据数据源类型获取 sqlglot dialect"""
722+ if equals_ignore_case (ds_type , 'mysql' , 'doris' , 'starrocks' ):
723+ return 'mysql'
724+ elif equals_ignore_case (ds_type , 'sqlServer' ):
725+ return 'tsql'
726+ elif equals_ignore_case (ds_type , 'hive' ):
727+ return 'hive'
728+ return None
729+
730+
731+ # 通用危险函数(适用于所有数据库)
732+ COMMON_DANGEROUS_FUNCTIONS = {'version' , 'current_user' , 'user' , 'database' }
733+
734+ # 特定数据库的危险函数
735+ DS_SPECIFIC_DANGEROUS_FUNCTIONS = {
736+ 'mysql' : {'LOAD_FILE' , 'INTO OUTFILE' , 'INTO DUMPFILE' },
737+ 'doris' : {'LOAD_FILE' , 'INTO OUTFILE' , 'INTO DUMPFILE' },
738+ 'starrocks' : {'LOAD_FILE' , 'INTO OUTFILE' , 'INTO DUMPFILE' },
739+ 'postgresql' : {'pg_read_file' , 'pg_write_file' , 'lo_import' , 'lo_export' },
740+ 'sqlserver' : {'EXEC' , 'xp_cmdshell' , 'sp_executesql' },
741+ 'oracle' : {'UTL_FILE' , 'DBMS_PIPE' , 'DBMS_LOCK' },
742+ 'hive' : {'ADD FILE' , 'ADD JAR' },
743+ }
744+
745+ # 危险模式正则表达式(用于检查特殊语法)
746+ import re
747+ DANGEROUS_PATTERNS = [
748+ r'\bINTO\s+OUTFILE\b' ,
749+ r'\bINTO\s+DUMPFILE\b' ,
750+ r'\bEXEC\s*\(' ,
751+ r'\bCOPY\s+.*\bTO\s+PROGRAM\b' ,
752+ ]
753+
754+
755+ def get_dangerous_functions (ds_type : str ) -> set :
756+ """获取危险函数(通用 + 特定数据源)"""
757+ functions = COMMON_DANGEROUS_FUNCTIONS .copy ()
758+ ds_key = ds_type .lower () if ds_type else ''
759+ if ds_key in DS_SPECIFIC_DANGEROUS_FUNCTIONS :
760+ functions .update (DS_SPECIFIC_DANGEROUS_FUNCTIONS [ds_key ])
761+ return functions
762+
763+
764+ def check_dangerous_functions (statements : list , ds_type : str ) -> bool :
765+ """检查是否使用了危险函数,返回 True 表示安全"""
766+ dangerous_functions = get_dangerous_functions (ds_type )
767+ dangerous_functions_upper = {f .upper () for f in dangerous_functions }
768+
769+ for stmt in statements :
770+ if stmt :
771+ for func in stmt .find_all (exp .Anonymous ):
772+ if func .name .upper () in dangerous_functions_upper :
773+ return False
774+ return True
775+
776+
777+ def check_sql_read (sql : str , ds : CoreDatasource | AssistantOutDsSchema ) -> tuple [bool , str ]:
778+ """
779+ 检查 SQL 是否为安全的只读查询
780+ 返回: (是否安全, 错误原因)
781+ """
720782 try :
721783 normalized_sql = sql .strip ().lstrip ("(" ).strip ()
722784 first_keyword = normalized_sql .split (None , 1 )[0 ].upper () if normalized_sql else ""
723- allowed_read_commands = {"SELECT" , "WITH" , "SHOW" , "DESCRIBE" , "DESC" , "EXPLAIN" }
785+
786+ # 根据配置决定是否允许元数据查询
787+ if settings .SQLBOT_ALLOW_METADATA_QUERIES :
788+ allowed_read_commands = {"SELECT" , "WITH" , "SHOW" , "DESCRIBE" , "DESC" , "EXPLAIN" }
789+ else :
790+ allowed_read_commands = {"SELECT" , "WITH" }
791+
724792 denied_write_commands = {
725793 "INSERT" , "UPDATE" , "DELETE" , "CREATE" , "DROP" , "ALTER" ,
726794 "TRUNCATE" , "MERGE" , "COPY" , "REPLACE" , "GRANT" , "REVOKE" ,
@@ -730,21 +798,29 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
730798 if not first_keyword :
731799 raise ValueError ("Parse SQL Error" )
732800 if first_keyword in denied_write_commands :
733- return False
801+ return False , f"Write operation ' { first_keyword } ' is not allowed"
734802
735- dialect = None
736- if equals_ignore_case (ds .type , 'mysql' , 'doris' , 'starrocks' ):
737- dialect = 'mysql'
738- elif equals_ignore_case (ds .type , 'sqlServer' ):
739- dialect = 'tsql'
740- elif equals_ignore_case (ds .type , 'hive' ):
741- dialect = 'hive'
803+ # 1. 使用正则检查特殊模式
804+ for pattern in DANGEROUS_PATTERNS :
805+ if re .search (pattern , sql , re .IGNORECASE ):
806+ return False , f"SQL contains dangerous pattern: { pattern } "
742807
808+ dialect = get_sqlglot_dialect (ds .type )
743809 statements = sqlglot .parse (sql , dialect = dialect )
744810
745811 if not statements :
746812 raise ValueError ("Parse SQL Error" )
747813
814+ # 2. 使用 sqlglot 检查函数调用
815+ dangerous_functions = get_dangerous_functions (ds .type )
816+ dangerous_functions_upper = {f .upper () for f in dangerous_functions }
817+ for stmt in statements :
818+ if stmt :
819+ for func in stmt .find_all (exp .Anonymous ):
820+ if func .name .upper () in dangerous_functions_upper :
821+ return False , f"SQL contains dangerous function: { func .name } "
822+
823+ # 3. 检查写操作类型
748824 write_types = (
749825 exp .Insert , exp .Update , exp .Delete ,
750826 exp .Create , exp .Drop , exp .Alter ,
@@ -755,9 +831,12 @@ def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
755831 if stmt is None :
756832 continue
757833 if isinstance (stmt , write_types ):
758- return False
834+ return False , f"SQL contains write operation: { type (stmt ).__name__ } "
835+
836+ if first_keyword not in allowed_read_commands :
837+ return False , f"SQL command '{ first_keyword } ' is not allowed. Only SELECT and WITH are permitted"
759838
760- return first_keyword in allowed_read_commands
839+ return True , ""
761840
762841 except Exception as e :
763842 raise ValueError (f"Parse SQL Error: { e } " )
0 commit comments