Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions sqle/api/controller/v1/sql_audit_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ func CreateSQLAuditRecord(c echo.Context) error {
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
if req.Sqls != "" {
if rawSQLs, ok := getRawFormValue(c, "sqls"); ok {
sqls = GetSQLFromFileResp{
SourceType: model.TaskSQLSourceFromFormData,
SQLsFromFormData: req.Sqls,
SQLsFromFormData: rawSQLs,
}
} else {
sqls, err = GetSQLFromFile(c)
Expand Down Expand Up @@ -232,27 +232,16 @@ func addSQLsFromFileToTasks(sqls GetSQLFromFileResp, task *model.Task, plugin dr
var num uint = 1

fileTask := func(sqlsText, filePath string, defaultStartLine uint64) error {
nodes, err := plugin.Parse(context.TODO(), sqlsText)
executeSQLs, err := server.BuildExecuteSQLsFromSQL(context.TODO(), plugin, sqlsText, server.BuildExecuteSQLsOptions{
StartNumber: num,
SourceFile: filePath,
StartLine: defaultStartLine,
})
if err != nil {
return fmt.Errorf("parse sqls failed: %v", err)
}
for _, node := range nodes {
startLine := defaultStartLine
if startLine == 0 {
startLine = node.StartLine
}
task.ExecuteSQLs = append(task.ExecuteSQLs, &model.ExecuteSQL{
BaseSQL: model.BaseSQL{
Number: num,
Content: node.Text,
SourceFile: filePath,
StartLine: startLine,
SQLType: node.Type,
ExecBatchId: node.ExecBatchId,
},
})
num++
return fmt.Errorf("build execute sqls failed: %v", err)
}
task.ExecuteSQLs = append(task.ExecuteSQLs, executeSQLs...)
num += uint(len(executeSQLs))
return nil
}

Expand Down
52 changes: 50 additions & 2 deletions sqle/api/controller/v1/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,54 @@ func GetSQLFromFile(c echo.Context) (GetSQLFromFileResp, error) {
return GetSQLFromFileResp{}, errors.New(errors.DataInvalid, fmt.Errorf("input sql is empty"))
}

func getRawFormValue(c echo.Context, name string) (string, bool) {
req := c.Request()
if req == nil {
return "", false
}
if req.MultipartForm != nil {
if values, ok := req.MultipartForm.Value[name]; ok {
if len(values) == 0 {
return "", true
}
return values[0], true
}
}
if req.Form != nil {
if values, ok := req.Form[name]; ok {
if len(values) == 0 {
return "", true
}
return values[0], true
}
}
if req.PostForm != nil {
if values, ok := req.PostForm[name]; ok {
if len(values) == 0 {
return "", true
}
return values[0], true
}
}
if err := req.ParseMultipartForm(32 << 20); err == nil && req.MultipartForm != nil {
if values, ok := req.MultipartForm.Value[name]; ok {
if len(values) == 0 {
return "", true
}
return values[0], true
}
}
if err := req.ParseForm(); err == nil && req.PostForm != nil {
if values, ok := req.PostForm[name]; ok {
if len(values) == 0 {
return "", true
}
return values[0], true
}
}
return "", false
}

func saveFileFromContext(c echo.Context) ([]*model.AuditFile, error) {
fileHeader, fileType, err := getFileHeaderFromContext(c)
if err != nil {
Expand Down Expand Up @@ -320,10 +368,10 @@ func CreateAndAuditTask(c echo.Context) error {
return controller.JSONBaseErrorReq(c, err)
}

if req.Sql != "" {
if rawSQL, ok := getRawFormValue(c, "sql"); ok {
sqls = GetSQLFromFileResp{
SourceType: model.TaskSQLSourceFromFormData,
SQLsFromFormData: req.Sql,
SQLsFromFormData: rawSQL,
}
} else {
sqls, err = GetSQLFromFile(c)
Expand Down
104 changes: 95 additions & 9 deletions sqle/server/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,59 @@

func convertSQLsToTask(sql string, p driver.Plugin) (*model.Task, error) {
task := &model.Task{}
nodes, err := p.Parse(context.TODO(), sql)
executeSQLs, err := BuildExecuteSQLsFromSQL(context.TODO(), p, sql, BuildExecuteSQLsOptions{})
if err != nil {
return nil, err
}
for n, node := range nodes {
task.ExecuteSQLs = append(task.ExecuteSQLs, &model.ExecuteSQL{
task.ExecuteSQLs = executeSQLs
return task, nil
}

type BuildExecuteSQLsOptions struct {
StartNumber uint
SourceFile string
StartLine uint64
}

func BuildExecuteSQLsFromSQL(ctx context.Context, p driver.Plugin, sqlText string, opts BuildExecuteSQLsOptions) ([]*model.ExecuteSQL, error) {
trimmedSQL := strings.TrimSpace(sqlText)
if trimmedSQL == "" {
return nil, nil
}
number := opts.StartNumber
if number == 0 {
number = 1
}
nodes, err := p.Parse(ctx, sqlText)
if err != nil || len(nodes) == 0 {
return []*model.ExecuteSQL{{

Check failure on line 153 in sqle/server/audit.go

View workflow job for this annotation

GitHub Actions / lint

error is not nil (line 151) but it returns nil (nilerr)
BaseSQL: model.BaseSQL{
Number: uint(n + 1),
Content: node.Text,
SQLType: node.Type,
Number: number,
Content: trimmedSQL,
SourceFile: opts.SourceFile,
StartLine: opts.StartLine,
},
}}, nil
}

executeSQLs := make([]*model.ExecuteSQL, 0, len(nodes))
for i, node := range nodes {
startLine := opts.StartLine
if startLine == 0 {
startLine = node.StartLine
}
executeSQLs = append(executeSQLs, &model.ExecuteSQL{
BaseSQL: model.BaseSQL{
Number: number + uint(i),
Content: node.Text,
SourceFile: opts.SourceFile,
StartLine: startLine,
SQLType: node.Type,
ExecBatchId: node.ExecBatchId,
},
})
}
return task, nil
return executeSQLs, nil
}

func audit(projectId string, l *logrus.Entry, task *model.Task, p driver.Plugin, customRules []*model.CustomRule) (err error) {
Expand Down Expand Up @@ -187,7 +226,8 @@
// - In these case, we pass the raw SQL to plugins, it's ok.
node, err := parse(l, p, strings.TrimSpace(executeSQL.Content))
if err != nil {
return err
appendManualConfirmWarn(executeSQL, executeSQL.Content, err)
continue
}
var whitelistMatch bool
var matchedWhitelistID uint
Expand Down Expand Up @@ -232,7 +272,24 @@

results, err := p.Audit(context.TODO(), sqls)
if err != nil {
return err
for i, sql := range auditSqls {
result, singleErr := auditSingleSQL(l, p, task, sqls[i], customRules)
hook.AfterAudit(sql)
if singleErr != nil {
appendManualConfirmWarn(sql, sqls[i], singleErr)
continue
}
sql.AuditStatus = model.SQLAuditStatusFinished
sql.AuditLevel = string(result.Level())
sql.AuditFingerprint = utils.Md5String(string(append([]byte(result.Message()), []byte(nodes[i].Fingerprint)...)))
sql.SqlFingerprint = nodes[i].Fingerprint
appendExecuteSqlResults(sql, result)
}
ReplenishTaskStatistics(task)
if AfterAuditHook != nil {
go AfterAuditHook(task)
}
return nil
}
if len(results) != len(sqls) {
return fmt.Errorf("audit results [%d] does not match the number of SQL [%d]", len(results), len(sqls))
Expand All @@ -259,6 +316,13 @@
}

func ReplenishTaskStatistics(task *model.Task) {
if len(task.ExecuteSQLs) == 0 {
task.PassRate = 1
task.AuditLevel = string(driverV2.RuleLevelNull)
task.Score = scoreTask(task)
task.Status = model.TaskStatusAudited
return
}
var normalCount float64
maxAuditLevel := driverV2.RuleLevelNull
for _, executeSQL := range task.ExecuteSQLs {
Expand All @@ -276,6 +340,28 @@
task.Status = model.TaskStatusAudited
}

func auditSingleSQL(l *logrus.Entry, p driver.Plugin, task *model.Task, sql string, customRules []*model.CustomRule) (*driverV2.AuditResults, error) {
results, err := p.Audit(context.TODO(), []string{sql})
if err != nil {
return nil, err
}
if len(results) != 1 {
return nil, fmt.Errorf("audit results [%d] does not match the number of SQL [1]", len(results))
}
CustomRuleAudit(l, task, []string{sql}, results, customRules)
return results[0], nil
}

func appendManualConfirmWarn(executeSQL *model.ExecuteSQL, sql string, err error) {
result := driverV2.NewAuditResults()
result.AddResultWithError(driverV2.RuleLevelWarn, "", err.Error(), false, plocale.Bundle.LocalizeAll(plocale.UnsupportedSyntaxError))
executeSQL.AuditStatus = model.SQLAuditStatusFinished
executeSQL.AuditLevel = string(result.Level())
executeSQL.AuditFingerprint = utils.Md5String(result.Message() + strings.TrimSpace(sql))
executeSQL.SqlFingerprint = utils.Md5String(strings.TrimSpace(sql))
appendExecuteSqlResults(executeSQL, result)
}

// Scoring rules from https://github.com/actiontech/sqle/issues/284
func scoreTask(task *model.Task) int32 {
if len(task.ExecuteSQLs) == 0 {
Expand Down
Loading
Loading