diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bae0a487a1..eb705215ac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,25 +12,18 @@ jobs: runs-on: ubuntu-latest steps: - name: Set up repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: # Should be consistent with go.mod go-version: '1.19.6' - name: Lint - uses: golangci/golangci-lint-action@v2 + uses: golangci/golangci-lint-action@v6 with: - # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version version: v1.45.2 - - # If set to true then the action will use pre-installed Go - skip-go-installation: true - - # Optional: golangci-lint command line arguments. - # The config file has lower priority than command-line options. args: --config=.golangci.yml test: @@ -38,10 +31,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Set up repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: # Should be consistent with go.mod go-version: '1.19.6' diff --git a/sqle/api/controller/v1/report_data_builder.go b/sqle/api/controller/v1/report_data_builder.go new file mode 100644 index 0000000000..bfa3bc52a4 --- /dev/null +++ b/sqle/api/controller/v1/report_data_builder.go @@ -0,0 +1,265 @@ +package v1 + +import ( + "context" + "fmt" + "sort" + "strings" + "time" + + "github.com/actiontech/sqle/sqle/locale" + "github.com/actiontech/sqle/sqle/model" + "github.com/actiontech/sqle/sqle/server" + "github.com/actiontech/sqle/sqle/server/auditreport" +) + +// BuildAuditReportData 从 Task 和数据库查询构建报告数据。 +// 该函数放在 controller 层;报告数据模型在 server/auditreport。utils 被 model 引用, +// 若 utils 反向引用 model 会产生循环依赖。 +func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, ctx context.Context) (*auditreport.AuditReportData, error) { + // 1. 获取 SQL 列表 + data := map[string]interface{}{ + "task_id": fmt.Sprintf("%d", task.ID), + "no_duplicate": noDuplicate, + } + + taskSQLsDetail, _, err := s.GetTaskSQLsByReq(data) + if err != nil { + return nil, fmt.Errorf("get task SQLs failed: %w", err) + } + + // 2. 获取回滚 SQL 映射 + rollbackSqlMap, err := server.BackupService{}.GetRollbackSqlsMap(task.ID) + if err != nil { + return nil, fmt.Errorf("get rollback SQLs failed: %w", err) + } + + // 3. 构建 SQL 列表和统计数据(等级:固定顺序 normal/notice/warn/error,其余等级单独统计) + var levelCounts [4]int + extrasLevel := make(map[string]int) + ruleHits := make(map[string]int) + n := len(taskSQLsDetail) + sqlList := make([]auditreport.AuditSQLItem, 0, n) + problemSQLs := make([]auditreport.AuditSQLItem, 0, n) + + for _, td := range taskSQLsDetail { + // 构造临时 ExecuteSQL 对象以复用状态描述方法 + tempSQL := &model.ExecuteSQL{ + AuditResults: td.AuditResults, + AuditStatus: td.AuditStatus, + } + tempSQL.ExecStatus = td.ExecStatus + + // 提取规则名称和审核建议 + ruleName, suggestion := extractRuleInfo(td.AuditResults, ctx) + + item := auditreport.AuditSQLItem{ + Number: td.Number, + SQL: td.ExecSQL, + AuditLevel: td.AuditLevel, + AuditStatus: tempSQL.GetAuditStatusDesc(ctx), + AuditResult: tempSQL.GetAuditResultDesc(ctx), + ExecStatus: tempSQL.GetExecStatusDesc(ctx), + ExecResult: td.ExecResult, + RollbackSQL: strings.Join(rollbackSqlMap[td.Id], "\n"), + Description: td.Description, + RuleName: ruleName, + Suggestion: suggestion, + } + sqlList = append(sqlList, item) + + // 统计等级分布 + level := td.AuditLevel + if level == "" { + level = "normal" + } + switch level { + case "normal": + levelCounts[0]++ + case "notice": + levelCounts[1]++ + case "warn": + levelCounts[2]++ + case "error": + levelCounts[3]++ + default: + extrasLevel[level]++ + } + + // 区分问题 SQL(AuditLevel 不是 normal 且不为空) + if level != "normal" { + problemSQLs = append(problemSQLs, item) + } + + // 统计规则命中 + for _, ar := range td.AuditResults { + if ar.RuleName != "" { + ruleHits[ar.RuleName]++ + } + } + } + + // 4. 构建国际化标签(当前使用 locale 包提供的 i18n 标签) + labels := buildReportLabels(ctx) + + now := time.Now() + auditTime := now.Format("2006-01-02 15:04:05") + if task.CreatedAt.Year() > 1 { + auditTime = task.CreatedAt.Format("2006-01-02 15:04:05") + } + + instanceName := task.InstanceName() + if instanceName == "" { + instanceName = "unknown" + } + + return &auditreport.AuditReportData{ + TaskID: uint64(task.ID), + Title: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelTitle), + InstanceName: instanceName, + Schema: task.Schema, + GeneratedAt: now, + Lang: locale.Bundle.GetLangTagFromCtx(ctx).String(), + LogoBase64: "", + Summary: auditreport.AuditSummary{ + AuditTime: auditTime, + InstanceName: instanceName, + Schema: task.Schema, + TotalSQL: len(sqlList), + PassRate: task.PassRate * 100, + Score: task.Score, + AuditLevel: task.AuditLevel, + }, + Statistics: auditreport.AuditStatistics{ + LevelDistribution: formatLevelDistribution(levelCounts, extrasLevel), + RuleHits: toRuleHits(ruleHits), + }, + SQLList: sqlList, + ProblemSQLs: problemSQLs, + Labels: labels, + }, nil +} + +// extractRuleInfo 从审核结果中提取规则名称和审核建议。 +// 如果有多条规则命中,使用逗号分隔拼接。 +func extractRuleInfo(auditResults model.AuditResults, ctx context.Context) (ruleName string, suggestion string) { + if len(auditResults) == 0 { + return "", "" + } + + lang := locale.Bundle.GetLangTagFromCtx(ctx) + var ruleNames []string + var suggestions []string + + for _, ar := range auditResults { + if ar.RuleName != "" { + ruleNames = append(ruleNames, ar.RuleName) + } + msg := ar.GetAuditMsgByLangTag(lang) + if msg != "" { + suggestions = append(suggestions, msg) + } + } + + return strings.Join(ruleNames, ", "), strings.Join(suggestions, "; ") +} + +// formatLevelDistribution 将各等级计数转为 LevelCount 切片:先按固定顺序输出非零的标准等级, +// 再按名字排序输出其余等级。 +func formatLevelDistribution(counts [4]int, extras map[string]int) []auditreport.LevelCount { + names := []string{"normal", "notice", "warn", "error"} + out := make([]auditreport.LevelCount, 0, 4+len(extras)) + for i, name := range names { + if counts[i] > 0 { + out = append(out, auditreport.LevelCount{Level: name, Count: counts[i]}) + } + } + if len(extras) == 0 { + return out + } + keys := make([]string, 0, len(extras)) + for k := range extras { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + out = append(out, auditreport.LevelCount{Level: k, Count: extras[k]}) + } + return out +} + +// toLevelCounts 将等级分布 map 转为 LevelCount 切片(供测试与 map 输入场景)。 +func toLevelCounts(dist map[string]int) []auditreport.LevelCount { + if len(dist) == 0 { + return []auditreport.LevelCount{} + } + var counts [4]int + extras := make(map[string]int) + for level, c := range dist { + switch level { + case "normal": + counts[0] += c + case "notice": + counts[1] += c + case "warn": + counts[2] += c + case "error": + counts[3] += c + default: + extras[level] += c + } + } + return formatLevelDistribution(counts, extras) +} + +// toRuleHits 将规则命中 map 转换为按命中次数降序排列的 RuleHit 切片。 +func toRuleHits(hits map[string]int) []auditreport.RuleHit { + if len(hits) == 0 { + return []auditreport.RuleHit{} + } + + result := make([]auditreport.RuleHit, 0, len(hits)) + for name, count := range hits { + result = append(result, auditreport.RuleHit{ + RuleName: name, + HitCount: count, + }) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].HitCount > result[j].HitCount + }) + + return result +} + +// buildReportLabels 构建报告中使用的国际化标签。 +// 当前版本使用 locale 包已有的国际化消息和硬编码中文标签, +// 后续阶段 8 将接入 go-i18n 框架实现完整国际化。 +func buildReportLabels(ctx context.Context) auditreport.ReportLabels { + return auditreport.ReportLabels{ + AuditSummary: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelAuditSummary), + ResultStatistics: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelResultStatistics), + ProblemSQLList: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelProblemSQLList), + RuleHitStatistics: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelRuleHitStatistics), + AuditTime: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelAuditTime), + DataSource: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelDataSource), + Schema: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelSchema), + TotalSQL: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelTotalSQL), + PassRate: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelPassRate), + Score: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelScore), + AuditLevel: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelAuditLevel), + Number: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportIndex), + SQL: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportSQL), + AuditStatus: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportAuditStatus), + AuditResult: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportAuditResult), + ExecStatus: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportExecStatus), + ExecResult: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportExecResult), + RollbackSQL: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportRollbackSQL), + RuleName: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelRuleName), + Description: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportDescription), + Suggestion: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelSuggestion), + Count: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelCount), + HitCount: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelHitCount), + } +} diff --git a/sqle/api/controller/v1/report_data_builder_test.go b/sqle/api/controller/v1/report_data_builder_test.go new file mode 100644 index 0000000000..c0155c0623 --- /dev/null +++ b/sqle/api/controller/v1/report_data_builder_test.go @@ -0,0 +1,451 @@ +package v1 + +import ( + "context" + "testing" + + "github.com/actiontech/sqle/sqle/model" + "github.com/actiontech/sqle/sqle/server/auditreport" + "golang.org/x/text/language" +) + +func TestToLevelCounts(t *testing.T) { + testCases := map[string]struct { + input map[string]int + wantLen int + wantFirst string // expected first level (highest priority) + wantLast string // expected last level (lowest priority) + description string + }{ + "mixed levels in fixed order": { + input: map[string]int{ + "normal": 5, + "error": 2, + "warn": 3, + "notice": 1, + }, + wantLen: 4, + wantFirst: "normal", + wantLast: "error", + description: "should list normal, notice, warn, error when all present", + }, + "empty map returns empty slice": { + input: map[string]int{}, + wantLen: 0, + description: "empty input should return empty result", + }, + "single level": { + input: map[string]int{ + "warn": 10, + }, + wantLen: 1, + wantFirst: "warn", + wantLast: "warn", + description: "single level should still work", + }, + "only normal level": { + input: map[string]int{ + "normal": 100, + }, + wantLen: 1, + wantFirst: "normal", + wantLast: "normal", + description: "all normal should return single entry", + }, + "unknown level after standard levels": { + input: map[string]int{ + "normal": 1, + "error": 1, + "unknown": 1, + }, + wantLen: 3, + wantFirst: "normal", + wantLast: "unknown", + description: "unknown levels should follow standard levels, sorted by name", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result := toLevelCounts(tc.input) + + if len(result) != tc.wantLen { + t.Fatalf("toLevelCounts() returned %d items, want %d", len(result), tc.wantLen) + } + + if tc.wantLen == 0 { + return + } + + // Verify first (highest priority) element + if result[0].Level != tc.wantFirst { + t.Errorf("first level = %q, want %q (%s)", result[0].Level, tc.wantFirst, tc.description) + } + + // Verify last (lowest priority) element + if result[len(result)-1].Level != tc.wantLast { + t.Errorf("last level = %q, want %q (%s)", result[len(result)-1].Level, tc.wantLast, tc.description) + } + + // Verify counts match input + for _, lc := range result { + expectedCount, ok := tc.input[lc.Level] + if !ok { + t.Errorf("unexpected level %q in result", lc.Level) + continue + } + if lc.Count != expectedCount { + t.Errorf("level %q count = %d, want %d", lc.Level, lc.Count, expectedCount) + } + } + }) + } +} + +func TestToRuleHits(t *testing.T) { + testCases := map[string]struct { + input map[string]int + wantLen int + wantFirst string // expected first rule (highest hit count) + description string + }{ + "multiple rules sorted by hit count descending": { + input: map[string]int{ + "no_select_all": 5, + "no_drop_table": 10, + "add_index": 3, + }, + wantLen: 3, + wantFirst: "no_drop_table", + description: "should sort by hit count descending", + }, + "empty map returns empty slice": { + input: map[string]int{}, + wantLen: 0, + description: "empty input should return empty result", + }, + "single rule": { + input: map[string]int{ + "no_select_all": 7, + }, + wantLen: 1, + wantFirst: "no_select_all", + description: "single rule should work correctly", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result := toRuleHits(tc.input) + + if len(result) != tc.wantLen { + t.Fatalf("toRuleHits() returned %d items, want %d", len(result), tc.wantLen) + } + + if tc.wantLen == 0 { + return + } + + // Verify first element has highest hit count + if result[0].RuleName != tc.wantFirst { + t.Errorf("first rule = %q, want %q (%s)", result[0].RuleName, tc.wantFirst, tc.description) + } + + // Verify descending order + for i := 1; i < len(result); i++ { + if result[i].HitCount > result[i-1].HitCount { + t.Errorf("rule at index %d (count=%d) > rule at index %d (count=%d), expected descending order", + i, result[i].HitCount, i-1, result[i-1].HitCount) + } + } + + // Verify counts match input + for _, rh := range result { + expectedCount, ok := tc.input[rh.RuleName] + if !ok { + t.Errorf("unexpected rule %q in result", rh.RuleName) + continue + } + if rh.HitCount != expectedCount { + t.Errorf("rule %q hit count = %d, want %d", rh.RuleName, rh.HitCount, expectedCount) + } + } + }) + } +} + +func TestExtractRuleInfo(t *testing.T) { + ctx := context.Background() + + testCases := map[string]struct { + auditResults model.AuditResults + wantRuleName string + wantHasRuleName bool + wantHasSugg bool + description string + }{ + "empty audit results": { + auditResults: model.AuditResults{}, + wantRuleName: "", + wantHasRuleName: false, + wantHasSugg: false, + description: "empty results should return empty strings", + }, + "single rule hit": { + auditResults: model.AuditResults{ + { + Level: "warn", + RuleName: "no_select_all", + I18nAuditResultInfo: model.I18nAuditResultInfo{ + language.Chinese: model.AuditResultInfo{Message: "should not use SELECT *"}, + }, + }, + }, + wantRuleName: "no_select_all", + wantHasRuleName: true, + wantHasSugg: true, + description: "single rule should return its name and message", + }, + "multiple rule hits": { + auditResults: model.AuditResults{ + { + Level: "warn", + RuleName: "no_select_all", + I18nAuditResultInfo: model.I18nAuditResultInfo{ + language.Chinese: model.AuditResultInfo{Message: "avoid SELECT *"}, + }, + }, + { + Level: "error", + RuleName: "no_drop_table", + I18nAuditResultInfo: model.I18nAuditResultInfo{ + language.Chinese: model.AuditResultInfo{Message: "DROP TABLE not allowed"}, + }, + }, + }, + wantRuleName: "no_select_all, no_drop_table", + wantHasRuleName: true, + wantHasSugg: true, + description: "multiple rules should be comma-separated", + }, + "rule with empty name is skipped": { + auditResults: model.AuditResults{ + { + Level: "notice", + RuleName: "", + I18nAuditResultInfo: model.I18nAuditResultInfo{ + language.Chinese: model.AuditResultInfo{Message: "some notice"}, + }, + }, + }, + wantRuleName: "", + wantHasRuleName: false, + wantHasSugg: true, + description: "rule with empty name should be skipped in rule names but message still included", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + ruleName, suggestion := extractRuleInfo(tc.auditResults, ctx) + + if tc.wantHasRuleName { + if ruleName != tc.wantRuleName { + t.Errorf("ruleName = %q, want %q (%s)", ruleName, tc.wantRuleName, tc.description) + } + } else { + if ruleName != "" { + t.Errorf("ruleName = %q, want empty (%s)", ruleName, tc.description) + } + } + + if tc.wantHasSugg { + if suggestion == "" { + t.Errorf("suggestion should not be empty (%s)", tc.description) + } + } else { + if suggestion != "" { + t.Errorf("suggestion = %q, want empty (%s)", suggestion, tc.description) + } + } + }) + } +} + +func TestBuildReportLabels(t *testing.T) { + ctx := context.Background() + + testCases := map[string]struct { + description string + }{ + "default context returns non-empty labels": { + description: "all label fields should be non-empty with default locale", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + labels := buildReportLabels(ctx) + + // Verify all label fields are non-empty + fieldChecks := map[string]string{ + "AuditSummary": labels.AuditSummary, + "ResultStatistics": labels.ResultStatistics, + "ProblemSQLList": labels.ProblemSQLList, + "RuleHitStatistics": labels.RuleHitStatistics, + "AuditTime": labels.AuditTime, + "DataSource": labels.DataSource, + "Schema": labels.Schema, + "TotalSQL": labels.TotalSQL, + "PassRate": labels.PassRate, + "Score": labels.Score, + "AuditLevel": labels.AuditLevel, + "Number": labels.Number, + "SQL": labels.SQL, + "AuditStatus": labels.AuditStatus, + "AuditResult": labels.AuditResult, + "ExecStatus": labels.ExecStatus, + "ExecResult": labels.ExecResult, + "RollbackSQL": labels.RollbackSQL, + "RuleName": labels.RuleName, + "Description": labels.Description, + "Suggestion": labels.Suggestion, + "Count": labels.Count, + "HitCount": labels.HitCount, + } + + for field, value := range fieldChecks { + if value == "" { + t.Errorf("%s: label field %q is empty (%s)", tc.description, field, tc.description) + } + } + }) + } +} + +// TestLevelCountsPreserveAllLevels verifies that toLevelCounts preserves +// the count data correctly for all four standard audit levels. +func TestLevelCountsPreserveAllLevels(t *testing.T) { + input := map[string]int{ + "error": 3, + "warn": 5, + "notice": 2, + "normal": 10, + } + + result := toLevelCounts(input) + + if len(result) != 4 { + t.Fatalf("expected 4 levels, got %d", len(result)) + } + + // Build a lookup map from result + resultMap := make(map[string]int) + for _, lc := range result { + resultMap[lc.Level] = lc.Count + } + + // Verify each level count matches + for level, expectedCount := range input { + if count, ok := resultMap[level]; !ok { + t.Errorf("level %q missing from result", level) + } else if count != expectedCount { + t.Errorf("level %q: count = %d, want %d", level, count, expectedCount) + } + } + + // Verify ordering: normal, notice, warn, error + expectedOrder := []string{"normal", "notice", "warn", "error"} + for i, expected := range expectedOrder { + if result[i].Level != expected { + t.Errorf("position %d: level = %q, want %q", i, result[i].Level, expected) + } + } +} + +// TestToRuleHitsStableSortForEqualCounts verifies that toRuleHits handles +// rules with equal hit counts without error. +func TestToRuleHitsStableSortForEqualCounts(t *testing.T) { + input := map[string]int{ + "rule_a": 5, + "rule_b": 5, + "rule_c": 5, + } + + result := toRuleHits(input) + + if len(result) != 3 { + t.Fatalf("expected 3 rules, got %d", len(result)) + } + + // All should have count 5 + for _, rh := range result { + if rh.HitCount != 5 { + t.Errorf("rule %q: hit count = %d, want 5", rh.RuleName, rh.HitCount) + } + } +} + +// TestToLevelCountsNilMap verifies toLevelCounts handles nil map gracefully. +func TestToLevelCountsNilMap(t *testing.T) { + result := toLevelCounts(nil) + if result == nil { + t.Fatal("toLevelCounts(nil) should return empty slice, not nil") + } + if len(result) != 0 { + t.Errorf("toLevelCounts(nil) returned %d items, want 0", len(result)) + } +} + +// TestToRuleHitsNilMap verifies toRuleHits handles nil map gracefully. +func TestToRuleHitsNilMap(t *testing.T) { + result := toRuleHits(nil) + if result == nil { + t.Fatal("toRuleHits(nil) should return empty slice, not nil") + } + if len(result) != 0 { + t.Errorf("toRuleHits(nil) returned %d items, want 0", len(result)) + } +} + +// TestExtractRuleInfoNilResults verifies extractRuleInfo handles nil AuditResults. +func TestExtractRuleInfoNilResults(t *testing.T) { + ctx := context.Background() + ruleName, suggestion := extractRuleInfo(nil, ctx) + if ruleName != "" { + t.Errorf("extractRuleInfo(nil) ruleName = %q, want empty", ruleName) + } + if suggestion != "" { + t.Errorf("extractRuleInfo(nil) suggestion = %q, want empty", suggestion) + } +} + +// TestCSVHeaders verifies that CSVHeaders returns the correct number of columns +// based on the report labels. +func TestCSVHeaders(t *testing.T) { + data := &auditreport.AuditReportData{ + Labels: auditreport.ReportLabels{ + Number: "Number", + SQL: "SQL", + AuditStatus: "Audit Status", + AuditResult: "Audit Result", + ExecStatus: "Exec Status", + ExecResult: "Exec Result", + RollbackSQL: "Rollback SQL", + Description: "Description", + }, + } + + gen := auditreport.NewCSVReportGenerator() + headers := gen.CSVHeaders(data) + if len(headers) != 8 { + t.Errorf("CSVHeaders() returned %d columns, want 8", len(headers)) + } + + expectedHeaders := []string{"Number", "SQL", "Audit Status", "Audit Result", "Exec Status", "Exec Result", "Rollback SQL", "Description"} + for i, h := range headers { + if h != expectedHeaders[i] { + t.Errorf("CSVHeaders()[%d] = %q, want %q", i, h, expectedHeaders[i]) + } + } +} diff --git a/sqle/api/controller/v1/task.go b/sqle/api/controller/v1/task.go index 4af23db0b4..d04c7f67d2 100644 --- a/sqle/api/controller/v1/task.go +++ b/sqle/api/controller/v1/task.go @@ -10,7 +10,6 @@ import ( "net/http" "net/http/httputil" "net/url" - "strconv" "strings" "time" @@ -21,10 +20,10 @@ import ( "github.com/actiontech/sqle/sqle/config" "github.com/actiontech/sqle/sqle/dms" "github.com/actiontech/sqle/sqle/errors" - "github.com/actiontech/sqle/sqle/locale" "github.com/actiontech/sqle/sqle/log" "github.com/actiontech/sqle/sqle/model" "github.com/actiontech/sqle/sqle/server" + "github.com/actiontech/sqle/sqle/server/auditreport" "github.com/actiontech/sqle/sqle/utils" "github.com/labstack/echo/v4" @@ -599,7 +598,8 @@ type DownloadAuditTaskSQLsFileReqV1 struct { // @Security ApiKeyAuth // @Param task_id path string true "task id" // @Param no_duplicate query boolean false "select unique (fingerprint and audit result) for task sql" -// @Success 200 file 1 "sql report csv file" +// @Param export_format query string false "export format" Enums(csv,html,pdf,word) default(csv) +// @Success 200 file 1 "sql report file" // @router /v1/tasks/audits/{task_id}/sql_report [get] func DownloadTaskSQLReportFile(c echo.Context) error { req := new(DownloadAuditTaskSQLsFileReqV1) @@ -618,59 +618,30 @@ func DownloadTaskSQLReportFile(c echo.Context) error { return controller.JSONBaseErrorReq(c, err) } - data := map[string]interface{}{ - "task_id": taskId, - "no_duplicate": req.NoDuplicate, - } + ctx := c.Request().Context() - taskSQLsDetail, _, err := s.GetTaskSQLsByReq(data) + // 解析导出格式参数,缺失时默认返回 CSV(向后兼容),无效格式返回 400 错误 + exportFormat, err := auditreport.NormalizeExportFormatStr(c.QueryParam("export_format")) if err != nil { return controller.JSONBaseErrorReq(c, err) } - ctx := c.Request().Context() - csvBuilder := utils.NewCSVBuilder() - err = csvBuilder.WriteHeader([]string{ - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportIndex), // "序号", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportSQL), // "SQL", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportAuditStatus), // "SQL审核状态", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportAuditResult), // "SQL审核结果", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportExecStatus), // "SQL执行状态", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportExecResult), // "SQL执行结果", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportRollbackSQL), // "SQL对应的回滚语句", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportDescription), // "SQL描述", - }) + // 构建报告数据 + reportData, err := BuildAuditReportData(task, s, req.NoDuplicate, ctx) if err != nil { - return controller.JSONBaseErrorReq(c, errors.New(errors.WriteDataToTheFileError, err)) + return controller.JSONBaseErrorReq(c, err) } - rollbackSqlMap, err := server.BackupService{}.GetRollbackSqlsMap(task.ID) + + // 根据格式生成报告 + result, err := auditreport.ExportAuditReport(exportFormat, reportData) if err != nil { return controller.JSONBaseErrorReq(c, err) } - for _, td := range taskSQLsDetail { - taskSql := &model.ExecuteSQL{ - AuditResults: td.AuditResults, - AuditStatus: td.AuditStatus, - } - taskSql.ExecStatus = td.ExecStatus - err := csvBuilder.WriteRow([]string{ - strconv.FormatUint(uint64(td.Number), 10), - td.ExecSQL, - taskSql.GetAuditStatusDesc(ctx), - taskSql.GetAuditResultDesc(ctx), - taskSql.GetExecStatusDesc(ctx), - td.ExecResult, - strings.Join(rollbackSqlMap[taskSql.ID], "\n"), - td.Description, - }) - if err != nil { - return controller.JSONBaseErrorReq(c, errors.New(errors.WriteDataToTheFileError, err)) - } - } - fileName := fmt.Sprintf("SQL_audit_report_%v_%v.csv", task.InstanceName(), taskId) + + // 设置响应头并返回 c.Response().Header().Set(echo.HeaderContentDisposition, - mime.FormatMediaType("attachment", map[string]string{"filename": fileName})) - return c.Blob(http.StatusOK, "text/csv", csvBuilder.FlushAndGetBuffer().Bytes()) + mime.FormatMediaType("attachment", map[string]string{"filename": result.FileName})) + return c.Blob(http.StatusOK, result.ContentType, result.Content) } // @Summary 下载指定扫描任务的SQL文件 diff --git a/sqle/docs/docs.go b/sqle/docs/docs.go index 1520dd0a3e..9bc0ab54b8 100644 --- a/sqle/docs/docs.go +++ b/sqle/docs/docs.go @@ -10474,11 +10474,24 @@ var doc = `{ "description": "select unique (fingerprint and audit result) for task sql", "name": "no_duplicate", "in": "query" + }, + { + "enum": [ + "csv", + "html", + "pdf", + "word" + ], + "type": "string", + "default": "csv", + "description": "export format", + "name": "export_format", + "in": "query" } ], "responses": { "200": { - "description": "sql report csv file", + "description": "sql report file", "schema": { "type": "file" } diff --git a/sqle/docs/swagger.json b/sqle/docs/swagger.json index 60cafb8cd1..bf284277d3 100644 --- a/sqle/docs/swagger.json +++ b/sqle/docs/swagger.json @@ -10458,11 +10458,24 @@ "description": "select unique (fingerprint and audit result) for task sql", "name": "no_duplicate", "in": "query" + }, + { + "enum": [ + "csv", + "html", + "pdf", + "word" + ], + "type": "string", + "default": "csv", + "description": "export format", + "name": "export_format", + "in": "query" } ], "responses": { "200": { - "description": "sql report csv file", + "description": "sql report file", "schema": { "type": "file" } diff --git a/sqle/docs/swagger.yaml b/sqle/docs/swagger.yaml index c0b8cb67f1..a6dd9f6d0a 100644 --- a/sqle/docs/swagger.yaml +++ b/sqle/docs/swagger.yaml @@ -14559,9 +14559,19 @@ paths: in: query name: no_duplicate type: boolean + - default: csv + description: export format + enum: + - csv + - html + - pdf + - word + in: query + name: export_format + type: string responses: "200": - description: sql report csv file + description: sql report file schema: type: file security: diff --git a/sqle/locale/active.en.toml b/sqle/locale/active.en.toml index 43d8dbba5f..cd5593cb3d 100644 --- a/sqle/locale/active.en.toml +++ b/sqle/locale/active.en.toml @@ -409,3 +409,19 @@ WorkflowStepStateApprove = "Approved" WorkflowStepStateReject = "Rejected" WorkflowStepTypeSQLAudit = "Auditing" WorkflowStepTypeSQLExecute = "Executing" +ReportLabelAuditSummary = "Audit Summary" +ReportLabelResultStatistics = "Audit Result Statistics" +ReportLabelProblemSQLList = "Problem SQL List" +ReportLabelRuleHitStatistics = "Rule Hit Statistics" +ReportLabelAuditTime = "Audit Time" +ReportLabelDataSource = "Data Source" +ReportLabelSchema = "Schema" +ReportLabelTotalSQL = "Total SQL" +ReportLabelPassRate = "Pass Rate" +ReportLabelScore = "Score" +ReportLabelAuditLevel = "Audit Level" +ReportLabelRuleName = "Rule Name" +ReportLabelSuggestion = "Suggestion" +ReportLabelCount = "Count" +ReportLabelHitCount = "Hit Count" +ReportLabelTitle = "SQL Audit Report" diff --git a/sqle/locale/active.zh.toml b/sqle/locale/active.zh.toml index 59c94f11e5..a25de8b025 100644 --- a/sqle/locale/active.zh.toml +++ b/sqle/locale/active.zh.toml @@ -409,3 +409,19 @@ WorkflowStepStateApprove = "通过" WorkflowStepStateReject = "驳回" WorkflowStepTypeSQLAudit = "审批" WorkflowStepTypeSQLExecute = "上线" +ReportLabelAuditSummary = "审核概要" +ReportLabelResultStatistics = "审核结果统计" +ReportLabelProblemSQLList = "问题SQL列表" +ReportLabelRuleHitStatistics = "规则命中统计" +ReportLabelAuditTime = "审核时间" +ReportLabelDataSource = "数据源" +ReportLabelSchema = "数据库" +ReportLabelTotalSQL = "SQL总数" +ReportLabelPassRate = "通过率" +ReportLabelScore = "评分" +ReportLabelAuditLevel = "审核等级" +ReportLabelRuleName = "规则名称" +ReportLabelSuggestion = "优化建议" +ReportLabelCount = "数量" +ReportLabelHitCount = "命中次数" +ReportLabelTitle = "SQL审核报告" diff --git a/sqle/locale/message_zh.go b/sqle/locale/message_zh.go index fdcad22c83..98d1bc0794 100644 --- a/sqle/locale/message_zh.go +++ b/sqle/locale/message_zh.go @@ -64,6 +64,26 @@ var ( TaskSQLReportDescription = &i18n.Message{ID: "TaskSQLReportDescription", Other: "SQL描述"} ) +// report labels (for audit report export) +var ( + ReportLabelAuditSummary = &i18n.Message{ID: "ReportLabelAuditSummary", Other: "审核概要"} + ReportLabelResultStatistics = &i18n.Message{ID: "ReportLabelResultStatistics", Other: "审核结果统计"} + ReportLabelProblemSQLList = &i18n.Message{ID: "ReportLabelProblemSQLList", Other: "问题SQL列表"} + ReportLabelRuleHitStatistics = &i18n.Message{ID: "ReportLabelRuleHitStatistics", Other: "规则命中统计"} + ReportLabelAuditTime = &i18n.Message{ID: "ReportLabelAuditTime", Other: "审核时间"} + ReportLabelDataSource = &i18n.Message{ID: "ReportLabelDataSource", Other: "数据源"} + ReportLabelSchema = &i18n.Message{ID: "ReportLabelSchema", Other: "数据库"} + ReportLabelTotalSQL = &i18n.Message{ID: "ReportLabelTotalSQL", Other: "SQL总数"} + ReportLabelPassRate = &i18n.Message{ID: "ReportLabelPassRate", Other: "通过率"} + ReportLabelScore = &i18n.Message{ID: "ReportLabelScore", Other: "评分"} + ReportLabelAuditLevel = &i18n.Message{ID: "ReportLabelAuditLevel", Other: "审核等级"} + ReportLabelRuleName = &i18n.Message{ID: "ReportLabelRuleName", Other: "规则名称"} + ReportLabelSuggestion = &i18n.Message{ID: "ReportLabelSuggestion", Other: "优化建议"} + ReportLabelCount = &i18n.Message{ID: "ReportLabelCount", Other: "数量"} + ReportLabelHitCount = &i18n.Message{ID: "ReportLabelHitCount", Other: "命中次数"} + ReportLabelTitle = &i18n.Message{ID: "ReportLabelTitle", Other: "SQL审核报告"} +) + // workflow var ( WorkflowStepStateApprove = &i18n.Message{ID: "WorkflowStepStateApprove", Other: "通过"} diff --git a/sqle/server/auditreport/export_format.go b/sqle/server/auditreport/export_format.go new file mode 100644 index 0000000000..762d887ef8 --- /dev/null +++ b/sqle/server/auditreport/export_format.go @@ -0,0 +1,36 @@ +package auditreport + +import ( + "fmt" + "strings" +) + +// ExportFormat 审核报告等业务的导出格式(与 utils 中通用表格导出的 csv/excel 区分)。 +type ExportFormat string + +const ( + CsvExportFormat ExportFormat = "csv" + ExcelExportFormat ExportFormat = "excel" + ExportFormatHTML ExportFormat = "html" + ExportFormatPDF ExportFormat = "pdf" + ExportFormatWORD ExportFormat = "word" +) + +// NormalizeExportFormatStr 规范化导出格式查询参数。 +// 空字符串默认返回 CSV(向后兼容);无效格式返回错误。 +func NormalizeExportFormatStr(format string) (ExportFormat, error) { + switch strings.ToLower(strings.TrimSpace(format)) { + case "html": + return ExportFormatHTML, nil + case "pdf": + return ExportFormatPDF, nil + case "word", "docx": + return ExportFormatWORD, nil + case "excel", "xlsx": + return ExcelExportFormat, nil + case "csv", "": + return CsvExportFormat, nil + default: + return "", fmt.Errorf("unsupported export format: %s", format) + } +} diff --git a/sqle/server/auditreport/report_csv.go b/sqle/server/auditreport/report_csv.go new file mode 100644 index 0000000000..39b5a4096a --- /dev/null +++ b/sqle/server/auditreport/report_csv.go @@ -0,0 +1,75 @@ +package auditreport + +import ( + "fmt" + + "github.com/actiontech/sqle/sqle/utils" +) + +// CSVReportGenerator CSV 格式报告生成器 +// 复用已有的 CSVBuilder 生成 CSV 报告,实现 ReportGenerator 接口。 +type CSVReportGenerator struct{} + +// NewCSVReportGenerator 创建并返回一个新的 CSVReportGenerator 实例 +func NewCSVReportGenerator() *CSVReportGenerator { + return &CSVReportGenerator{} +} + +// ReportType 返回生成器支持的导出格式 +func (g *CSVReportGenerator) ReportType() ExportFormat { + return CsvExportFormat +} + +// CSVHeaders 返回 CSV 报告的表头列表 +func (g *CSVReportGenerator) CSVHeaders(data *AuditReportData) []string { + return []string{ + data.Labels.Number, + data.Labels.SQL, + data.Labels.AuditStatus, + data.Labels.AuditResult, + data.Labels.ExecStatus, + data.Labels.ExecResult, + data.Labels.RollbackSQL, + data.Labels.Description, + } +} + +// ToCSVRow 将单条审核 SQL 转为 CSV 行 +func (g *CSVReportGenerator) ToCSVRow(item *AuditSQLItem) []string { + return []string{ + fmt.Sprintf("%d", item.Number), + item.SQL, + item.AuditStatus, + item.AuditResult, + item.ExecStatus, + item.ExecResult, + item.RollbackSQL, + item.Description, + } +} + +// Generate 根据审核报告数据生成 CSV 格式的文件 +func (g *CSVReportGenerator) Generate(data *AuditReportData) (*utils.ExportDataResult, error) { + builder := utils.NewCSVBuilder() + + if err := builder.WriteHeader(g.CSVHeaders(data)); err != nil { + return nil, fmt.Errorf("write csv header failed: %v", err) + } + + for i := range data.SQLList { + if err := builder.WriteRow(g.ToCSVRow(&data.SQLList[i])); err != nil { + return nil, fmt.Errorf("write csv row failed: %v", err) + } + } + + content := builder.FlushAndGetBuffer().Bytes() + if err := builder.Error(); err != nil { + return nil, fmt.Errorf("csv builder error: %v", err) + } + + return &utils.ExportDataResult{ + Content: content, + ContentType: "text/csv", + FileName: fmt.Sprintf("SQL_audit_report_%s_%d.csv", data.InstanceName, data.TaskID), + }, nil +} diff --git a/sqle/server/auditreport/report_generator.go b/sqle/server/auditreport/report_generator.go new file mode 100644 index 0000000000..658888a700 --- /dev/null +++ b/sqle/server/auditreport/report_generator.go @@ -0,0 +1,112 @@ +package auditreport + +import ( + "time" + + "github.com/actiontech/sqle/sqle/utils" +) + +// AuditReportData 审核报告完整数据模型 +type AuditReportData struct { + // 元信息 + TaskID uint64 `json:"task_id"` + Title string `json:"title"` // 报告标题 (i18n) + InstanceName string `json:"instance_name"` // 数据源名称 + Schema string `json:"schema"` + GeneratedAt time.Time `json:"generated_at"` // 报告生成时间 + Lang string `json:"lang"` // 语言: zh-CN / en-US + LogoBase64 string `json:"logo_base64"` // Logo 图片 base64 + + // 审核概要 + Summary AuditSummary `json:"summary"` + + // 审核结果统计 + Statistics AuditStatistics `json:"statistics"` + + // SQL 列表 + SQLList []AuditSQLItem `json:"sql_list"` // 全部 SQL + ProblemSQLs []AuditSQLItem `json:"problem_sqls"` // 问题 SQL(AuditLevel != normal) + + // 国际化标签 + Labels ReportLabels `json:"labels"` +} + +// AuditSummary 审核概要 +type AuditSummary struct { + AuditTime string `json:"audit_time"` + InstanceName string `json:"instance_name"` + Schema string `json:"schema"` + TotalSQL int `json:"total_sql"` + PassRate float64 `json:"pass_rate"` + Score int32 `json:"score"` + AuditLevel string `json:"audit_level"` +} + +// AuditStatistics 审核结果统计 +type AuditStatistics struct { + LevelDistribution []LevelCount `json:"level_distribution"` // 按等级分布 + RuleHits []RuleHit `json:"rule_hits"` // 规则命中统计 +} + +// LevelCount 等级统计 +type LevelCount struct { + Level string `json:"level"` // normal/notice/warn/error + Count int `json:"count"` +} + +// RuleHit 规则命中统计 +type RuleHit struct { + RuleName string `json:"rule_name"` + HitCount int `json:"hit_count"` +} + +// AuditSQLItem 单条 SQL 审核结果 +type AuditSQLItem struct { + Number uint `json:"number"` + SQL string `json:"sql"` + AuditLevel string `json:"audit_level"` + AuditStatus string `json:"audit_status"` + AuditResult string `json:"audit_result"` // 审核结果描述 + ExecStatus string `json:"exec_status"` + ExecResult string `json:"exec_result"` + RollbackSQL string `json:"rollback_sql"` + Description string `json:"description"` + // HTML/PDF/WORD 报告扩展字段 + RuleName string `json:"rule_name"` // 触发的规则名称 + Suggestion string `json:"suggestion"` // 优化建议 +} + +// ReportLabels 报告中的国际化标签 +type ReportLabels struct { + AuditSummary string `json:"audit_summary"` + ResultStatistics string `json:"result_statistics"` + ProblemSQLList string `json:"problem_sql_list"` + RuleHitStatistics string `json:"rule_hit_statistics"` + AuditTime string `json:"audit_time"` + DataSource string `json:"data_source"` + Schema string `json:"schema"` + TotalSQL string `json:"total_sql"` + PassRate string `json:"pass_rate"` + Score string `json:"score"` + AuditLevel string `json:"audit_level"` + Number string `json:"number"` + SQL string `json:"sql"` + AuditStatus string `json:"audit_status"` + AuditResult string `json:"audit_result"` + ExecStatus string `json:"exec_status"` + ExecResult string `json:"exec_result"` + RollbackSQL string `json:"rollback_sql"` + RuleName string `json:"rule_name"` + Description string `json:"description"` + Suggestion string `json:"suggestion"` + Count string `json:"count"` + HitCount string `json:"hit_count"` +} + +// ReportGenerator 报告生成器接口 +type ReportGenerator interface { + // Generate 根据报告数据生成指定格式的文件 + Generate(data *AuditReportData) (*utils.ExportDataResult, error) + // ReportType 返回生成器对应的导出格式 + ReportType() ExportFormat +} diff --git a/sqle/server/auditreport/report_generator_ce.go b/sqle/server/auditreport/report_generator_ce.go new file mode 100644 index 0000000000..d2d034eb44 --- /dev/null +++ b/sqle/server/auditreport/report_generator_ce.go @@ -0,0 +1,30 @@ +//go:build !enterprise + +package auditreport + +import ( + "fmt" + + "github.com/actiontech/sqle/sqle/utils" +) + +// ExportAuditReport CE 版统一导出入口。 +// CE 版仅支持 CSV 和 HTML 两种格式。 +// 请求 PDF 或 WORD 格式时返回错误提示,提醒用户需要企业版。 +// 无效格式返回错误(REQ-6.3)。 +func ExportAuditReport(format ExportFormat, data *AuditReportData) (*utils.ExportDataResult, error) { + switch format { + case CsvExportFormat: + return NewCSVReportGenerator().Generate(data) + case ExportFormatHTML: + gen, err := NewHTMLReportGenerator() + if err != nil { + return nil, err + } + return gen.Generate(data) + case ExportFormatPDF, ExportFormatWORD: + return nil, fmt.Errorf("export format %s is only supported in enterprise edition", format) + default: + return nil, fmt.Errorf("unsupported export format: %s", format) + } +} diff --git a/sqle/server/auditreport/report_generator_ce_test.go b/sqle/server/auditreport/report_generator_ce_test.go new file mode 100644 index 0000000000..e3f852ffa1 --- /dev/null +++ b/sqle/server/auditreport/report_generator_ce_test.go @@ -0,0 +1,109 @@ +//go:build !enterprise + +package auditreport + +import ( + "strings" + "testing" +) + +func TestExportAuditReport_CEEdition_PDFBlocked(t *testing.T) { + testCases := map[string]struct { + format ExportFormat + wantErrSubstr string + }{ + "PDF format is blocked in CE edition": { + format: ExportFormatPDF, + wantErrSubstr: "enterprise edition", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err == nil { + t.Fatal("ExportAuditReport(PDF) should return error in CE edition, got nil") + } + if result != nil { + t.Errorf("ExportAuditReport(PDF) should return nil result in CE edition, got %+v", result) + } + if !strings.Contains(err.Error(), tc.wantErrSubstr) { + t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) + } + if !strings.Contains(err.Error(), string(tc.format)) { + t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) + } + }) + } +} + +func TestExportAuditReport_CEEdition_WORDBlocked(t *testing.T) { + testCases := map[string]struct { + format ExportFormat + wantErrSubstr string + }{ + "WORD format is blocked in CE edition": { + format: ExportFormatWORD, + wantErrSubstr: "enterprise edition", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err == nil { + t.Fatal("ExportAuditReport(WORD) should return error in CE edition, got nil") + } + if result != nil { + t.Errorf("ExportAuditReport(WORD) should return nil result in CE edition, got %+v", result) + } + if !strings.Contains(err.Error(), tc.wantErrSubstr) { + t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) + } + if !strings.Contains(err.Error(), string(tc.format)) { + t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) + } + }) + } +} + +func TestExportAuditReport_DefaultCSV(t *testing.T) { + testCases := map[string]struct { + format ExportFormat + wantErrSubstr string + }{ + "invalid format returns error": { + format: ExportFormat("invalid"), + wantErrSubstr: "unsupported export format", + }, + "empty format returns error": { + format: ExportFormat(""), + wantErrSubstr: "unsupported export format", + }, + "unknown format returns error": { + format: ExportFormat("json"), + wantErrSubstr: "unsupported export format", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err == nil { + t.Fatalf("ExportAuditReport(%q) expected error, got nil", tc.format) + } + if result != nil { + t.Errorf("ExportAuditReport(%q) expected nil result, got %+v", tc.format, result) + } + if !strings.Contains(err.Error(), tc.wantErrSubstr) { + t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) + } + if !strings.Contains(err.Error(), string(tc.format)) { + t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) + } + }) + } +} diff --git a/sqle/server/auditreport/report_generator_test.go b/sqle/server/auditreport/report_generator_test.go new file mode 100644 index 0000000000..f4ae8a24fe --- /dev/null +++ b/sqle/server/auditreport/report_generator_test.go @@ -0,0 +1,751 @@ +package auditreport + +import ( + "fmt" + "strings" + "testing" + "time" +) + +// buildTestReportData 构建测试用的 AuditReportData +func buildTestReportData() *AuditReportData { + return &AuditReportData{ + TaskID: 1001, + Title: "SQL Audit Report", + InstanceName: "test-mysql", + Schema: "test_db", + GeneratedAt: time.Now(), + Lang: "en-US", + Summary: AuditSummary{ + AuditTime: "2026-03-31 10:00:00", + InstanceName: "test-mysql", + Schema: "test_db", + TotalSQL: 3, + PassRate: 66.7, + Score: 70, + AuditLevel: "warn", + }, + Statistics: AuditStatistics{ + LevelDistribution: []LevelCount{ + {Level: "normal", Count: 1}, + {Level: "warn", Count: 1}, + {Level: "error", Count: 1}, + }, + RuleHits: []RuleHit{ + {RuleName: "no_select_all", HitCount: 1}, + {RuleName: "no_drop_table", HitCount: 1}, + }, + }, + SQLList: []AuditSQLItem{ + { + Number: 1, + SQL: "SELECT * FROM users", + AuditLevel: "warn", + AuditStatus: "finished", + AuditResult: "should not use SELECT *", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "", + Description: "query all users", + RuleName: "no_select_all", + Suggestion: "specify column names", + }, + { + Number: 2, + SQL: "DROP TABLE test", + AuditLevel: "error", + AuditStatus: "finished", + AuditResult: "DROP TABLE is prohibited", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "", + Description: "drop test table", + RuleName: "no_drop_table", + Suggestion: "do not use DROP TABLE", + }, + { + Number: 3, + SQL: "INSERT INTO t VALUES(1)", + AuditLevel: "normal", + AuditStatus: "finished", + AuditResult: "", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "DELETE FROM t WHERE id=1", + Description: "insert a row", + }, + }, + ProblemSQLs: []AuditSQLItem{ + {Number: 1, SQL: "SELECT * FROM users", AuditLevel: "warn", RuleName: "no_select_all"}, + {Number: 2, SQL: "DROP TABLE test", AuditLevel: "error", RuleName: "no_drop_table"}, + }, + Labels: ReportLabels{ + AuditSummary: "Audit Summary", + ResultStatistics: "Audit Result Statistics", + ProblemSQLList: "Problem SQL List", + RuleHitStatistics: "Rule Hit Statistics", + AuditTime: "Audit Time", + DataSource: "Data Source", + Schema: "Schema", + TotalSQL: "Total SQL", + PassRate: "Pass Rate", + Score: "Score", + AuditLevel: "Audit Level", + Number: "Number", + SQL: "SQL", + AuditStatus: "Audit Status", + AuditResult: "Audit Result", + ExecStatus: "Exec Status", + ExecResult: "Exec Result", + RollbackSQL: "Rollback SQL", + RuleName: "Rule Name", + Description: "Description", + Suggestion: "Suggestion", + Count: "Count", + HitCount: "Hit Count", + }, + } +} + +func TestCSVReportGenerator_Normal(t *testing.T) { + testCases := map[string]struct { + data *AuditReportData + wantContentType string + wantFilePrefix string + wantFileSuffix string + wantBOM bool + wantHeaders []string + wantDataRows int + }{ + "normal data generates valid CSV report": { + data: buildTestReportData(), + wantContentType: "text/csv", + wantFilePrefix: "SQL_audit_report_test-mysql_1001", + wantFileSuffix: ".csv", + wantBOM: true, + wantHeaders: []string{"Number", "SQL", "Audit Status", "Audit Result", "Exec Status", "Exec Result", "Rollback SQL", "Description"}, + wantDataRows: 3, + }, + } + + gen := NewCSVReportGenerator() + + // Verify the generator implements ReportGenerator interface + var _ ReportGenerator = gen + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result, err := gen.Generate(tc.data) + if err != nil { + t.Fatalf("Generate() returned unexpected error: %v", err) + } + if result == nil { + t.Fatal("Generate() returned nil result") + } + + // Verify ContentType + if result.ContentType != tc.wantContentType { + t.Errorf("ContentType = %q, want %q", result.ContentType, tc.wantContentType) + } + + // Verify FileName contains InstanceName and TaskID, ends with .csv + if !strings.Contains(result.FileName, tc.data.InstanceName) { + t.Errorf("FileName %q does not contain InstanceName %q", result.FileName, tc.data.InstanceName) + } + if !strings.Contains(result.FileName, "1001") { + t.Errorf("FileName %q does not contain TaskID", result.FileName) + } + if !strings.HasSuffix(result.FileName, tc.wantFileSuffix) { + t.Errorf("FileName %q does not end with %q", result.FileName, tc.wantFileSuffix) + } + if !strings.HasPrefix(result.FileName, tc.wantFilePrefix) { + t.Errorf("FileName %q does not start with %q", result.FileName, tc.wantFilePrefix) + } + + // Verify UTF-8 BOM + content := string(result.Content) + if tc.wantBOM { + if !strings.HasPrefix(content, "\xEF\xBB\xBF") { + t.Error("Content does not start with UTF-8 BOM") + } + } + + // Verify headers exist in content + for _, h := range tc.wantHeaders { + if !strings.Contains(content, h) { + t.Errorf("Content does not contain header %q", h) + } + } + + // Verify the number of data rows (excluding BOM and header line) + contentWithoutBOM := strings.TrimPrefix(content, "\xEF\xBB\xBF") + lines := strings.Split(strings.TrimRight(contentWithoutBOM, "\n"), "\n") + // First line is header, remaining are data rows + dataLineCount := len(lines) - 1 + if dataLineCount != tc.wantDataRows { + t.Errorf("data row count = %d, want %d", dataLineCount, tc.wantDataRows) + } + }) + } +} + +func TestCSVReportGenerator_EmptyData(t *testing.T) { + testCases := map[string]struct { + data *AuditReportData + wantDataRows int + }{ + "empty SQL list produces header only": { + data: &AuditReportData{ + TaskID: 2002, + InstanceName: "empty-instance", + SQLList: []AuditSQLItem{}, + Labels: ReportLabels{ + Number: "Number", + SQL: "SQL", + AuditStatus: "Audit Status", + AuditResult: "Audit Result", + ExecStatus: "Exec Status", + ExecResult: "Exec Result", + RollbackSQL: "Rollback SQL", + Description: "Description", + }, + }, + wantDataRows: 0, + }, + } + + gen := NewCSVReportGenerator() + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result, err := gen.Generate(tc.data) + if err != nil { + t.Fatalf("Generate() returned unexpected error: %v", err) + } + if result == nil { + t.Fatal("Generate() returned nil result") + } + + content := string(result.Content) + + // Verify BOM is present + if !strings.HasPrefix(content, "\xEF\xBB\xBF") { + t.Error("Content does not start with UTF-8 BOM") + } + + // Verify only header line, no data rows + contentWithoutBOM := strings.TrimPrefix(content, "\xEF\xBB\xBF") + lines := strings.Split(strings.TrimRight(contentWithoutBOM, "\n"), "\n") + // Should have exactly 1 line (header only) + if len(lines) != 1 { + t.Errorf("expected 1 line (header only), got %d lines", len(lines)) + } + + // Verify headers are present + headerLine := lines[0] + for _, h := range []string{"Number", "SQL", "Audit Status"} { + if !strings.Contains(headerLine, h) { + t.Errorf("header line does not contain %q", h) + } + } + + // Verify FileName + if result.FileName != "SQL_audit_report_empty-instance_2002.csv" { + t.Errorf("FileName = %q, want %q", result.FileName, "SQL_audit_report_empty-instance_2002.csv") + } + }) + } +} + +func TestCSVReportGenerator_SpecialChars(t *testing.T) { + testCases := map[string]struct { + sqlItem AuditSQLItem + wantInRow string + description string + }{ + "SQL with comma is quoted": { + sqlItem: AuditSQLItem{ + Number: 1, + SQL: "SELECT a, b FROM users", + AuditStatus: "finished", + AuditResult: "ok", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "", + Description: "", + }, + wantInRow: `"SELECT a, b FROM users"`, + description: "field containing comma should be wrapped in double quotes", + }, + "SQL with double quote is escaped": { + sqlItem: AuditSQLItem{ + Number: 2, + SQL: `SELECT "name" FROM users`, + AuditStatus: "finished", + AuditResult: "ok", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "", + Description: "", + }, + wantInRow: `"SELECT ""name"" FROM users"`, + description: "double quotes within a field should be escaped as two double quotes", + }, + "SQL with newline is quoted": { + sqlItem: AuditSQLItem{ + Number: 3, + SQL: "SELECT *\nFROM users", + AuditStatus: "finished", + AuditResult: "ok", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "", + Description: "", + }, + wantInRow: "\"SELECT *\nFROM users\"", + description: "field containing newline should be wrapped in double quotes", + }, + } + + gen := NewCSVReportGenerator() + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := &AuditReportData{ + TaskID: 3003, + InstanceName: "special-instance", + SQLList: []AuditSQLItem{tc.sqlItem}, + Labels: ReportLabels{ + Number: "Number", + SQL: "SQL", + AuditStatus: "Audit Status", + AuditResult: "Audit Result", + ExecStatus: "Exec Status", + ExecResult: "Exec Result", + RollbackSQL: "Rollback SQL", + Description: "Description", + }, + } + + result, err := gen.Generate(data) + if err != nil { + t.Fatalf("Generate() returned unexpected error: %v", err) + } + if result == nil { + t.Fatal("Generate() returned nil result") + } + + content := string(result.Content) + + // Verify that the special character handling is correct + if !strings.Contains(content, tc.wantInRow) { + t.Errorf("%s\ncontent does not contain expected substring %q\nfull content:\n%s", + tc.description, tc.wantInRow, content) + } + }) + } +} + +func TestCSVReportGenerator_ReportType(t *testing.T) { + gen := NewCSVReportGenerator() + if gen.ReportType() != CsvExportFormat { + t.Errorf("ReportType() = %q, want %q", gen.ReportType(), CsvExportFormat) + } +} + +func TestHTMLReportGenerator_Normal(t *testing.T) { + gen, err := NewHTMLReportGenerator() + if err != nil { + t.Fatalf("NewHTMLReportGenerator() returned unexpected error: %v", err) + } + + // Verify the generator implements ReportGenerator interface + var _ ReportGenerator = gen + + testCases := map[string]struct { + data *AuditReportData + wantContentType string + wantFilePrefix string + wantFileSuffix string + wantHTMLTags []string + wantSQLContents []string + wantLabels []string + }{ + "normal data generates valid HTML report": { + data: buildTestReportData(), + wantContentType: "text/html", + wantFilePrefix: "SQL_audit_report_test-mysql_1001", + wantFileSuffix: ".html", + wantHTMLTags: []string{"