diff --git a/sqle/api/controller/v1/sql_audit_record.go b/sqle/api/controller/v1/sql_audit_record.go index 1cdea8721..e3d9a8455 100644 --- a/sqle/api/controller/v1/sql_audit_record.go +++ b/sqle/api/controller/v1/sql_audit_record.go @@ -117,10 +117,10 @@ func CreateSQLAuditRecord(c echo.Context) error { if err != nil { return controller.JSONBaseErrorReq(c, err) } - if req.Sqls != "" { + if rawSQLs, ok := getRawFormValue(c, "sqls"); ok { sqls = GetSQLFromFileResp{ SourceType: model.TaskSQLSourceFromFormData, - SQLsFromFormData: req.Sqls, + SQLsFromFormData: rawSQLs, } } else { sqls, err = GetSQLFromFile(c) @@ -232,27 +232,16 @@ func addSQLsFromFileToTasks(sqls GetSQLFromFileResp, task *model.Task, plugin dr var num uint = 1 fileTask := func(sqlsText, filePath string, defaultStartLine uint64) error { - nodes, err := plugin.Parse(context.TODO(), sqlsText) + executeSQLs, err := server.BuildExecuteSQLsFromSQL(context.TODO(), plugin, sqlsText, server.BuildExecuteSQLsOptions{ + StartNumber: num, + SourceFile: filePath, + StartLine: defaultStartLine, + }) if err != nil { - return fmt.Errorf("parse sqls failed: %v", err) - } - for _, node := range nodes { - startLine := defaultStartLine - if startLine == 0 { - startLine = node.StartLine - } - task.ExecuteSQLs = append(task.ExecuteSQLs, &model.ExecuteSQL{ - BaseSQL: model.BaseSQL{ - Number: num, - Content: node.Text, - SourceFile: filePath, - StartLine: startLine, - SQLType: node.Type, - ExecBatchId: node.ExecBatchId, - }, - }) - num++ + return fmt.Errorf("build execute sqls failed: %v", err) } + task.ExecuteSQLs = append(task.ExecuteSQLs, executeSQLs...) + num += uint(len(executeSQLs)) return nil } diff --git a/sqle/api/controller/v1/task.go b/sqle/api/controller/v1/task.go index 4af23db0b..50e348c58 100644 --- a/sqle/api/controller/v1/task.go +++ b/sqle/api/controller/v1/task.go @@ -183,6 +183,54 @@ func GetSQLFromFile(c echo.Context) (GetSQLFromFileResp, error) { return GetSQLFromFileResp{}, errors.New(errors.DataInvalid, fmt.Errorf("input sql is empty")) } +func getRawFormValue(c echo.Context, name string) (string, bool) { + req := c.Request() + if req == nil { + return "", false + } + if req.MultipartForm != nil { + if values, ok := req.MultipartForm.Value[name]; ok { + if len(values) == 0 { + return "", true + } + return values[0], true + } + } + if req.Form != nil { + if values, ok := req.Form[name]; ok { + if len(values) == 0 { + return "", true + } + return values[0], true + } + } + if req.PostForm != nil { + if values, ok := req.PostForm[name]; ok { + if len(values) == 0 { + return "", true + } + return values[0], true + } + } + if err := req.ParseMultipartForm(32 << 20); err == nil && req.MultipartForm != nil { + if values, ok := req.MultipartForm.Value[name]; ok { + if len(values) == 0 { + return "", true + } + return values[0], true + } + } + if err := req.ParseForm(); err == nil && req.PostForm != nil { + if values, ok := req.PostForm[name]; ok { + if len(values) == 0 { + return "", true + } + return values[0], true + } + } + return "", false +} + func saveFileFromContext(c echo.Context) ([]*model.AuditFile, error) { fileHeader, fileType, err := getFileHeaderFromContext(c) if err != nil { @@ -320,10 +368,10 @@ func CreateAndAuditTask(c echo.Context) error { return controller.JSONBaseErrorReq(c, err) } - if req.Sql != "" { + if rawSQL, ok := getRawFormValue(c, "sql"); ok { sqls = GetSQLFromFileResp{ SourceType: model.TaskSQLSourceFromFormData, - SQLsFromFormData: req.Sql, + SQLsFromFormData: rawSQL, } } else { sqls, err = GetSQLFromFile(c) diff --git a/sqle/server/audit.go b/sqle/server/audit.go index 5268fa120..119c500c4 100644 --- a/sqle/server/audit.go +++ b/sqle/server/audit.go @@ -125,20 +125,59 @@ func AuditSQLByDriver(projectId string, l *logrus.Entry, sql string, p driver.Pl func convertSQLsToTask(sql string, p driver.Plugin) (*model.Task, error) { task := &model.Task{} - nodes, err := p.Parse(context.TODO(), sql) + executeSQLs, err := BuildExecuteSQLsFromSQL(context.TODO(), p, sql, BuildExecuteSQLsOptions{}) if err != nil { return nil, err } - for n, node := range nodes { - task.ExecuteSQLs = append(task.ExecuteSQLs, &model.ExecuteSQL{ + task.ExecuteSQLs = executeSQLs + return task, nil +} + +type BuildExecuteSQLsOptions struct { + StartNumber uint + SourceFile string + StartLine uint64 +} + +func BuildExecuteSQLsFromSQL(ctx context.Context, p driver.Plugin, sqlText string, opts BuildExecuteSQLsOptions) ([]*model.ExecuteSQL, error) { + trimmedSQL := strings.TrimSpace(sqlText) + if trimmedSQL == "" { + return nil, nil + } + number := opts.StartNumber + if number == 0 { + number = 1 + } + nodes, err := p.Parse(ctx, sqlText) + if err != nil || len(nodes) == 0 { + return []*model.ExecuteSQL{{ BaseSQL: model.BaseSQL{ - Number: uint(n + 1), - Content: node.Text, - SQLType: node.Type, + Number: number, + Content: trimmedSQL, + SourceFile: opts.SourceFile, + StartLine: opts.StartLine, + }, + }}, nil + } + + executeSQLs := make([]*model.ExecuteSQL, 0, len(nodes)) + for i, node := range nodes { + startLine := opts.StartLine + if startLine == 0 { + startLine = node.StartLine + } + executeSQLs = append(executeSQLs, &model.ExecuteSQL{ + BaseSQL: model.BaseSQL{ + Number: number + uint(i), + Content: node.Text, + SourceFile: opts.SourceFile, + StartLine: startLine, + SQLType: node.Type, + ExecBatchId: node.ExecBatchId, }, }) } - return task, nil + return executeSQLs, nil } func audit(projectId string, l *logrus.Entry, task *model.Task, p driver.Plugin, customRules []*model.CustomRule) (err error) { @@ -187,7 +226,8 @@ func hookAudit(l *logrus.Entry, task *model.Task, p driver.Plugin, hook AuditHoo // - In these case, we pass the raw SQL to plugins, it's ok. node, err := parse(l, p, strings.TrimSpace(executeSQL.Content)) if err != nil { - return err + appendManualConfirmWarn(executeSQL, executeSQL.Content, err) + continue } var whitelistMatch bool var matchedWhitelistID uint @@ -232,7 +272,24 @@ func hookAudit(l *logrus.Entry, task *model.Task, p driver.Plugin, hook AuditHoo results, err := p.Audit(context.TODO(), sqls) if err != nil { - return err + for i, sql := range auditSqls { + result, singleErr := auditSingleSQL(l, p, task, sqls[i], customRules) + hook.AfterAudit(sql) + if singleErr != nil { + appendManualConfirmWarn(sql, sqls[i], singleErr) + continue + } + sql.AuditStatus = model.SQLAuditStatusFinished + sql.AuditLevel = string(result.Level()) + sql.AuditFingerprint = utils.Md5String(string(append([]byte(result.Message()), []byte(nodes[i].Fingerprint)...))) + sql.SqlFingerprint = nodes[i].Fingerprint + appendExecuteSqlResults(sql, result) + } + ReplenishTaskStatistics(task) + if AfterAuditHook != nil { + go AfterAuditHook(task) + } + return nil } if len(results) != len(sqls) { return fmt.Errorf("audit results [%d] does not match the number of SQL [%d]", len(results), len(sqls)) @@ -259,6 +316,13 @@ func hookAudit(l *logrus.Entry, task *model.Task, p driver.Plugin, hook AuditHoo } func ReplenishTaskStatistics(task *model.Task) { + if len(task.ExecuteSQLs) == 0 { + task.PassRate = 1 + task.AuditLevel = string(driverV2.RuleLevelNull) + task.Score = scoreTask(task) + task.Status = model.TaskStatusAudited + return + } var normalCount float64 maxAuditLevel := driverV2.RuleLevelNull for _, executeSQL := range task.ExecuteSQLs { @@ -276,6 +340,28 @@ func ReplenishTaskStatistics(task *model.Task) { task.Status = model.TaskStatusAudited } +func auditSingleSQL(l *logrus.Entry, p driver.Plugin, task *model.Task, sql string, customRules []*model.CustomRule) (*driverV2.AuditResults, error) { + results, err := p.Audit(context.TODO(), []string{sql}) + if err != nil { + return nil, err + } + if len(results) != 1 { + return nil, fmt.Errorf("audit results [%d] does not match the number of SQL [1]", len(results)) + } + CustomRuleAudit(l, task, []string{sql}, results, customRules) + return results[0], nil +} + +func appendManualConfirmWarn(executeSQL *model.ExecuteSQL, sql string, err error) { + result := driverV2.NewAuditResults() + result.AddResultWithError(driverV2.RuleLevelWarn, "", err.Error(), false, plocale.Bundle.LocalizeAll(plocale.UnsupportedSyntaxError)) + executeSQL.AuditStatus = model.SQLAuditStatusFinished + executeSQL.AuditLevel = string(result.Level()) + executeSQL.AuditFingerprint = utils.Md5String(result.Message() + strings.TrimSpace(sql)) + executeSQL.SqlFingerprint = utils.Md5String(strings.TrimSpace(sql)) + appendExecuteSqlResults(executeSQL, result) +} + // Scoring rules from https://github.com/actiontech/sqle/issues/284 func scoreTask(task *model.Task) int32 { if len(task.ExecuteSQLs) == 0 { diff --git a/sqle/server/audit_degrade_test.go b/sqle/server/audit_degrade_test.go new file mode 100644 index 000000000..04ec22dfe --- /dev/null +++ b/sqle/server/audit_degrade_test.go @@ -0,0 +1,160 @@ +package server + +import ( + "context" + _driver "database/sql/driver" + "errors" + "reflect" + "strings" + "testing" + + "github.com/actiontech/dms/pkg/dms-common/i18nPkg" + "github.com/actiontech/sqle/sqle/driver" + driverV2 "github.com/actiontech/sqle/sqle/driver/v2" + "github.com/actiontech/sqle/sqle/log" + "github.com/actiontech/sqle/sqle/model" + "github.com/agiledragon/gomonkey" + "github.com/stretchr/testify/assert" +) + +type degradeAuditPlugin struct { + parseFn func(context.Context, string) ([]driverV2.Node, error) + auditFn func(context.Context, []string) ([]*driverV2.AuditResults, error) +} + +func (p *degradeAuditPlugin) Close(ctx context.Context) {} +func (p *degradeAuditPlugin) Ping(ctx context.Context) error { return nil } +func (p *degradeAuditPlugin) KillProcess(ctx context.Context) error { return nil } +func (p *degradeAuditPlugin) Exec(ctx context.Context, query string) (_driver.Result, error) { + return nil, nil +} +func (p *degradeAuditPlugin) ExecBatch(ctx context.Context, queries ...string) ([]_driver.Result, error) { + return nil, nil +} +func (p *degradeAuditPlugin) Tx(ctx context.Context, queries ...string) (*driverV2.TxResponse, error) { + return nil, nil +} +func (p *degradeAuditPlugin) Schemas(ctx context.Context) ([]string, error) { return nil, nil } +func (p *degradeAuditPlugin) Parse(ctx context.Context, sqlText string) ([]driverV2.Node, error) { + if p.parseFn != nil { + return p.parseFn(ctx, sqlText) + } + return []driverV2.Node{{Text: sqlText, Fingerprint: sqlText}}, nil +} +func (p *degradeAuditPlugin) Audit(ctx context.Context, sqls []string) ([]*driverV2.AuditResults, error) { + if p.auditFn != nil { + return p.auditFn(ctx, sqls) + } + results := make([]*driverV2.AuditResults, 0, len(sqls)) + for range sqls { + result := driverV2.NewAuditResults() + result.Add(driverV2.RuleLevelNormal, "normal_rule", i18nPkg.I18nStr{i18nPkg.DefaultLang: "normal"}) + results = append(results, result) + } + return results, nil +} +func (p *degradeAuditPlugin) GenRollbackSQL(ctx context.Context, sql string) (string, i18nPkg.I18nStr, error) { + return "", nil, nil +} +func (p *degradeAuditPlugin) Explain(ctx context.Context, conf *driverV2.ExplainConf) (*driverV2.ExplainResult, error) { + return nil, nil +} +func (p *degradeAuditPlugin) ExplainJSONFormat(ctx context.Context, conf *driverV2.ExplainConf) (*driverV2.ExplainJSONResult, error) { + return nil, nil +} +func (p *degradeAuditPlugin) GetTableMetaBySQL(ctx context.Context, conf *driver.GetTableMetaBySQLConf) (*driver.GetTableMetaBySQLResult, error) { + return nil, nil +} +func (p *degradeAuditPlugin) Query(ctx context.Context, sql string, conf *driverV2.QueryConf) (*driverV2.QueryResult, error) { + return nil, nil +} +func (p *degradeAuditPlugin) EstimateSQLAffectRows(ctx context.Context, sql string) (*driverV2.EstimatedAffectRows, error) { + return nil, nil +} +func (p *degradeAuditPlugin) GetDatabaseObjectDDL(ctx context.Context, objInfos []*driverV2.DatabaseSchemaInfo) ([]*driverV2.DatabaseSchemaObjectResult, error) { + return nil, nil +} +func (p *degradeAuditPlugin) GetDatabaseDiffModifySQL(ctx context.Context, calibratedDSN *driverV2.DSN, objInfos []*driverV2.DatabasCompareSchemaInfo) ([]*driverV2.DatabaseDiffModifySQLResult, error) { + return nil, nil +} +func (p *degradeAuditPlugin) Backup(ctx context.Context, backupStrategy string, sql string, backupMaxRows uint64) ([]string, string, error) { + return nil, "", nil +} +func (p *degradeAuditPlugin) RecommendBackupStrategy(ctx context.Context, sql string) (*driver.RecommendBackupStrategyRes, error) { + return nil, nil +} +func (p *degradeAuditPlugin) GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string]map[string]float32, error) { + return nil, nil +} + +func TestBuildExecuteSQLsFromSQLToleratesEmptyAndParseFailure(t *testing.T) { + plugin := °radeAuditPlugin{parseFn: func(ctx context.Context, sqlText string) ([]driverV2.Node, error) { + return nil, errors.New("parse failed") + }} + + empty, err := BuildExecuteSQLsFromSQL(context.Background(), plugin, " \n\t ", BuildExecuteSQLsOptions{}) + assert.NoError(t, err) + assert.Empty(t, empty) + + executeSQLs, err := BuildExecuteSQLsFromSQL(context.Background(), plugin, " bad tdsql syntax ", BuildExecuteSQLsOptions{StartNumber: 7, SourceFile: "a.sql", StartLine: 3}) + assert.NoError(t, err) + assert.Len(t, executeSQLs, 1) + assert.Equal(t, uint(7), executeSQLs[0].Number) + assert.Equal(t, "bad tdsql syntax", executeSQLs[0].Content) + assert.Equal(t, "a.sql", executeSQLs[0].SourceFile) + assert.Equal(t, uint64(3), executeSQLs[0].StartLine) +} + +func TestHookAuditDegradesParseAndBatchAuditFailures(t *testing.T) { + patches := gomonkey.ApplyMethod(reflect.TypeOf(&model.Storage{}), "GetSqlWhitelistByProjectId", func(_ *model.Storage, _ string) ([]model.SqlWhitelist, error) { + return nil, nil + }) + defer patches.Reset() + + auditCalls := make([][]string, 0) + plugin := °radeAuditPlugin{ + parseFn: func(ctx context.Context, sqlText string) ([]driverV2.Node, error) { + if strings.Contains(sqlText, "bad_parse") { + return nil, errors.New("parse failed") + } + return []driverV2.Node{{Text: sqlText, Fingerprint: "fp:" + sqlText}}, nil + }, + auditFn: func(ctx context.Context, sqls []string) ([]*driverV2.AuditResults, error) { + auditCalls = append(auditCalls, append([]string{}, sqls...)) + if len(sqls) > 1 { + return nil, errors.New("batch audit failed") + } + if strings.Contains(sqls[0], "bad_audit") { + return nil, errors.New("single audit failed") + } + result := driverV2.NewAuditResults() + result.Add(driverV2.RuleLevelNormal, "normal_rule", i18nPkg.I18nStr{i18nPkg.DefaultLang: "normal"}) + return []*driverV2.AuditResults{result}, nil + }, + } + task := &model.Task{ExecuteSQLs: []*model.ExecuteSQL{ + {BaseSQL: model.BaseSQL{Content: "select 1"}}, + {BaseSQL: model.BaseSQL{Content: "bad_parse"}}, + {BaseSQL: model.BaseSQL{Content: "bad_audit"}}, + }} + + err := hookAudit(log.NewEntry(), task, plugin, &EmptyAuditHook{}, "project1", nil) + assert.NoError(t, err) + assert.Equal(t, model.TaskStatusAudited, task.Status) + assert.Equal(t, string(driverV2.RuleLevelNormal), task.ExecuteSQLs[0].AuditLevel) + assert.Equal(t, string(driverV2.RuleLevelWarn), task.ExecuteSQLs[1].AuditLevel) + assert.Equal(t, string(driverV2.RuleLevelWarn), task.ExecuteSQLs[2].AuditLevel) + assert.Len(t, auditCalls, 3) + assert.Equal(t, []string{"select 1", "bad_audit"}, auditCalls[0]) + assert.Equal(t, []string{"select 1"}, auditCalls[1]) + assert.Equal(t, []string{"bad_audit"}, auditCalls[2]) +} + +func TestReplenishTaskStatisticsWithEmptyTask(t *testing.T) { + task := &model.Task{} + ReplenishTaskStatistics(task) + assert.Equal(t, model.TaskStatusAudited, task.Status) + assert.Equal(t, float64(1), task.PassRate) + assert.Equal(t, string(driverV2.RuleLevelNull), task.AuditLevel) + assert.Equal(t, int32(0), task.Score) +}