From 0d80d29dfcfa8c6b33f605141a7a5a3f7b645289 Mon Sep 17 00:00:00 2001 From: actiontech-zihan Date: Tue, 31 Mar 2026 14:05:14 +0000 Subject: [PATCH 01/14] feat(utils): add report export format enums, data models, and generator interface Add backend infrastructure for multi-format audit report export: - Extend ExportFormat enum with HTML/PDF/WORD constants - Add NormalizeExportFormatStr function for string-based format normalization - Define AuditReportData, AuditSummary, AuditStatistics, LevelCount, RuleHit, AuditSQLItem, and ReportLabels data model structs - Define ReportGenerator interface with Generate and Format methods - Add CSVHeaders() and ToCSVRow() helper methods for CSV export - Add unit tests for NormalizeExportFormatStr with map case pattern --- sqle/utils/file.go | 23 +++++ sqle/utils/report_generator.go | 144 ++++++++++++++++++++++++++++ sqle/utils/report_generator_test.go | 90 +++++++++++++++++ 3 files changed, 257 insertions(+) create mode 100644 sqle/utils/report_generator.go create mode 100644 sqle/utils/report_generator_test.go diff --git a/sqle/utils/file.go b/sqle/utils/file.go index 5fe22435ce..e9ba7e6f53 100644 --- a/sqle/utils/file.go +++ b/sqle/utils/file.go @@ -5,6 +5,7 @@ import ( "io" "io/fs" "os" + "strings" "time" ) @@ -80,6 +81,9 @@ type ExportFormat string const ( CsvExportFormat ExportFormat = "csv" ExcelExportFormat ExportFormat = "excel" + ExportFormatHTML ExportFormat = "html" + ExportFormatPDF ExportFormat = "pdf" + ExportFormatWORD ExportFormat = "word" ) // ExportDataResult 导出数据的结果 @@ -168,6 +172,25 @@ func NormalizeExportFormat(format *ExportFormat) ExportFormat { return *format } +// NormalizeExportFormatStr 规范化导出格式参数(字符串版本),默认返回 CSV(向后兼容) +// 支持 html/pdf/word/docx/excel/xlsx/csv 等输入的规范化,空字符串和无效值默认返回 CSV。 +func NormalizeExportFormatStr(format string) ExportFormat { + switch strings.ToLower(strings.TrimSpace(format)) { + case "html": + return ExportFormatHTML + case "pdf": + return ExportFormatPDF + case "word", "docx": + return ExportFormatWORD + case "excel", "xlsx": + return ExcelExportFormat + case "csv", "": + return CsvExportFormat + default: + return CsvExportFormat + } +} + // ExportData 根据导出格式导出数据 // header: 表头字符串数组 // rows: 数据行,二维字符串数组 diff --git a/sqle/utils/report_generator.go b/sqle/utils/report_generator.go new file mode 100644 index 0000000000..34e3366d44 --- /dev/null +++ b/sqle/utils/report_generator.go @@ -0,0 +1,144 @@ +package utils + +import ( + "fmt" + "time" +) + +// AuditReportData 审核报告完整数据模型 +type AuditReportData struct { + // 元信息 + TaskID uint64 `json:"task_id"` + Title string `json:"title"` // 报告标题 (i18n) + InstanceName string `json:"instance_name"` // 数据源名称 + Schema string `json:"schema"` + GeneratedAt time.Time `json:"generated_at"` // 报告生成时间 + Lang string `json:"lang"` // 语言: zh-CN / en-US + LogoBase64 string `json:"logo_base64"` // Logo 图片 base64 + + // 审核概要 + Summary AuditSummary `json:"summary"` + + // 审核结果统计 + Statistics AuditStatistics `json:"statistics"` + + // SQL 列表 + SQLList []AuditSQLItem `json:"sql_list"` // 全部 SQL + ProblemSQLs []AuditSQLItem `json:"problem_sqls"` // 问题 SQL(AuditLevel != normal) + + // 国际化标签 + Labels ReportLabels `json:"labels"` +} + +// CSVHeaders 返回 CSV 报告的表头列表 +func (d *AuditReportData) CSVHeaders() []string { + return []string{ + d.Labels.Number, + d.Labels.SQL, + d.Labels.AuditStatus, + d.Labels.AuditResult, + d.Labels.ExecStatus, + d.Labels.ExecResult, + d.Labels.RollbackSQL, + d.Labels.Description, + } +} + +// AuditSummary 审核概要 +type AuditSummary struct { + AuditTime string `json:"audit_time"` + InstanceName string `json:"instance_name"` + Schema string `json:"schema"` + TotalSQL int `json:"total_sql"` + PassRate float64 `json:"pass_rate"` + Score int32 `json:"score"` + AuditLevel string `json:"audit_level"` +} + +// AuditStatistics 审核结果统计 +type AuditStatistics struct { + LevelDistribution []LevelCount `json:"level_distribution"` // 按等级分布 + RuleHits []RuleHit `json:"rule_hits"` // 规则命中统计 +} + +// LevelCount 等级统计 +type LevelCount struct { + Level string `json:"level"` // normal/notice/warn/error + Count int `json:"count"` +} + +// RuleHit 规则命中统计 +type RuleHit struct { + RuleName string `json:"rule_name"` + HitCount int `json:"hit_count"` +} + +// AuditSQLItem 单条 SQL 审核结果 +type AuditSQLItem struct { + Number uint `json:"number"` + SQL string `json:"sql"` + AuditLevel string `json:"audit_level"` + AuditStatus string `json:"audit_status"` + AuditResult string `json:"audit_result"` // 审核结果描述 + ExecStatus string `json:"exec_status"` + ExecResult string `json:"exec_result"` + RollbackSQL string `json:"rollback_sql"` + Description string `json:"description"` + // HTML/PDF/WORD 报告扩展字段 + RuleName string `json:"rule_name"` // 触发的规则名称 + Suggestion string `json:"suggestion"` // 优化建议 +} + +// ToCSVRow 将审核 SQL 项转换为 CSV 行数据 +func (item *AuditSQLItem) ToCSVRow() []string { + return []string{ + fmt.Sprintf("%d", item.Number), + item.SQL, + item.AuditStatus, + item.AuditResult, + item.ExecStatus, + item.ExecResult, + item.RollbackSQL, + item.Description, + } +} + +// ReportLabels 报告中的国际化标签 +type ReportLabels struct { + AuditSummary string `json:"audit_summary"` + ResultStatistics string `json:"result_statistics"` + ProblemSQLList string `json:"problem_sql_list"` + RuleHitStatistics string `json:"rule_hit_statistics"` + AuditTime string `json:"audit_time"` + DataSource string `json:"data_source"` + Schema string `json:"schema"` + TotalSQL string `json:"total_sql"` + PassRate string `json:"pass_rate"` + Score string `json:"score"` + AuditLevel string `json:"audit_level"` + Number string `json:"number"` + SQL string `json:"sql"` + AuditStatus string `json:"audit_status"` + AuditResult string `json:"audit_result"` + ExecStatus string `json:"exec_status"` + ExecResult string `json:"exec_result"` + RollbackSQL string `json:"rollback_sql"` + RuleName string `json:"rule_name"` + Description string `json:"description"` + Suggestion string `json:"suggestion"` + Count string `json:"count"` + HitCount string `json:"hit_count"` +} + +// ReportGenerator 报告生成器接口 +type ReportGenerator interface { + // Generate 根据报告数据生成指定格式的文件 + Generate(data *AuditReportData) (*ExportDataResult, error) + // Format 返回生成器支持的格式 + Format() ExportFormat +} + +// ExportAuditReport 统一导出入口(CE/EE 通过 build tags 区分实现) +// CE 版本支持 CSV 和 HTML 格式,EE 版本额外支持 PDF 和 WORD 格式。 +// 函数签名:func ExportAuditReport(format ExportFormat, data *AuditReportData) (*ExportDataResult, error) +// 实现分别位于 report_generator_ce.go 和 report_generator_ee.go 中。 diff --git a/sqle/utils/report_generator_test.go b/sqle/utils/report_generator_test.go new file mode 100644 index 0000000000..d6f07c8b30 --- /dev/null +++ b/sqle/utils/report_generator_test.go @@ -0,0 +1,90 @@ +package utils + +import ( + "testing" +) + +func TestNormalizeExportFormatStr(t *testing.T) { + testCases := map[string]struct { + input string + expected ExportFormat + }{ + "empty string defaults to csv": { + input: "", + expected: CsvExportFormat, + }, + "csv returns csv": { + input: "csv", + expected: CsvExportFormat, + }, + "CSV uppercase returns csv": { + input: "CSV", + expected: CsvExportFormat, + }, + "excel returns excel": { + input: "excel", + expected: ExcelExportFormat, + }, + "xlsx returns excel": { + input: "xlsx", + expected: ExcelExportFormat, + }, + "html returns html": { + input: "html", + expected: ExportFormatHTML, + }, + "HTML uppercase returns html": { + input: "HTML", + expected: ExportFormatHTML, + }, + "pdf returns pdf": { + input: "pdf", + expected: ExportFormatPDF, + }, + "PDF uppercase returns pdf": { + input: "PDF", + expected: ExportFormatPDF, + }, + "word returns word": { + input: "word", + expected: ExportFormatWORD, + }, + "WORD uppercase returns word": { + input: "WORD", + expected: ExportFormatWORD, + }, + "docx returns word": { + input: "docx", + expected: ExportFormatWORD, + }, + "DOCX uppercase returns word": { + input: "DOCX", + expected: ExportFormatWORD, + }, + "invalid value defaults to csv": { + input: "invalid", + expected: CsvExportFormat, + }, + "unknown format defaults to csv": { + input: "json", + expected: CsvExportFormat, + }, + "whitespace-only defaults to csv": { + input: " ", + expected: CsvExportFormat, + }, + "leading and trailing spaces are trimmed": { + input: " pdf ", + expected: ExportFormatPDF, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result := NormalizeExportFormatStr(tc.input) + if result != tc.expected { + t.Errorf("NormalizeExportFormatStr(%q) = %q, want %q", tc.input, result, tc.expected) + } + }) + } +} From 768df70eae6f98937e20f20ad40aaeb0e55479a6 Mon Sep 17 00:00:00 2001 From: actiontech-zihan Date: Tue, 31 Mar 2026 14:14:15 +0000 Subject: [PATCH 02/14] feat(utils): implement CSV report generator with unit tests Implement CSVReportGenerator that reuses the existing CSVBuilder to generate CSV audit reports via the ReportGenerator interface. Add comprehensive unit tests covering normal data, empty SQL list, and special character escaping (comma, newline, double quote). --- sqle/utils/report_csv.go | 55 +++++ sqle/utils/report_generator_test.go | 348 ++++++++++++++++++++++++++++ 2 files changed, 403 insertions(+) create mode 100644 sqle/utils/report_csv.go diff --git a/sqle/utils/report_csv.go b/sqle/utils/report_csv.go new file mode 100644 index 0000000000..7d16b07a7a --- /dev/null +++ b/sqle/utils/report_csv.go @@ -0,0 +1,55 @@ +package utils + +import "fmt" + +// CSVReportGenerator CSV 格式报告生成器 +// 复用已有的 CSVBuilder 生成 CSV 报告,实现 ReportGenerator 接口。 +type CSVReportGenerator struct{} + +// NewCSVReportGenerator 创建并返回一个新的 CSVReportGenerator 实例 +func NewCSVReportGenerator() *CSVReportGenerator { + return &CSVReportGenerator{} +} + +// Format 返回生成器支持的导出格式 +func (g *CSVReportGenerator) Format() ExportFormat { + return CsvExportFormat +} + +// Generate 根据审核报告数据生成 CSV 格式的文件 +// +// 参数: +// +// data: 审核报告完整数据模型 +// +// 返回: +// +// *ExportDataResult: 包含 CSV 文件内容、ContentType 和文件名 +// error: 生成过程中的错误 +func (g *CSVReportGenerator) Generate(data *AuditReportData) (*ExportDataResult, error) { + builder := NewCSVBuilder() + + // 写入表头 + if err := builder.WriteHeader(data.CSVHeaders()); err != nil { + return nil, fmt.Errorf("write csv header failed: %v", err) + } + + // 写入数据行 + for _, sql := range data.SQLList { + if err := builder.WriteRow(sql.ToCSVRow()); err != nil { + return nil, fmt.Errorf("write csv row failed: %v", err) + } + } + + // 刷新缓冲区并获取内容 + content := builder.FlushAndGetBuffer().Bytes() + if err := builder.Error(); err != nil { + return nil, fmt.Errorf("csv builder error: %v", err) + } + + return &ExportDataResult{ + Content: content, + ContentType: "text/csv", + FileName: fmt.Sprintf("SQL_audit_report_%s_%d.csv", data.InstanceName, data.TaskID), + }, nil +} diff --git a/sqle/utils/report_generator_test.go b/sqle/utils/report_generator_test.go index d6f07c8b30..eed95805c7 100644 --- a/sqle/utils/report_generator_test.go +++ b/sqle/utils/report_generator_test.go @@ -1,7 +1,9 @@ package utils import ( + "strings" "testing" + "time" ) func TestNormalizeExportFormatStr(t *testing.T) { @@ -88,3 +90,349 @@ func TestNormalizeExportFormatStr(t *testing.T) { }) } } + +// buildTestReportData 构建测试用的 AuditReportData +func buildTestReportData() *AuditReportData { + return &AuditReportData{ + TaskID: 1001, + Title: "SQL Audit Report", + InstanceName: "test-mysql", + Schema: "test_db", + GeneratedAt: time.Now(), + Lang: "en-US", + Summary: AuditSummary{ + AuditTime: "2026-03-31 10:00:00", + InstanceName: "test-mysql", + Schema: "test_db", + TotalSQL: 3, + PassRate: 66.7, + Score: 70, + AuditLevel: "warn", + }, + Statistics: AuditStatistics{ + LevelDistribution: []LevelCount{ + {Level: "normal", Count: 1}, + {Level: "warn", Count: 1}, + {Level: "error", Count: 1}, + }, + RuleHits: []RuleHit{ + {RuleName: "no_select_all", HitCount: 1}, + {RuleName: "no_drop_table", HitCount: 1}, + }, + }, + SQLList: []AuditSQLItem{ + { + Number: 1, + SQL: "SELECT * FROM users", + AuditLevel: "warn", + AuditStatus: "finished", + AuditResult: "should not use SELECT *", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "", + Description: "query all users", + RuleName: "no_select_all", + Suggestion: "specify column names", + }, + { + Number: 2, + SQL: "DROP TABLE test", + AuditLevel: "error", + AuditStatus: "finished", + AuditResult: "DROP TABLE is prohibited", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "", + Description: "drop test table", + RuleName: "no_drop_table", + Suggestion: "do not use DROP TABLE", + }, + { + Number: 3, + SQL: "INSERT INTO t VALUES(1)", + AuditLevel: "normal", + AuditStatus: "finished", + AuditResult: "", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "DELETE FROM t WHERE id=1", + Description: "insert a row", + }, + }, + ProblemSQLs: []AuditSQLItem{ + {Number: 1, SQL: "SELECT * FROM users", AuditLevel: "warn", RuleName: "no_select_all"}, + {Number: 2, SQL: "DROP TABLE test", AuditLevel: "error", RuleName: "no_drop_table"}, + }, + Labels: ReportLabels{ + AuditSummary: "Audit Summary", + ResultStatistics: "Audit Result Statistics", + ProblemSQLList: "Problem SQL List", + RuleHitStatistics: "Rule Hit Statistics", + AuditTime: "Audit Time", + DataSource: "Data Source", + Schema: "Schema", + TotalSQL: "Total SQL", + PassRate: "Pass Rate", + Score: "Score", + AuditLevel: "Audit Level", + Number: "Number", + SQL: "SQL", + AuditStatus: "Audit Status", + AuditResult: "Audit Result", + ExecStatus: "Exec Status", + ExecResult: "Exec Result", + RollbackSQL: "Rollback SQL", + RuleName: "Rule Name", + Description: "Description", + Suggestion: "Suggestion", + Count: "Count", + HitCount: "Hit Count", + }, + } +} + +func TestCSVReportGenerator_Normal(t *testing.T) { + testCases := map[string]struct { + data *AuditReportData + wantContentType string + wantFilePrefix string + wantFileSuffix string + wantBOM bool + wantHeaders []string + wantDataRows int + }{ + "normal data generates valid CSV report": { + data: buildTestReportData(), + wantContentType: "text/csv", + wantFilePrefix: "SQL_audit_report_test-mysql_1001", + wantFileSuffix: ".csv", + wantBOM: true, + wantHeaders: []string{"Number", "SQL", "Audit Status", "Audit Result", "Exec Status", "Exec Result", "Rollback SQL", "Description"}, + wantDataRows: 3, + }, + } + + gen := NewCSVReportGenerator() + + // Verify the generator implements ReportGenerator interface + var _ ReportGenerator = gen + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result, err := gen.Generate(tc.data) + if err != nil { + t.Fatalf("Generate() returned unexpected error: %v", err) + } + if result == nil { + t.Fatal("Generate() returned nil result") + } + + // Verify ContentType + if result.ContentType != tc.wantContentType { + t.Errorf("ContentType = %q, want %q", result.ContentType, tc.wantContentType) + } + + // Verify FileName contains InstanceName and TaskID, ends with .csv + if !strings.Contains(result.FileName, tc.data.InstanceName) { + t.Errorf("FileName %q does not contain InstanceName %q", result.FileName, tc.data.InstanceName) + } + if !strings.Contains(result.FileName, "1001") { + t.Errorf("FileName %q does not contain TaskID", result.FileName) + } + if !strings.HasSuffix(result.FileName, tc.wantFileSuffix) { + t.Errorf("FileName %q does not end with %q", result.FileName, tc.wantFileSuffix) + } + if !strings.HasPrefix(result.FileName, tc.wantFilePrefix) { + t.Errorf("FileName %q does not start with %q", result.FileName, tc.wantFilePrefix) + } + + // Verify UTF-8 BOM + content := string(result.Content) + if tc.wantBOM { + if !strings.HasPrefix(content, "\xEF\xBB\xBF") { + t.Error("Content does not start with UTF-8 BOM") + } + } + + // Verify headers exist in content + for _, h := range tc.wantHeaders { + if !strings.Contains(content, h) { + t.Errorf("Content does not contain header %q", h) + } + } + + // Verify the number of data rows (excluding BOM and header line) + contentWithoutBOM := strings.TrimPrefix(content, "\xEF\xBB\xBF") + lines := strings.Split(strings.TrimRight(contentWithoutBOM, "\n"), "\n") + // First line is header, remaining are data rows + dataLineCount := len(lines) - 1 + if dataLineCount != tc.wantDataRows { + t.Errorf("data row count = %d, want %d", dataLineCount, tc.wantDataRows) + } + }) + } +} + +func TestCSVReportGenerator_EmptyData(t *testing.T) { + testCases := map[string]struct { + data *AuditReportData + wantDataRows int + }{ + "empty SQL list produces header only": { + data: &AuditReportData{ + TaskID: 2002, + InstanceName: "empty-instance", + SQLList: []AuditSQLItem{}, + Labels: ReportLabels{ + Number: "Number", + SQL: "SQL", + AuditStatus: "Audit Status", + AuditResult: "Audit Result", + ExecStatus: "Exec Status", + ExecResult: "Exec Result", + RollbackSQL: "Rollback SQL", + Description: "Description", + }, + }, + wantDataRows: 0, + }, + } + + gen := NewCSVReportGenerator() + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result, err := gen.Generate(tc.data) + if err != nil { + t.Fatalf("Generate() returned unexpected error: %v", err) + } + if result == nil { + t.Fatal("Generate() returned nil result") + } + + content := string(result.Content) + + // Verify BOM is present + if !strings.HasPrefix(content, "\xEF\xBB\xBF") { + t.Error("Content does not start with UTF-8 BOM") + } + + // Verify only header line, no data rows + contentWithoutBOM := strings.TrimPrefix(content, "\xEF\xBB\xBF") + lines := strings.Split(strings.TrimRight(contentWithoutBOM, "\n"), "\n") + // Should have exactly 1 line (header only) + if len(lines) != 1 { + t.Errorf("expected 1 line (header only), got %d lines", len(lines)) + } + + // Verify headers are present + headerLine := lines[0] + for _, h := range []string{"Number", "SQL", "Audit Status"} { + if !strings.Contains(headerLine, h) { + t.Errorf("header line does not contain %q", h) + } + } + + // Verify FileName + if result.FileName != "SQL_audit_report_empty-instance_2002.csv" { + t.Errorf("FileName = %q, want %q", result.FileName, "SQL_audit_report_empty-instance_2002.csv") + } + }) + } +} + +func TestCSVReportGenerator_SpecialChars(t *testing.T) { + testCases := map[string]struct { + sqlItem AuditSQLItem + wantInRow string + description string + }{ + "SQL with comma is quoted": { + sqlItem: AuditSQLItem{ + Number: 1, + SQL: "SELECT a, b FROM users", + AuditStatus: "finished", + AuditResult: "ok", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "", + Description: "", + }, + wantInRow: `"SELECT a, b FROM users"`, + description: "field containing comma should be wrapped in double quotes", + }, + "SQL with double quote is escaped": { + sqlItem: AuditSQLItem{ + Number: 2, + SQL: `SELECT "name" FROM users`, + AuditStatus: "finished", + AuditResult: "ok", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "", + Description: "", + }, + wantInRow: `"SELECT ""name"" FROM users"`, + description: "double quotes within a field should be escaped as two double quotes", + }, + "SQL with newline is quoted": { + sqlItem: AuditSQLItem{ + Number: 3, + SQL: "SELECT *\nFROM users", + AuditStatus: "finished", + AuditResult: "ok", + ExecStatus: "initialized", + ExecResult: "", + RollbackSQL: "", + Description: "", + }, + wantInRow: "\"SELECT *\nFROM users\"", + description: "field containing newline should be wrapped in double quotes", + }, + } + + gen := NewCSVReportGenerator() + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := &AuditReportData{ + TaskID: 3003, + InstanceName: "special-instance", + SQLList: []AuditSQLItem{tc.sqlItem}, + Labels: ReportLabels{ + Number: "Number", + SQL: "SQL", + AuditStatus: "Audit Status", + AuditResult: "Audit Result", + ExecStatus: "Exec Status", + ExecResult: "Exec Result", + RollbackSQL: "Rollback SQL", + Description: "Description", + }, + } + + result, err := gen.Generate(data) + if err != nil { + t.Fatalf("Generate() returned unexpected error: %v", err) + } + if result == nil { + t.Fatal("Generate() returned nil result") + } + + content := string(result.Content) + + // Verify that the special character handling is correct + if !strings.Contains(content, tc.wantInRow) { + t.Errorf("%s\ncontent does not contain expected substring %q\nfull content:\n%s", + tc.description, tc.wantInRow, content) + } + }) + } +} + +func TestCSVReportGenerator_Format(t *testing.T) { + gen := NewCSVReportGenerator() + if gen.Format() != CsvExportFormat { + t.Errorf("Format() = %q, want %q", gen.Format(), CsvExportFormat) + } +} From 47500f711ea5fb3e7f09f02adac00d85477ce274 Mon Sep 17 00:00:00 2001 From: actiontech-zihan Date: Tue, 31 Mar 2026 14:24:14 +0000 Subject: [PATCH 03/14] feat(utils): implement HTML report generator with template and unit tests - Add self-contained HTML report template (templates/audit_report.html) with 4 content modules: audit summary, result statistics, problem SQL list, and rule hit statistics. All text uses Go template i18n variables, CSS is inlined, SQL displayed in
 tags, logo supports base64 embedding.
- Add report_html_template.go using embed.FS to embed the HTML template file.
- Add report_html.go implementing HTMLReportGenerator with html/template for
  automatic XSS prevention via HTML escaping. Content-Type: text/html,
  filename format: SQL_audit_report_{instance}_{taskId}.html.
- Add 4 unit tests: Normal (ContentType, HTML tags, SQL content, labels),
  XSSPrevention (script/img tags escaped), EmptyData (empty SQL list renders),
  LargeData (10000 SQL items performance check).
---
 sqle/utils/report_generator_test.go    | 321 +++++++++++++++++++++++++
 sqle/utils/report_html.go              |  58 +++++
 sqle/utils/report_html_template.go     |  18 ++
 sqle/utils/templates/audit_report.html | 280 +++++++++++++++++++++
 4 files changed, 677 insertions(+)
 create mode 100644 sqle/utils/report_html.go
 create mode 100644 sqle/utils/report_html_template.go
 create mode 100644 sqle/utils/templates/audit_report.html

diff --git a/sqle/utils/report_generator_test.go b/sqle/utils/report_generator_test.go
index eed95805c7..6cf3ede59e 100644
--- a/sqle/utils/report_generator_test.go
+++ b/sqle/utils/report_generator_test.go
@@ -1,6 +1,7 @@
 package utils
 
 import (
+	"fmt"
 	"strings"
 	"testing"
 	"time"
@@ -436,3 +437,323 @@ func TestCSVReportGenerator_Format(t *testing.T) {
 		t.Errorf("Format() = %q, want %q", gen.Format(), CsvExportFormat)
 	}
 }
+
+func TestHTMLReportGenerator_Normal(t *testing.T) {
+	gen, err := NewHTMLReportGenerator()
+	if err != nil {
+		t.Fatalf("NewHTMLReportGenerator() returned unexpected error: %v", err)
+	}
+
+	// Verify the generator implements ReportGenerator interface
+	var _ ReportGenerator = gen
+
+	testCases := map[string]struct {
+		data             *AuditReportData
+		wantContentType  string
+		wantFilePrefix   string
+		wantFileSuffix   string
+		wantHTMLTags     []string
+		wantSQLContents  []string
+		wantLabels       []string
+	}{
+		"normal data generates valid HTML report": {
+			data:            buildTestReportData(),
+			wantContentType: "text/html",
+			wantFilePrefix:  "SQL_audit_report_test-mysql_1001",
+			wantFileSuffix:  ".html",
+			wantHTMLTags:    []string{"", "

", "

", "
", "", ""},
+			wantSQLContents: []string{"SELECT * FROM users", "DROP TABLE test"},
+			wantLabels:      []string{"Audit Summary", "Audit Result Statistics", "Problem SQL List", "Rule Hit Statistics"},
+		},
+	}
+
+	for name, tc := range testCases {
+		t.Run(name, func(t *testing.T) {
+			result, err := gen.Generate(tc.data)
+			if err != nil {
+				t.Fatalf("Generate() returned unexpected error: %v", err)
+			}
+			if result == nil {
+				t.Fatal("Generate() returned nil result")
+			}
+
+			// Verify ContentType
+			if result.ContentType != tc.wantContentType {
+				t.Errorf("ContentType = %q, want %q", result.ContentType, tc.wantContentType)
+			}
+
+			// Verify FileName format
+			if !strings.HasPrefix(result.FileName, tc.wantFilePrefix) {
+				t.Errorf("FileName %q does not start with %q", result.FileName, tc.wantFilePrefix)
+			}
+			if !strings.HasSuffix(result.FileName, tc.wantFileSuffix) {
+				t.Errorf("FileName %q does not end with %q", result.FileName, tc.wantFileSuffix)
+			}
+			if !strings.Contains(result.FileName, tc.data.InstanceName) {
+				t.Errorf("FileName %q does not contain InstanceName %q", result.FileName, tc.data.InstanceName)
+			}
+			if !strings.Contains(result.FileName, "1001") {
+				t.Errorf("FileName %q does not contain TaskID", result.FileName)
+			}
+
+			// Verify HTML content contains key HTML tags
+			content := string(result.Content)
+			for _, tag := range tc.wantHTMLTags {
+				if !strings.Contains(content, tag) {
+					t.Errorf("Content does not contain expected HTML tag %q", tag)
+				}
+			}
+
+			// Verify SQL contents exist in the output
+			for _, sql := range tc.wantSQLContents {
+				if !strings.Contains(content, sql) {
+					t.Errorf("Content does not contain expected SQL %q", sql)
+				}
+			}
+
+			// Verify i18n labels are rendered
+			for _, label := range tc.wantLabels {
+				if !strings.Contains(content, label) {
+					t.Errorf("Content does not contain expected label %q", label)
+				}
+			}
+
+			// Verify Format() returns ExportFormatHTML
+			if gen.Format() != ExportFormatHTML {
+				t.Errorf("Format() = %q, want %q", gen.Format(), ExportFormatHTML)
+			}
+		})
+	}
+}
+
+func TestHTMLReportGenerator_XSSPrevention(t *testing.T) {
+	gen, err := NewHTMLReportGenerator()
+	if err != nil {
+		t.Fatalf("NewHTMLReportGenerator() returned unexpected error: %v", err)
+	}
+
+	testCases := map[string]struct {
+		maliciousSQL    string
+		wantAbsent      []string
+		wantDescription string
+	}{
+		"script tag in SQL is escaped": {
+			maliciousSQL: "",
+			wantAbsent:   []string{""},
+			wantDescription: "script tags should be HTML-escaped by html/template",
+		},
+		"img onerror in SQL is escaped": {
+			maliciousSQL: ``,
+			wantAbsent:   []string{`onerror="alert`},
+			wantDescription: "event handler attributes should be HTML-escaped",
+		},
+		"script tag in description is escaped": {
+			maliciousSQL: "SELECT 1",
+			wantAbsent:   []string{""
+			} else {
+				// Put malicious content in SQL field
+				data.ProblemSQLs[0].SQL = tc.maliciousSQL
+				data.SQLList[0].SQL = tc.maliciousSQL
+			}
+
+			result, err := gen.Generate(data)
+			if err != nil {
+				t.Fatalf("Generate() returned unexpected error: %v", err)
+			}
+
+			content := string(result.Content)
+
+			// Verify that raw malicious content is NOT present (it should be escaped)
+			for _, absent := range tc.wantAbsent {
+				if strings.Contains(content, absent) {
+					snippetLen := 500
+					if len(content) < snippetLen {
+						snippetLen = len(content)
+					}
+					t.Errorf("%s: content contains unescaped %q\ncontent snippet: %s",
+						tc.wantDescription, absent, content[:snippetLen])
+				}
+			}
+		})
+	}
+}
+
+func TestHTMLReportGenerator_EmptyData(t *testing.T) {
+	gen, err := NewHTMLReportGenerator()
+	if err != nil {
+		t.Fatalf("NewHTMLReportGenerator() returned unexpected error: %v", err)
+	}
+
+	testCases := map[string]struct {
+		data         *AuditReportData
+		wantHTMLTags []string
+	}{
+		"empty SQL list renders without error": {
+			data: &AuditReportData{
+				TaskID:       2002,
+				Title:        "Empty Report",
+				InstanceName: "empty-instance",
+				Schema:       "empty_db",
+				GeneratedAt:  time.Now(),
+				Lang:         "en-US",
+				Summary: AuditSummary{
+					AuditTime:    "2026-03-31 10:00:00",
+					InstanceName: "empty-instance",
+					Schema:       "empty_db",
+					TotalSQL:     0,
+					PassRate:     100.0,
+					Score:        100,
+				},
+				Statistics: AuditStatistics{
+					LevelDistribution: []LevelCount{},
+					RuleHits:          []RuleHit{},
+				},
+				SQLList:     []AuditSQLItem{},
+				ProblemSQLs: []AuditSQLItem{},
+				Labels: ReportLabels{
+					AuditSummary:      "Audit Summary",
+					ResultStatistics:  "Audit Result Statistics",
+					ProblemSQLList:    "Problem SQL List",
+					RuleHitStatistics: "Rule Hit Statistics",
+					AuditTime:         "Audit Time",
+					DataSource:        "Data Source",
+					Schema:            "Schema",
+					TotalSQL:          "Total SQL",
+					PassRate:          "Pass Rate",
+					Score:             "Score",
+					AuditLevel:        "Audit Level",
+					Number:            "Number",
+					SQL:               "SQL",
+					AuditResult:       "Audit Result",
+					RuleName:          "Rule Name",
+					Description:       "Description",
+					Suggestion:        "Suggestion",
+					Count:             "Count",
+					HitCount:          "Hit Count",
+				},
+			},
+			wantHTMLTags: []string{"", "", "

", "

", "

"}, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result, err := gen.Generate(tc.data) + if err != nil { + t.Fatalf("Generate() returned unexpected error: %v", err) + } + if result == nil { + t.Fatal("Generate() returned nil result") + } + + // Verify ContentType + if result.ContentType != "text/html" { + t.Errorf("ContentType = %q, want %q", result.ContentType, "text/html") + } + + // Verify the output is valid HTML with required tags + content := string(result.Content) + for _, tag := range tc.wantHTMLTags { + if !strings.Contains(content, tag) { + t.Errorf("Content does not contain expected HTML tag %q", tag) + } + } + + // Verify FileName + expectedFileName := "SQL_audit_report_empty-instance_2002.html" + if result.FileName != expectedFileName { + t.Errorf("FileName = %q, want %q", result.FileName, expectedFileName) + } + + // Verify the title is rendered + if !strings.Contains(content, "Empty Report") { + t.Error("Content does not contain the report title") + } + }) + } +} + +func TestHTMLReportGenerator_LargeData(t *testing.T) { + gen, err := NewHTMLReportGenerator() + if err != nil { + t.Fatalf("NewHTMLReportGenerator() returned unexpected error: %v", err) + } + + testCases := map[string]struct { + sqlCount int + }{ + "10000 SQL items generates successfully": { + sqlCount: 10000, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + + // Build large SQL list and problem SQL list + sqlList := make([]AuditSQLItem, 0, tc.sqlCount) + problemSQLs := make([]AuditSQLItem, 0, tc.sqlCount/2) + for i := 0; i < tc.sqlCount; i++ { + item := AuditSQLItem{ + Number: uint(i + 1), + SQL: fmt.Sprintf("SELECT * FROM table_%d WHERE id = %d", i, i), + AuditLevel: "warn", + AuditStatus: "finished", + AuditResult: fmt.Sprintf("audit result for SQL #%d", i), + RuleName: "no_select_all", + Description: fmt.Sprintf("description for SQL #%d", i), + Suggestion: "specify column names", + } + sqlList = append(sqlList, item) + if i%2 == 0 { + problemSQLs = append(problemSQLs, item) + } + } + data.SQLList = sqlList + data.ProblemSQLs = problemSQLs + data.Summary.TotalSQL = tc.sqlCount + + start := time.Now() + result, err := gen.Generate(data) + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("Generate() returned unexpected error: %v", err) + } + if result == nil { + t.Fatal("Generate() returned nil result") + } + + // Verify content is not empty + if len(result.Content) == 0 { + t.Error("Generate() returned empty content") + } + + // Verify ContentType + if result.ContentType != "text/html" { + t.Errorf("ContentType = %q, want %q", result.ContentType, "text/html") + } + + // Log the elapsed time (informational, not a hard failure since dev machines vary) + t.Logf("Generated HTML report with %d SQLs in %v, output size: %d bytes", + tc.sqlCount, elapsed, len(result.Content)) + + // Soft check: warn if it takes more than 5 seconds + if elapsed > 5*time.Second { + t.Logf("WARNING: Generation took %v which exceeds 5s target", elapsed) + } + }) + } +} diff --git a/sqle/utils/report_html.go b/sqle/utils/report_html.go new file mode 100644 index 0000000000..9ca036334c --- /dev/null +++ b/sqle/utils/report_html.go @@ -0,0 +1,58 @@ +package utils + +import ( + "bytes" + "fmt" + "html/template" +) + +// HTMLReportGenerator HTML 格式报告生成器 +// 使用 html/template 渲染嵌入的 HTML 模板,自动进行 HTML 转义防止 XSS。 +// 实现 ReportGenerator 接口。 +type HTMLReportGenerator struct { + tmpl *template.Template +} + +// NewHTMLReportGenerator 创建并返回一个新的 HTMLReportGenerator 实例。 +// 在创建时解析嵌入的 HTML 模板,如果模板解析失败则返回错误。 +func NewHTMLReportGenerator() (*HTMLReportGenerator, error) { + templateContent, err := GetAuditReportHTMLTemplate() + if err != nil { + return nil, fmt.Errorf("read HTML template failed: %w", err) + } + + tmpl, err := template.New("audit_report").Parse(templateContent) + if err != nil { + return nil, fmt.Errorf("parse HTML template failed: %w", err) + } + + return &HTMLReportGenerator{tmpl: tmpl}, nil +} + +// Format 返回生成器支持的导出格式 +func (g *HTMLReportGenerator) Format() ExportFormat { + return ExportFormatHTML +} + +// Generate 根据审核报告数据生成 HTML 格式的文件 +// +// 参数: +// +// data: 审核报告完整数据模型 +// +// 返回: +// +// *ExportDataResult: 包含 HTML 文件内容、ContentType 和文件名 +// error: 生成过程中的错误 +func (g *HTMLReportGenerator) Generate(data *AuditReportData) (*ExportDataResult, error) { + var buf bytes.Buffer + if err := g.tmpl.Execute(&buf, data); err != nil { + return nil, fmt.Errorf("render HTML report failed: %w", err) + } + + return &ExportDataResult{ + Content: buf.Bytes(), + ContentType: "text/html", + FileName: fmt.Sprintf("SQL_audit_report_%s_%d.html", data.InstanceName, data.TaskID), + }, nil +} diff --git a/sqle/utils/report_html_template.go b/sqle/utils/report_html_template.go new file mode 100644 index 0000000000..773c506bd8 --- /dev/null +++ b/sqle/utils/report_html_template.go @@ -0,0 +1,18 @@ +package utils + +import "embed" + +//go:embed templates/audit_report.html +var auditReportTemplateFS embed.FS + +// auditReportHTMLTemplatePath is the path to the embedded HTML template file. +const auditReportHTMLTemplatePath = "templates/audit_report.html" + +// GetAuditReportHTMLTemplate reads the embedded HTML template and returns its content as a string. +func GetAuditReportHTMLTemplate() (string, error) { + content, err := auditReportTemplateFS.ReadFile(auditReportHTMLTemplatePath) + if err != nil { + return "", err + } + return string(content), nil +} diff --git a/sqle/utils/templates/audit_report.html b/sqle/utils/templates/audit_report.html new file mode 100644 index 0000000000..8564fd148f --- /dev/null +++ b/sqle/utils/templates/audit_report.html @@ -0,0 +1,280 @@ + + + + + + {{.Title}} + + + +
+ +
+ {{if .LogoBase64}}{{end}} +

{{.Title}}

+

{{.GeneratedAt.Format "2006-01-02 15:04:05"}}

+
+ + +
+

{{.Labels.AuditSummary}}

+

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
{{.Labels.AuditTime}}{{.Summary.AuditTime}}
{{.Labels.DataSource}}{{.Summary.InstanceName}}
{{.Labels.Schema}}{{.Summary.Schema}}
{{.Labels.TotalSQL}}{{.Summary.TotalSQL}}
{{.Labels.PassRate}}{{.Summary.PassRate}}%
{{.Labels.Score}}{{.Summary.Score}}
+ + + +
+

{{.Labels.ResultStatistics}}

+ + + + + + + + + {{range .Statistics.LevelDistribution}} + + + + + {{end}} + +
{{.Labels.AuditLevel}}{{.Labels.Count}}
{{.Level}}{{.Count}}
+
+ + +
+

{{.Labels.ProblemSQLList}}

+ + + + + + + + + + + + + + {{range .ProblemSQLs}} + + + + + + + + + + {{end}} + +
{{.Labels.Number}}{{.Labels.SQL}}{{.Labels.AuditLevel}}{{.Labels.AuditResult}}{{.Labels.RuleName}}{{.Labels.Description}}{{.Labels.Suggestion}}
{{.Number}}
{{.SQL}}
{{.AuditLevel}}{{.AuditResult}}{{.RuleName}}{{.Description}}{{.Suggestion}}
+
+ + +
+

{{.Labels.RuleHitStatistics}}

+ + + + + + + + + {{range .Statistics.RuleHits}} + + + + + {{end}} + +
{{.Labels.RuleName}}{{.Labels.HitCount}}
{{.RuleName}}{{.HitCount}}
+
+ + + From 19ac482068432829b6ff094c7142ea2ba7e18956 Mon Sep 17 00:00:00 2001 From: actiontech-zihan Date: Fri, 10 Apr 2026 02:46:22 +0000 Subject: [PATCH 04/14] feat(utils): implement CE format dispatch for ExportAuditReport --- sqle/utils/report_generator_ce.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 sqle/utils/report_generator_ce.go diff --git a/sqle/utils/report_generator_ce.go b/sqle/utils/report_generator_ce.go new file mode 100644 index 0000000000..ac2723c0fa --- /dev/null +++ b/sqle/utils/report_generator_ce.go @@ -0,0 +1,26 @@ +//go:build !enterprise + +package utils + +import "fmt" + +// ExportAuditReport CE 版统一导出入口。 +// CE 版仅支持 CSV 和 HTML 两种格式。 +// 请求 PDF 或 WORD 格式时返回错误提示,提醒用户需要企业版。 +// 无效格式返回错误(REQ-6.3)。 +func ExportAuditReport(format ExportFormat, data *AuditReportData) (*ExportDataResult, error) { + switch format { + case CsvExportFormat: + return NewCSVReportGenerator().Generate(data) + case ExportFormatHTML: + gen, err := NewHTMLReportGenerator() + if err != nil { + return nil, err + } + return gen.Generate(data) + case ExportFormatPDF, ExportFormatWORD: + return nil, fmt.Errorf("export format %s is only supported in enterprise edition", format) + default: + return nil, fmt.Errorf("unsupported export format: %s", format) + } +} From 6eb0a3c8b21c5e5407d40e706d1a39c3d99552da Mon Sep 17 00:00:00 2001 From: actiontech-zihan Date: Tue, 31 Mar 2026 15:28:37 +0000 Subject: [PATCH 05/14] feat(api): refactor DownloadTaskSQLReportFile to support multi-format export Refactor the DownloadTaskSQLReportFile API handler to support exporting audit reports in multiple formats (CSV, HTML, PDF, WORD) via the new export_format query parameter. Default to CSV for backward compatibility. Key changes: - Add BuildAuditReportData function in report_data_builder.go that converts Task + SQL data into AuditReportData (placed in controller layer to avoid circular dependency between utils and model packages) - Refactor DownloadTaskSQLReportFile to use BuildAuditReportData and ExportAuditReport for format-agnostic report generation - Add 16 new i18n locale messages for report labels (zh/en TOML + Go vars) - Update Swagger annotations with export_format parameter - Add comprehensive unit tests for helper functions: toLevelCounts, toRuleHits, extractRuleInfo, buildReportLabels (13 test functions, map case pattern, covering normal/empty/nil/edge cases) Requirement refs: REQ-6.1, REQ-6.2, REQ-6.4, REQ-NF-2.1, REQ-NF-3.1 --- sqle/api/controller/v1/report_data_builder.go | 236 +++++++++ .../controller/v1/report_data_builder_test.go | 450 ++++++++++++++++++ sqle/api/controller/v1/task.go | 61 +-- sqle/locale/active.en.toml | 16 + sqle/locale/active.zh.toml | 16 + sqle/locale/message_zh.go | 20 + 6 files changed, 752 insertions(+), 47 deletions(-) create mode 100644 sqle/api/controller/v1/report_data_builder.go create mode 100644 sqle/api/controller/v1/report_data_builder_test.go diff --git a/sqle/api/controller/v1/report_data_builder.go b/sqle/api/controller/v1/report_data_builder.go new file mode 100644 index 0000000000..fda3ab4619 --- /dev/null +++ b/sqle/api/controller/v1/report_data_builder.go @@ -0,0 +1,236 @@ +package v1 + +import ( + "context" + "fmt" + "sort" + "strings" + "time" + + "github.com/actiontech/sqle/sqle/locale" + "github.com/actiontech/sqle/sqle/model" + "github.com/actiontech/sqle/sqle/server" + "github.com/actiontech/sqle/sqle/utils" +) + +// BuildAuditReportData 从 Task 和数据库查询构建报告数据。 +// 该函数放在 controller 层而非 utils 层,因为 utils 被 model 引用, +// 若 utils 反向引用 model 会产生循环依赖。 +func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, ctx context.Context) (*utils.AuditReportData, error) { + // 1. 获取 SQL 列表 + data := map[string]interface{}{ + "task_id": fmt.Sprintf("%d", task.ID), + "no_duplicate": noDuplicate, + } + + taskSQLsDetail, _, err := s.GetTaskSQLsByReq(data) + if err != nil { + return nil, fmt.Errorf("get task SQLs failed: %w", err) + } + + // 2. 获取回滚 SQL 映射 + rollbackSqlMap, err := server.BackupService{}.GetRollbackSqlsMap(task.ID) + if err != nil { + return nil, fmt.Errorf("get rollback SQLs failed: %w", err) + } + + // 3. 构建 SQL 列表和统计数据 + levelDist := make(map[string]int) + ruleHits := make(map[string]int) + var sqlList []utils.AuditSQLItem + var problemSQLs []utils.AuditSQLItem + + for _, td := range taskSQLsDetail { + // 构造临时 ExecuteSQL 对象以复用状态描述方法 + tempSQL := &model.ExecuteSQL{ + AuditResults: td.AuditResults, + AuditStatus: td.AuditStatus, + } + tempSQL.ExecStatus = td.ExecStatus + + // 提取规则名称和审核建议 + ruleName, suggestion := extractRuleInfo(td.AuditResults, ctx) + + item := utils.AuditSQLItem{ + Number: td.Number, + SQL: td.ExecSQL, + AuditLevel: td.AuditLevel, + AuditStatus: tempSQL.GetAuditStatusDesc(ctx), + AuditResult: tempSQL.GetAuditResultDesc(ctx), + ExecStatus: tempSQL.GetExecStatusDesc(ctx), + ExecResult: td.ExecResult, + RollbackSQL: strings.Join(rollbackSqlMap[td.Id], "\n"), + Description: td.Description, + RuleName: ruleName, + Suggestion: suggestion, + } + sqlList = append(sqlList, item) + + // 统计等级分布 + level := td.AuditLevel + if level == "" { + level = "normal" + } + levelDist[level]++ + + // 区分问题 SQL(AuditLevel 不是 normal 且不为空) + if level != "normal" { + problemSQLs = append(problemSQLs, item) + } + + // 统计规则命中 + for _, ar := range td.AuditResults { + if ar.RuleName != "" { + ruleHits[ar.RuleName]++ + } + } + } + + // 4. 构建国际化标签(当前使用 locale 包提供的 i18n 标签) + labels := buildReportLabels(ctx) + + // 5. 构建审核时间 + auditTime := time.Now().Format("2006-01-02 15:04:05") + if task.CreatedAt.Year() > 1 { + auditTime = task.CreatedAt.Format("2006-01-02 15:04:05") + } + + return &utils.AuditReportData{ + TaskID: uint64(task.ID), + Title: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelTitle), + InstanceName: task.InstanceName(), + Schema: task.Schema, + GeneratedAt: time.Now(), + Lang: locale.Bundle.GetLangTagFromCtx(ctx).String(), + LogoBase64: "", + Summary: utils.AuditSummary{ + AuditTime: auditTime, + InstanceName: task.InstanceName(), + Schema: task.Schema, + TotalSQL: len(sqlList), + PassRate: task.PassRate, + Score: task.Score, + AuditLevel: task.AuditLevel, + }, + Statistics: utils.AuditStatistics{ + LevelDistribution: toLevelCounts(levelDist), + RuleHits: toRuleHits(ruleHits), + }, + SQLList: sqlList, + ProblemSQLs: problemSQLs, + Labels: labels, + }, nil +} + +// extractRuleInfo 从审核结果中提取规则名称和审核建议。 +// 如果有多条规则命中,使用逗号分隔拼接。 +func extractRuleInfo(auditResults model.AuditResults, ctx context.Context) (ruleName string, suggestion string) { + if len(auditResults) == 0 { + return "", "" + } + + lang := locale.Bundle.GetLangTagFromCtx(ctx) + var ruleNames []string + var suggestions []string + + for _, ar := range auditResults { + if ar.RuleName != "" { + ruleNames = append(ruleNames, ar.RuleName) + } + msg := ar.GetAuditMsgByLangTag(lang) + if msg != "" { + suggestions = append(suggestions, msg) + } + } + + return strings.Join(ruleNames, ", "), strings.Join(suggestions, "; ") +} + +// toLevelCounts 将等级分布 map 转换为有序的 LevelCount 切片。 +// 按 error > warn > notice > normal 顺序排列。 +func toLevelCounts(dist map[string]int) []utils.LevelCount { + if len(dist) == 0 { + return []utils.LevelCount{} + } + + levelOrder := map[string]int{ + "error": 0, + "warn": 1, + "notice": 2, + "normal": 3, + } + + result := make([]utils.LevelCount, 0, len(dist)) + for level, count := range dist { + result = append(result, utils.LevelCount{ + Level: level, + Count: count, + }) + } + + sort.Slice(result, func(i, j int) bool { + oi, ok := levelOrder[result[i].Level] + if !ok { + oi = 99 + } + oj, ok := levelOrder[result[j].Level] + if !ok { + oj = 99 + } + return oi < oj + }) + + return result +} + +// toRuleHits 将规则命中 map 转换为按命中次数降序排列的 RuleHit 切片。 +func toRuleHits(hits map[string]int) []utils.RuleHit { + if len(hits) == 0 { + return []utils.RuleHit{} + } + + result := make([]utils.RuleHit, 0, len(hits)) + for name, count := range hits { + result = append(result, utils.RuleHit{ + RuleName: name, + HitCount: count, + }) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].HitCount > result[j].HitCount + }) + + return result +} + +// buildReportLabels 构建报告中使用的国际化标签。 +// 当前版本使用 locale 包已有的国际化消息和硬编码中文标签, +// 后续阶段 8 将接入 go-i18n 框架实现完整国际化。 +func buildReportLabels(ctx context.Context) utils.ReportLabels { + return utils.ReportLabels{ + AuditSummary: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelAuditSummary), + ResultStatistics: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelResultStatistics), + ProblemSQLList: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelProblemSQLList), + RuleHitStatistics: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelRuleHitStatistics), + AuditTime: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelAuditTime), + DataSource: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelDataSource), + Schema: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelSchema), + TotalSQL: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelTotalSQL), + PassRate: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelPassRate), + Score: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelScore), + AuditLevel: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelAuditLevel), + Number: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportIndex), + SQL: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportSQL), + AuditStatus: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportAuditStatus), + AuditResult: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportAuditResult), + ExecStatus: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportExecStatus), + ExecResult: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportExecResult), + RollbackSQL: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportRollbackSQL), + RuleName: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelRuleName), + Description: locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportDescription), + Suggestion: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelSuggestion), + Count: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelCount), + HitCount: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelHitCount), + } +} diff --git a/sqle/api/controller/v1/report_data_builder_test.go b/sqle/api/controller/v1/report_data_builder_test.go new file mode 100644 index 0000000000..e87b550852 --- /dev/null +++ b/sqle/api/controller/v1/report_data_builder_test.go @@ -0,0 +1,450 @@ +package v1 + +import ( + "context" + "testing" + + "github.com/actiontech/sqle/sqle/model" + "github.com/actiontech/sqle/sqle/utils" + "golang.org/x/text/language" +) + +func TestToLevelCounts(t *testing.T) { + testCases := map[string]struct { + input map[string]int + wantLen int + wantFirst string // expected first level (highest priority) + wantLast string // expected last level (lowest priority) + description string + }{ + "mixed levels sorted by severity": { + input: map[string]int{ + "normal": 5, + "error": 2, + "warn": 3, + "notice": 1, + }, + wantLen: 4, + wantFirst: "error", + wantLast: "normal", + description: "should sort error > warn > notice > normal", + }, + "empty map returns empty slice": { + input: map[string]int{}, + wantLen: 0, + description: "empty input should return empty result", + }, + "single level": { + input: map[string]int{ + "warn": 10, + }, + wantLen: 1, + wantFirst: "warn", + wantLast: "warn", + description: "single level should still work", + }, + "only normal level": { + input: map[string]int{ + "normal": 100, + }, + wantLen: 1, + wantFirst: "normal", + wantLast: "normal", + description: "all normal should return single entry", + }, + "unknown level sorted after known levels": { + input: map[string]int{ + "normal": 1, + "error": 1, + "unknown": 1, + }, + wantLen: 3, + wantFirst: "error", + wantLast: "unknown", + description: "unknown levels should be sorted after known levels", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result := toLevelCounts(tc.input) + + if len(result) != tc.wantLen { + t.Fatalf("toLevelCounts() returned %d items, want %d", len(result), tc.wantLen) + } + + if tc.wantLen == 0 { + return + } + + // Verify first (highest priority) element + if result[0].Level != tc.wantFirst { + t.Errorf("first level = %q, want %q (%s)", result[0].Level, tc.wantFirst, tc.description) + } + + // Verify last (lowest priority) element + if result[len(result)-1].Level != tc.wantLast { + t.Errorf("last level = %q, want %q (%s)", result[len(result)-1].Level, tc.wantLast, tc.description) + } + + // Verify counts match input + for _, lc := range result { + expectedCount, ok := tc.input[lc.Level] + if !ok { + t.Errorf("unexpected level %q in result", lc.Level) + continue + } + if lc.Count != expectedCount { + t.Errorf("level %q count = %d, want %d", lc.Level, lc.Count, expectedCount) + } + } + }) + } +} + +func TestToRuleHits(t *testing.T) { + testCases := map[string]struct { + input map[string]int + wantLen int + wantFirst string // expected first rule (highest hit count) + description string + }{ + "multiple rules sorted by hit count descending": { + input: map[string]int{ + "no_select_all": 5, + "no_drop_table": 10, + "add_index": 3, + }, + wantLen: 3, + wantFirst: "no_drop_table", + description: "should sort by hit count descending", + }, + "empty map returns empty slice": { + input: map[string]int{}, + wantLen: 0, + description: "empty input should return empty result", + }, + "single rule": { + input: map[string]int{ + "no_select_all": 7, + }, + wantLen: 1, + wantFirst: "no_select_all", + description: "single rule should work correctly", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + result := toRuleHits(tc.input) + + if len(result) != tc.wantLen { + t.Fatalf("toRuleHits() returned %d items, want %d", len(result), tc.wantLen) + } + + if tc.wantLen == 0 { + return + } + + // Verify first element has highest hit count + if result[0].RuleName != tc.wantFirst { + t.Errorf("first rule = %q, want %q (%s)", result[0].RuleName, tc.wantFirst, tc.description) + } + + // Verify descending order + for i := 1; i < len(result); i++ { + if result[i].HitCount > result[i-1].HitCount { + t.Errorf("rule at index %d (count=%d) > rule at index %d (count=%d), expected descending order", + i, result[i].HitCount, i-1, result[i-1].HitCount) + } + } + + // Verify counts match input + for _, rh := range result { + expectedCount, ok := tc.input[rh.RuleName] + if !ok { + t.Errorf("unexpected rule %q in result", rh.RuleName) + continue + } + if rh.HitCount != expectedCount { + t.Errorf("rule %q hit count = %d, want %d", rh.RuleName, rh.HitCount, expectedCount) + } + } + }) + } +} + +func TestExtractRuleInfo(t *testing.T) { + ctx := context.Background() + + testCases := map[string]struct { + auditResults model.AuditResults + wantRuleName string + wantHasRuleName bool + wantHasSugg bool + description string + }{ + "empty audit results": { + auditResults: model.AuditResults{}, + wantRuleName: "", + wantHasRuleName: false, + wantHasSugg: false, + description: "empty results should return empty strings", + }, + "single rule hit": { + auditResults: model.AuditResults{ + { + Level: "warn", + RuleName: "no_select_all", + I18nAuditResultInfo: model.I18nAuditResultInfo{ + language.Chinese: model.AuditResultInfo{Message: "should not use SELECT *"}, + }, + }, + }, + wantRuleName: "no_select_all", + wantHasRuleName: true, + wantHasSugg: true, + description: "single rule should return its name and message", + }, + "multiple rule hits": { + auditResults: model.AuditResults{ + { + Level: "warn", + RuleName: "no_select_all", + I18nAuditResultInfo: model.I18nAuditResultInfo{ + language.Chinese: model.AuditResultInfo{Message: "avoid SELECT *"}, + }, + }, + { + Level: "error", + RuleName: "no_drop_table", + I18nAuditResultInfo: model.I18nAuditResultInfo{ + language.Chinese: model.AuditResultInfo{Message: "DROP TABLE not allowed"}, + }, + }, + }, + wantRuleName: "no_select_all, no_drop_table", + wantHasRuleName: true, + wantHasSugg: true, + description: "multiple rules should be comma-separated", + }, + "rule with empty name is skipped": { + auditResults: model.AuditResults{ + { + Level: "notice", + RuleName: "", + I18nAuditResultInfo: model.I18nAuditResultInfo{ + language.Chinese: model.AuditResultInfo{Message: "some notice"}, + }, + }, + }, + wantRuleName: "", + wantHasRuleName: false, + wantHasSugg: true, + description: "rule with empty name should be skipped in rule names but message still included", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + ruleName, suggestion := extractRuleInfo(tc.auditResults, ctx) + + if tc.wantHasRuleName { + if ruleName != tc.wantRuleName { + t.Errorf("ruleName = %q, want %q (%s)", ruleName, tc.wantRuleName, tc.description) + } + } else { + if ruleName != "" { + t.Errorf("ruleName = %q, want empty (%s)", ruleName, tc.description) + } + } + + if tc.wantHasSugg { + if suggestion == "" { + t.Errorf("suggestion should not be empty (%s)", tc.description) + } + } else { + if suggestion != "" { + t.Errorf("suggestion = %q, want empty (%s)", suggestion, tc.description) + } + } + }) + } +} + +func TestBuildReportLabels(t *testing.T) { + ctx := context.Background() + + testCases := map[string]struct { + description string + }{ + "default context returns non-empty labels": { + description: "all label fields should be non-empty with default locale", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + labels := buildReportLabels(ctx) + + // Verify all label fields are non-empty + fieldChecks := map[string]string{ + "AuditSummary": labels.AuditSummary, + "ResultStatistics": labels.ResultStatistics, + "ProblemSQLList": labels.ProblemSQLList, + "RuleHitStatistics": labels.RuleHitStatistics, + "AuditTime": labels.AuditTime, + "DataSource": labels.DataSource, + "Schema": labels.Schema, + "TotalSQL": labels.TotalSQL, + "PassRate": labels.PassRate, + "Score": labels.Score, + "AuditLevel": labels.AuditLevel, + "Number": labels.Number, + "SQL": labels.SQL, + "AuditStatus": labels.AuditStatus, + "AuditResult": labels.AuditResult, + "ExecStatus": labels.ExecStatus, + "ExecResult": labels.ExecResult, + "RollbackSQL": labels.RollbackSQL, + "RuleName": labels.RuleName, + "Description": labels.Description, + "Suggestion": labels.Suggestion, + "Count": labels.Count, + "HitCount": labels.HitCount, + } + + for field, value := range fieldChecks { + if value == "" { + t.Errorf("%s: label field %q is empty (%s)", tc.description, field, tc.description) + } + } + }) + } +} + +// TestLevelCountsPreserveAllLevels verifies that toLevelCounts preserves +// the count data correctly for all four standard audit levels. +func TestLevelCountsPreserveAllLevels(t *testing.T) { + input := map[string]int{ + "error": 3, + "warn": 5, + "notice": 2, + "normal": 10, + } + + result := toLevelCounts(input) + + if len(result) != 4 { + t.Fatalf("expected 4 levels, got %d", len(result)) + } + + // Build a lookup map from result + resultMap := make(map[string]int) + for _, lc := range result { + resultMap[lc.Level] = lc.Count + } + + // Verify each level count matches + for level, expectedCount := range input { + if count, ok := resultMap[level]; !ok { + t.Errorf("level %q missing from result", level) + } else if count != expectedCount { + t.Errorf("level %q: count = %d, want %d", level, count, expectedCount) + } + } + + // Verify ordering: error=0, warn=1, notice=2, normal=3 + expectedOrder := []string{"error", "warn", "notice", "normal"} + for i, expected := range expectedOrder { + if result[i].Level != expected { + t.Errorf("position %d: level = %q, want %q", i, result[i].Level, expected) + } + } +} + +// TestToRuleHitsStableSortForEqualCounts verifies that toRuleHits handles +// rules with equal hit counts without error. +func TestToRuleHitsStableSortForEqualCounts(t *testing.T) { + input := map[string]int{ + "rule_a": 5, + "rule_b": 5, + "rule_c": 5, + } + + result := toRuleHits(input) + + if len(result) != 3 { + t.Fatalf("expected 3 rules, got %d", len(result)) + } + + // All should have count 5 + for _, rh := range result { + if rh.HitCount != 5 { + t.Errorf("rule %q: hit count = %d, want 5", rh.RuleName, rh.HitCount) + } + } +} + +// TestToLevelCountsNilMap verifies toLevelCounts handles nil map gracefully. +func TestToLevelCountsNilMap(t *testing.T) { + result := toLevelCounts(nil) + if result == nil { + t.Fatal("toLevelCounts(nil) should return empty slice, not nil") + } + if len(result) != 0 { + t.Errorf("toLevelCounts(nil) returned %d items, want 0", len(result)) + } +} + +// TestToRuleHitsNilMap verifies toRuleHits handles nil map gracefully. +func TestToRuleHitsNilMap(t *testing.T) { + result := toRuleHits(nil) + if result == nil { + t.Fatal("toRuleHits(nil) should return empty slice, not nil") + } + if len(result) != 0 { + t.Errorf("toRuleHits(nil) returned %d items, want 0", len(result)) + } +} + +// TestExtractRuleInfoNilResults verifies extractRuleInfo handles nil AuditResults. +func TestExtractRuleInfoNilResults(t *testing.T) { + ctx := context.Background() + ruleName, suggestion := extractRuleInfo(nil, ctx) + if ruleName != "" { + t.Errorf("extractRuleInfo(nil) ruleName = %q, want empty", ruleName) + } + if suggestion != "" { + t.Errorf("extractRuleInfo(nil) suggestion = %q, want empty", suggestion) + } +} + +// TestCSVHeaders verifies that CSVHeaders returns the correct number of columns +// based on the report labels. +func TestCSVHeaders(t *testing.T) { + data := &utils.AuditReportData{ + Labels: utils.ReportLabels{ + Number: "Number", + SQL: "SQL", + AuditStatus: "Audit Status", + AuditResult: "Audit Result", + ExecStatus: "Exec Status", + ExecResult: "Exec Result", + RollbackSQL: "Rollback SQL", + Description: "Description", + }, + } + + headers := data.CSVHeaders() + if len(headers) != 8 { + t.Errorf("CSVHeaders() returned %d columns, want 8", len(headers)) + } + + expectedHeaders := []string{"Number", "SQL", "Audit Status", "Audit Result", "Exec Status", "Exec Result", "Rollback SQL", "Description"} + for i, h := range headers { + if h != expectedHeaders[i] { + t.Errorf("CSVHeaders()[%d] = %q, want %q", i, h, expectedHeaders[i]) + } + } +} diff --git a/sqle/api/controller/v1/task.go b/sqle/api/controller/v1/task.go index 4af23db0b4..15057e4d04 100644 --- a/sqle/api/controller/v1/task.go +++ b/sqle/api/controller/v1/task.go @@ -10,7 +10,6 @@ import ( "net/http" "net/http/httputil" "net/url" - "strconv" "strings" "time" @@ -21,7 +20,6 @@ import ( "github.com/actiontech/sqle/sqle/config" "github.com/actiontech/sqle/sqle/dms" "github.com/actiontech/sqle/sqle/errors" - "github.com/actiontech/sqle/sqle/locale" "github.com/actiontech/sqle/sqle/log" "github.com/actiontech/sqle/sqle/model" "github.com/actiontech/sqle/sqle/server" @@ -599,7 +597,8 @@ type DownloadAuditTaskSQLsFileReqV1 struct { // @Security ApiKeyAuth // @Param task_id path string true "task id" // @Param no_duplicate query boolean false "select unique (fingerprint and audit result) for task sql" -// @Success 200 file 1 "sql report csv file" +// @Param export_format query string false "export format: csv, html, pdf, word" default(csv) +// @Success 200 file 1 "sql report file" // @router /v1/tasks/audits/{task_id}/sql_report [get] func DownloadTaskSQLReportFile(c echo.Context) error { req := new(DownloadAuditTaskSQLsFileReqV1) @@ -618,59 +617,27 @@ func DownloadTaskSQLReportFile(c echo.Context) error { return controller.JSONBaseErrorReq(c, err) } - data := map[string]interface{}{ - "task_id": taskId, - "no_duplicate": req.NoDuplicate, - } + ctx := c.Request().Context() + + // 解析导出格式参数,缺失时默认返回 CSV(向后兼容) + exportFormat := utils.NormalizeExportFormatStr(c.QueryParam("export_format")) - taskSQLsDetail, _, err := s.GetTaskSQLsByReq(data) + // 构建报告数据 + reportData, err := BuildAuditReportData(task, s, req.NoDuplicate, ctx) if err != nil { return controller.JSONBaseErrorReq(c, err) } - ctx := c.Request().Context() - csvBuilder := utils.NewCSVBuilder() - err = csvBuilder.WriteHeader([]string{ - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportIndex), // "序号", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportSQL), // "SQL", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportAuditStatus), // "SQL审核状态", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportAuditResult), // "SQL审核结果", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportExecStatus), // "SQL执行状态", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportExecResult), // "SQL执行结果", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportRollbackSQL), // "SQL对应的回滚语句", - locale.Bundle.LocalizeMsgByCtx(ctx, locale.TaskSQLReportDescription), // "SQL描述", - }) - if err != nil { - return controller.JSONBaseErrorReq(c, errors.New(errors.WriteDataToTheFileError, err)) - } - rollbackSqlMap, err := server.BackupService{}.GetRollbackSqlsMap(task.ID) + // 根据格式生成报告 + result, err := utils.ExportAuditReport(exportFormat, reportData) if err != nil { return controller.JSONBaseErrorReq(c, err) } - for _, td := range taskSQLsDetail { - taskSql := &model.ExecuteSQL{ - AuditResults: td.AuditResults, - AuditStatus: td.AuditStatus, - } - taskSql.ExecStatus = td.ExecStatus - err := csvBuilder.WriteRow([]string{ - strconv.FormatUint(uint64(td.Number), 10), - td.ExecSQL, - taskSql.GetAuditStatusDesc(ctx), - taskSql.GetAuditResultDesc(ctx), - taskSql.GetExecStatusDesc(ctx), - td.ExecResult, - strings.Join(rollbackSqlMap[taskSql.ID], "\n"), - td.Description, - }) - if err != nil { - return controller.JSONBaseErrorReq(c, errors.New(errors.WriteDataToTheFileError, err)) - } - } - fileName := fmt.Sprintf("SQL_audit_report_%v_%v.csv", task.InstanceName(), taskId) + + // 设置响应头并返回 c.Response().Header().Set(echo.HeaderContentDisposition, - mime.FormatMediaType("attachment", map[string]string{"filename": fileName})) - return c.Blob(http.StatusOK, "text/csv", csvBuilder.FlushAndGetBuffer().Bytes()) + mime.FormatMediaType("attachment", map[string]string{"filename": result.FileName})) + return c.Blob(http.StatusOK, result.ContentType, result.Content) } // @Summary 下载指定扫描任务的SQL文件 diff --git a/sqle/locale/active.en.toml b/sqle/locale/active.en.toml index 43d8dbba5f..cd5593cb3d 100644 --- a/sqle/locale/active.en.toml +++ b/sqle/locale/active.en.toml @@ -409,3 +409,19 @@ WorkflowStepStateApprove = "Approved" WorkflowStepStateReject = "Rejected" WorkflowStepTypeSQLAudit = "Auditing" WorkflowStepTypeSQLExecute = "Executing" +ReportLabelAuditSummary = "Audit Summary" +ReportLabelResultStatistics = "Audit Result Statistics" +ReportLabelProblemSQLList = "Problem SQL List" +ReportLabelRuleHitStatistics = "Rule Hit Statistics" +ReportLabelAuditTime = "Audit Time" +ReportLabelDataSource = "Data Source" +ReportLabelSchema = "Schema" +ReportLabelTotalSQL = "Total SQL" +ReportLabelPassRate = "Pass Rate" +ReportLabelScore = "Score" +ReportLabelAuditLevel = "Audit Level" +ReportLabelRuleName = "Rule Name" +ReportLabelSuggestion = "Suggestion" +ReportLabelCount = "Count" +ReportLabelHitCount = "Hit Count" +ReportLabelTitle = "SQL Audit Report" diff --git a/sqle/locale/active.zh.toml b/sqle/locale/active.zh.toml index 59c94f11e5..a25de8b025 100644 --- a/sqle/locale/active.zh.toml +++ b/sqle/locale/active.zh.toml @@ -409,3 +409,19 @@ WorkflowStepStateApprove = "通过" WorkflowStepStateReject = "驳回" WorkflowStepTypeSQLAudit = "审批" WorkflowStepTypeSQLExecute = "上线" +ReportLabelAuditSummary = "审核概要" +ReportLabelResultStatistics = "审核结果统计" +ReportLabelProblemSQLList = "问题SQL列表" +ReportLabelRuleHitStatistics = "规则命中统计" +ReportLabelAuditTime = "审核时间" +ReportLabelDataSource = "数据源" +ReportLabelSchema = "数据库" +ReportLabelTotalSQL = "SQL总数" +ReportLabelPassRate = "通过率" +ReportLabelScore = "评分" +ReportLabelAuditLevel = "审核等级" +ReportLabelRuleName = "规则名称" +ReportLabelSuggestion = "优化建议" +ReportLabelCount = "数量" +ReportLabelHitCount = "命中次数" +ReportLabelTitle = "SQL审核报告" diff --git a/sqle/locale/message_zh.go b/sqle/locale/message_zh.go index fdcad22c83..98d1bc0794 100644 --- a/sqle/locale/message_zh.go +++ b/sqle/locale/message_zh.go @@ -64,6 +64,26 @@ var ( TaskSQLReportDescription = &i18n.Message{ID: "TaskSQLReportDescription", Other: "SQL描述"} ) +// report labels (for audit report export) +var ( + ReportLabelAuditSummary = &i18n.Message{ID: "ReportLabelAuditSummary", Other: "审核概要"} + ReportLabelResultStatistics = &i18n.Message{ID: "ReportLabelResultStatistics", Other: "审核结果统计"} + ReportLabelProblemSQLList = &i18n.Message{ID: "ReportLabelProblemSQLList", Other: "问题SQL列表"} + ReportLabelRuleHitStatistics = &i18n.Message{ID: "ReportLabelRuleHitStatistics", Other: "规则命中统计"} + ReportLabelAuditTime = &i18n.Message{ID: "ReportLabelAuditTime", Other: "审核时间"} + ReportLabelDataSource = &i18n.Message{ID: "ReportLabelDataSource", Other: "数据源"} + ReportLabelSchema = &i18n.Message{ID: "ReportLabelSchema", Other: "数据库"} + ReportLabelTotalSQL = &i18n.Message{ID: "ReportLabelTotalSQL", Other: "SQL总数"} + ReportLabelPassRate = &i18n.Message{ID: "ReportLabelPassRate", Other: "通过率"} + ReportLabelScore = &i18n.Message{ID: "ReportLabelScore", Other: "评分"} + ReportLabelAuditLevel = &i18n.Message{ID: "ReportLabelAuditLevel", Other: "审核等级"} + ReportLabelRuleName = &i18n.Message{ID: "ReportLabelRuleName", Other: "规则名称"} + ReportLabelSuggestion = &i18n.Message{ID: "ReportLabelSuggestion", Other: "优化建议"} + ReportLabelCount = &i18n.Message{ID: "ReportLabelCount", Other: "数量"} + ReportLabelHitCount = &i18n.Message{ID: "ReportLabelHitCount", Other: "命中次数"} + ReportLabelTitle = &i18n.Message{ID: "ReportLabelTitle", Other: "SQL审核报告"} +) + // workflow var ( WorkflowStepStateApprove = &i18n.Message{ID: "WorkflowStepStateApprove", Other: "通过"} From b4906912eca5a47cea9e8b261bf50764a4929679 Mon Sep 17 00:00:00 2001 From: actiontech-zihan Date: Tue, 31 Mar 2026 15:38:13 +0000 Subject: [PATCH 06/14] docs(swagger): add export_format query parameter to sql_report endpoint Update swagger.json, swagger.yaml, and docs.go to include the new export_format query parameter (type: string, default: csv) for the /v1/tasks/audits/{task_id}/sql_report endpoint. Also update response description from "sql report csv file" to "sql report file" to reflect multi-format support. --- sqle/docs/docs.go | 9 ++++++++- sqle/docs/swagger.json | 9 ++++++++- sqle/docs/swagger.yaml | 7 ++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sqle/docs/docs.go b/sqle/docs/docs.go index 1520dd0a3e..2ec79ad150 100644 --- a/sqle/docs/docs.go +++ b/sqle/docs/docs.go @@ -10474,11 +10474,18 @@ var doc = `{ "description": "select unique (fingerprint and audit result) for task sql", "name": "no_duplicate", "in": "query" + }, + { + "type": "string", + "description": "export format: csv, html, pdf, word", + "name": "export_format", + "in": "query", + "default": "csv" } ], "responses": { "200": { - "description": "sql report csv file", + "description": "sql report file", "schema": { "type": "file" } diff --git a/sqle/docs/swagger.json b/sqle/docs/swagger.json index 60cafb8cd1..30e1295f73 100644 --- a/sqle/docs/swagger.json +++ b/sqle/docs/swagger.json @@ -10458,11 +10458,18 @@ "description": "select unique (fingerprint and audit result) for task sql", "name": "no_duplicate", "in": "query" + }, + { + "type": "string", + "description": "export format: csv, html, pdf, word", + "name": "export_format", + "in": "query", + "default": "csv" } ], "responses": { "200": { - "description": "sql report csv file", + "description": "sql report file", "schema": { "type": "file" } diff --git a/sqle/docs/swagger.yaml b/sqle/docs/swagger.yaml index c0b8cb67f1..ed6e5c2579 100644 --- a/sqle/docs/swagger.yaml +++ b/sqle/docs/swagger.yaml @@ -14559,9 +14559,14 @@ paths: in: query name: no_duplicate type: boolean + - default: csv + description: 'export format: csv, html, pdf, word' + in: query + name: export_format + type: string responses: "200": - description: sql report csv file + description: sql report file schema: type: file security: From a01d8a03a9156091168ee98fcfe7acacbae0e6c8 Mon Sep 17 00:00:00 2001 From: actiontech-zihan Date: Fri, 10 Apr 2026 02:47:18 +0000 Subject: [PATCH 07/14] fix(report): fix BUG-001,003,005,007 in multi-format export (CE part) --- sqle/api/controller/v1/report_data_builder.go | 11 +++++-- sqle/api/controller/v1/task.go | 7 +++-- sqle/utils/file.go | 19 ++++++------ sqle/utils/report_generator_test.go | 29 +++++++++++++------ 4 files changed, 43 insertions(+), 23 deletions(-) diff --git a/sqle/api/controller/v1/report_data_builder.go b/sqle/api/controller/v1/report_data_builder.go index fda3ab4619..985d9882d2 100644 --- a/sqle/api/controller/v1/report_data_builder.go +++ b/sqle/api/controller/v1/report_data_builder.go @@ -95,20 +95,25 @@ func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, auditTime = task.CreatedAt.Format("2006-01-02 15:04:05") } + instanceName := task.InstanceName() + if instanceName == "" { + instanceName = "unknown" + } + return &utils.AuditReportData{ TaskID: uint64(task.ID), Title: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelTitle), - InstanceName: task.InstanceName(), + InstanceName: instanceName, Schema: task.Schema, GeneratedAt: time.Now(), Lang: locale.Bundle.GetLangTagFromCtx(ctx).String(), LogoBase64: "", Summary: utils.AuditSummary{ AuditTime: auditTime, - InstanceName: task.InstanceName(), + InstanceName: instanceName, Schema: task.Schema, TotalSQL: len(sqlList), - PassRate: task.PassRate, + PassRate: task.PassRate * 100, Score: task.Score, AuditLevel: task.AuditLevel, }, diff --git a/sqle/api/controller/v1/task.go b/sqle/api/controller/v1/task.go index 15057e4d04..768270fab3 100644 --- a/sqle/api/controller/v1/task.go +++ b/sqle/api/controller/v1/task.go @@ -619,8 +619,11 @@ func DownloadTaskSQLReportFile(c echo.Context) error { ctx := c.Request().Context() - // 解析导出格式参数,缺失时默认返回 CSV(向后兼容) - exportFormat := utils.NormalizeExportFormatStr(c.QueryParam("export_format")) + // 解析导出格式参数,缺失时默认返回 CSV(向后兼容),无效格式返回 400 错误 + exportFormat, err := utils.NormalizeExportFormatStr(c.QueryParam("export_format")) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } // 构建报告数据 reportData, err := BuildAuditReportData(task, s, req.NoDuplicate, ctx) diff --git a/sqle/utils/file.go b/sqle/utils/file.go index e9ba7e6f53..3d4de02cfd 100644 --- a/sqle/utils/file.go +++ b/sqle/utils/file.go @@ -172,22 +172,23 @@ func NormalizeExportFormat(format *ExportFormat) ExportFormat { return *format } -// NormalizeExportFormatStr 规范化导出格式参数(字符串版本),默认返回 CSV(向后兼容) -// 支持 html/pdf/word/docx/excel/xlsx/csv 等输入的规范化,空字符串和无效值默认返回 CSV。 -func NormalizeExportFormatStr(format string) ExportFormat { +// NormalizeExportFormatStr 规范化导出格式参数(字符串版本)。 +// 空字符串默认返回 CSV(向后兼容);无效格式返回错误。 +// 支持 html/pdf/word/docx/excel/xlsx/csv 等输入的规范化。 +func NormalizeExportFormatStr(format string) (ExportFormat, error) { switch strings.ToLower(strings.TrimSpace(format)) { case "html": - return ExportFormatHTML + return ExportFormatHTML, nil case "pdf": - return ExportFormatPDF + return ExportFormatPDF, nil case "word", "docx": - return ExportFormatWORD + return ExportFormatWORD, nil case "excel", "xlsx": - return ExcelExportFormat + return ExcelExportFormat, nil case "csv", "": - return CsvExportFormat + return CsvExportFormat, nil default: - return CsvExportFormat + return "", fmt.Errorf("unsupported export format: %s", format) } } diff --git a/sqle/utils/report_generator_test.go b/sqle/utils/report_generator_test.go index 6cf3ede59e..3f8ad574ff 100644 --- a/sqle/utils/report_generator_test.go +++ b/sqle/utils/report_generator_test.go @@ -9,8 +9,9 @@ import ( func TestNormalizeExportFormatStr(t *testing.T) { testCases := map[string]struct { - input string - expected ExportFormat + input string + expected ExportFormat + expectError bool }{ "empty string defaults to csv": { input: "", @@ -64,13 +65,13 @@ func TestNormalizeExportFormatStr(t *testing.T) { input: "DOCX", expected: ExportFormatWORD, }, - "invalid value defaults to csv": { - input: "invalid", - expected: CsvExportFormat, + "invalid value returns error": { + input: "invalid", + expectError: true, }, - "unknown format defaults to csv": { - input: "json", - expected: CsvExportFormat, + "unknown format returns error": { + input: "json", + expectError: true, }, "whitespace-only defaults to csv": { input: " ", @@ -84,7 +85,17 @@ func TestNormalizeExportFormatStr(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { - result := NormalizeExportFormatStr(tc.input) + result, err := NormalizeExportFormatStr(tc.input) + if tc.expectError { + if err == nil { + t.Errorf("NormalizeExportFormatStr(%q) expected error, got nil", tc.input) + } + return + } + if err != nil { + t.Errorf("NormalizeExportFormatStr(%q) unexpected error: %v", tc.input, err) + return + } if result != tc.expected { t.Errorf("NormalizeExportFormatStr(%q) = %q, want %q", tc.input, result, tc.expected) } From 8f6d8398bbbc60cbf80f2ce270d78586369e7538 Mon Sep 17 00:00:00 2001 From: actiontech-zihan Date: Fri, 10 Apr 2026 02:47:35 +0000 Subject: [PATCH 08/14] fix(utils): return error for unsupported export format instead of CSV fallback (CE part) --- sqle/utils/report_generator_test.go | 186 ++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) diff --git a/sqle/utils/report_generator_test.go b/sqle/utils/report_generator_test.go index 3f8ad574ff..b97b4171f9 100644 --- a/sqle/utils/report_generator_test.go +++ b/sqle/utils/report_generator_test.go @@ -768,3 +768,189 @@ func TestHTMLReportGenerator_LargeData(t *testing.T) { }) } } + +// ========================================================================= +// ExportAuditReport CE 版测试 +// 以下测试在默认(非 enterprise)构建条件下运行,验证 CE 版格式分发逻辑。 +// ========================================================================= + +func TestExportAuditReport_CSVFormat(t *testing.T) { + testCases := map[string]struct { + format ExportFormat + wantContentType string + wantFileSuffix string + }{ + "CSV format returns valid CSV result": { + format: CsvExportFormat, + wantContentType: "text/csv", + wantFileSuffix: ".csv", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err != nil { + t.Fatalf("ExportAuditReport(%q) returned unexpected error: %v", tc.format, err) + } + if result == nil { + t.Fatal("ExportAuditReport() returned nil result") + } + if result.ContentType != tc.wantContentType { + t.Errorf("ContentType = %q, want %q", result.ContentType, tc.wantContentType) + } + if !strings.HasSuffix(result.FileName, tc.wantFileSuffix) { + t.Errorf("FileName %q does not end with %q", result.FileName, tc.wantFileSuffix) + } + if len(result.Content) == 0 { + t.Error("ExportAuditReport() returned empty content") + } + }) + } +} + +func TestExportAuditReport_HTMLFormat(t *testing.T) { + testCases := map[string]struct { + format ExportFormat + wantContentType string + wantFileSuffix string + wantHTMLTags []string + }{ + "HTML format returns valid HTML result": { + format: ExportFormatHTML, + wantContentType: "text/html", + wantFileSuffix: ".html", + wantHTMLTags: []string{"", "", ""}, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err != nil { + t.Fatalf("ExportAuditReport(%q) returned unexpected error: %v", tc.format, err) + } + if result == nil { + t.Fatal("ExportAuditReport() returned nil result") + } + if result.ContentType != tc.wantContentType { + t.Errorf("ContentType = %q, want %q", result.ContentType, tc.wantContentType) + } + if !strings.HasSuffix(result.FileName, tc.wantFileSuffix) { + t.Errorf("FileName %q does not end with %q", result.FileName, tc.wantFileSuffix) + } + content := string(result.Content) + for _, tag := range tc.wantHTMLTags { + if !strings.Contains(content, tag) { + t.Errorf("Content does not contain expected HTML tag %q", tag) + } + } + }) + } +} + +func TestExportAuditReport_CEEdition_PDFBlocked(t *testing.T) { + testCases := map[string]struct { + format ExportFormat + wantErrSubstr string + }{ + "PDF format is blocked in CE edition": { + format: ExportFormatPDF, + wantErrSubstr: "enterprise edition", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err == nil { + t.Fatal("ExportAuditReport(PDF) should return error in CE edition, got nil") + } + if result != nil { + t.Errorf("ExportAuditReport(PDF) should return nil result in CE edition, got %+v", result) + } + if !strings.Contains(err.Error(), tc.wantErrSubstr) { + t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) + } + // Verify the error message includes the format name + if !strings.Contains(err.Error(), string(tc.format)) { + t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) + } + }) + } +} + +func TestExportAuditReport_CEEdition_WORDBlocked(t *testing.T) { + testCases := map[string]struct { + format ExportFormat + wantErrSubstr string + }{ + "WORD format is blocked in CE edition": { + format: ExportFormatWORD, + wantErrSubstr: "enterprise edition", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err == nil { + t.Fatal("ExportAuditReport(WORD) should return error in CE edition, got nil") + } + if result != nil { + t.Errorf("ExportAuditReport(WORD) should return nil result in CE edition, got %+v", result) + } + if !strings.Contains(err.Error(), tc.wantErrSubstr) { + t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) + } + // Verify the error message includes the format name + if !strings.Contains(err.Error(), string(tc.format)) { + t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) + } + }) + } +} + +func TestExportAuditReport_DefaultCSV(t *testing.T) { + testCases := map[string]struct { + format ExportFormat + wantErrSubstr string + }{ + "invalid format returns error": { + format: ExportFormat("invalid"), + wantErrSubstr: "unsupported export format", + }, + "empty format returns error": { + format: ExportFormat(""), + wantErrSubstr: "unsupported export format", + }, + "unknown format returns error": { + format: ExportFormat("json"), + wantErrSubstr: "unsupported export format", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err == nil { + t.Fatalf("ExportAuditReport(%q) expected error, got nil", tc.format) + } + if result != nil { + t.Errorf("ExportAuditReport(%q) expected nil result, got %+v", tc.format, result) + } + if !strings.Contains(err.Error(), tc.wantErrSubstr) { + t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) + } + // Verify the error message includes the format name + if !strings.Contains(err.Error(), string(tc.format)) { + t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) + } + }) + } +} From 2d973a013859c42c7cb06a74d932e31cd91f27a1 Mon Sep 17 00:00:00 2001 From: yangzhongjiao Date: Fri, 10 Apr 2026 06:29:35 +0000 Subject: [PATCH 09/14] refactor(auditreport): move report export to server and apply review fixes --- sqle/api/controller/v1/report_data_builder.go | 38 +-- .../controller/v1/report_data_builder_test.go | 9 +- sqle/api/controller/v1/task.go | 3 +- sqle/server/auditreport/report_csv.go | 75 ++++++ .../auditreport}/report_generator.go | 46 +--- .../auditreport}/report_generator_ce.go | 16 +- .../auditreport/report_generator_ce_test.go | 111 +++++++++ .../auditreport}/report_generator_test.go | 228 +----------------- .../auditreport}/report_html.go | 23 +- .../auditreport}/report_html_template.go | 3 +- .../auditreport}/templates/audit_report.html | 0 sqle/utils/report_csv.go | 55 ----- 12 files changed, 251 insertions(+), 356 deletions(-) create mode 100644 sqle/server/auditreport/report_csv.go rename sqle/{utils => server/auditreport}/report_generator.go (75%) rename sqle/{utils => server/auditreport}/report_generator_ce.go (66%) create mode 100644 sqle/server/auditreport/report_generator_ce_test.go rename sqle/{utils => server/auditreport}/report_generator_test.go (77%) rename sqle/{utils => server/auditreport}/report_html.go (77%) rename sqle/{utils => server/auditreport}/report_html_template.go (83%) rename sqle/{utils => server/auditreport}/templates/audit_report.html (100%) delete mode 100644 sqle/utils/report_csv.go diff --git a/sqle/api/controller/v1/report_data_builder.go b/sqle/api/controller/v1/report_data_builder.go index 985d9882d2..dd5b8c3551 100644 --- a/sqle/api/controller/v1/report_data_builder.go +++ b/sqle/api/controller/v1/report_data_builder.go @@ -9,14 +9,14 @@ import ( "github.com/actiontech/sqle/sqle/locale" "github.com/actiontech/sqle/sqle/model" + "github.com/actiontech/sqle/sqle/server/auditreport" "github.com/actiontech/sqle/sqle/server" - "github.com/actiontech/sqle/sqle/utils" ) // BuildAuditReportData 从 Task 和数据库查询构建报告数据。 -// 该函数放在 controller 层而非 utils 层,因为 utils 被 model 引用, +// 该函数放在 controller 层;报告数据模型在 server/auditreport。utils 被 model 引用, // 若 utils 反向引用 model 会产生循环依赖。 -func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, ctx context.Context) (*utils.AuditReportData, error) { +func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, ctx context.Context) (*auditreport.AuditReportData, error) { // 1. 获取 SQL 列表 data := map[string]interface{}{ "task_id": fmt.Sprintf("%d", task.ID), @@ -37,8 +37,8 @@ func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, // 3. 构建 SQL 列表和统计数据 levelDist := make(map[string]int) ruleHits := make(map[string]int) - var sqlList []utils.AuditSQLItem - var problemSQLs []utils.AuditSQLItem + var sqlList []auditreport.AuditSQLItem + var problemSQLs []auditreport.AuditSQLItem for _, td := range taskSQLsDetail { // 构造临时 ExecuteSQL 对象以复用状态描述方法 @@ -51,7 +51,7 @@ func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, // 提取规则名称和审核建议 ruleName, suggestion := extractRuleInfo(td.AuditResults, ctx) - item := utils.AuditSQLItem{ + item := auditreport.AuditSQLItem{ Number: td.Number, SQL: td.ExecSQL, AuditLevel: td.AuditLevel, @@ -100,7 +100,7 @@ func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, instanceName = "unknown" } - return &utils.AuditReportData{ + return &auditreport.AuditReportData{ TaskID: uint64(task.ID), Title: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelTitle), InstanceName: instanceName, @@ -108,7 +108,7 @@ func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, GeneratedAt: time.Now(), Lang: locale.Bundle.GetLangTagFromCtx(ctx).String(), LogoBase64: "", - Summary: utils.AuditSummary{ + Summary: auditreport.AuditSummary{ AuditTime: auditTime, InstanceName: instanceName, Schema: task.Schema, @@ -117,7 +117,7 @@ func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, Score: task.Score, AuditLevel: task.AuditLevel, }, - Statistics: utils.AuditStatistics{ + Statistics: auditreport.AuditStatistics{ LevelDistribution: toLevelCounts(levelDist), RuleHits: toRuleHits(ruleHits), }, @@ -153,9 +153,9 @@ func extractRuleInfo(auditResults model.AuditResults, ctx context.Context) (rule // toLevelCounts 将等级分布 map 转换为有序的 LevelCount 切片。 // 按 error > warn > notice > normal 顺序排列。 -func toLevelCounts(dist map[string]int) []utils.LevelCount { +func toLevelCounts(dist map[string]int) []auditreport.LevelCount { if len(dist) == 0 { - return []utils.LevelCount{} + return []auditreport.LevelCount{} } levelOrder := map[string]int{ @@ -165,9 +165,9 @@ func toLevelCounts(dist map[string]int) []utils.LevelCount { "normal": 3, } - result := make([]utils.LevelCount, 0, len(dist)) + result := make([]auditreport.LevelCount, 0, len(dist)) for level, count := range dist { - result = append(result, utils.LevelCount{ + result = append(result, auditreport.LevelCount{ Level: level, Count: count, }) @@ -189,14 +189,14 @@ func toLevelCounts(dist map[string]int) []utils.LevelCount { } // toRuleHits 将规则命中 map 转换为按命中次数降序排列的 RuleHit 切片。 -func toRuleHits(hits map[string]int) []utils.RuleHit { +func toRuleHits(hits map[string]int) []auditreport.RuleHit { if len(hits) == 0 { - return []utils.RuleHit{} + return []auditreport.RuleHit{} } - result := make([]utils.RuleHit, 0, len(hits)) + result := make([]auditreport.RuleHit, 0, len(hits)) for name, count := range hits { - result = append(result, utils.RuleHit{ + result = append(result, auditreport.RuleHit{ RuleName: name, HitCount: count, }) @@ -212,8 +212,8 @@ func toRuleHits(hits map[string]int) []utils.RuleHit { // buildReportLabels 构建报告中使用的国际化标签。 // 当前版本使用 locale 包已有的国际化消息和硬编码中文标签, // 后续阶段 8 将接入 go-i18n 框架实现完整国际化。 -func buildReportLabels(ctx context.Context) utils.ReportLabels { - return utils.ReportLabels{ +func buildReportLabels(ctx context.Context) auditreport.ReportLabels { + return auditreport.ReportLabels{ AuditSummary: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelAuditSummary), ResultStatistics: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelResultStatistics), ProblemSQLList: locale.Bundle.LocalizeMsgByCtx(ctx, locale.ReportLabelProblemSQLList), diff --git a/sqle/api/controller/v1/report_data_builder_test.go b/sqle/api/controller/v1/report_data_builder_test.go index e87b550852..12e08e7abb 100644 --- a/sqle/api/controller/v1/report_data_builder_test.go +++ b/sqle/api/controller/v1/report_data_builder_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/actiontech/sqle/sqle/model" - "github.com/actiontech/sqle/sqle/utils" + "github.com/actiontech/sqle/sqle/server/auditreport" "golang.org/x/text/language" ) @@ -423,8 +423,8 @@ func TestExtractRuleInfoNilResults(t *testing.T) { // TestCSVHeaders verifies that CSVHeaders returns the correct number of columns // based on the report labels. func TestCSVHeaders(t *testing.T) { - data := &utils.AuditReportData{ - Labels: utils.ReportLabels{ + data := &auditreport.AuditReportData{ + Labels: auditreport.ReportLabels{ Number: "Number", SQL: "SQL", AuditStatus: "Audit Status", @@ -436,7 +436,8 @@ func TestCSVHeaders(t *testing.T) { }, } - headers := data.CSVHeaders() + gen := auditreport.NewCSVReportGenerator() + headers := gen.CSVHeaders(data) if len(headers) != 8 { t.Errorf("CSVHeaders() returned %d columns, want 8", len(headers)) } diff --git a/sqle/api/controller/v1/task.go b/sqle/api/controller/v1/task.go index 768270fab3..69fe3077f7 100644 --- a/sqle/api/controller/v1/task.go +++ b/sqle/api/controller/v1/task.go @@ -22,6 +22,7 @@ import ( "github.com/actiontech/sqle/sqle/errors" "github.com/actiontech/sqle/sqle/log" "github.com/actiontech/sqle/sqle/model" + "github.com/actiontech/sqle/sqle/server/auditreport" "github.com/actiontech/sqle/sqle/server" "github.com/actiontech/sqle/sqle/utils" @@ -632,7 +633,7 @@ func DownloadTaskSQLReportFile(c echo.Context) error { } // 根据格式生成报告 - result, err := utils.ExportAuditReport(exportFormat, reportData) + result, err := auditreport.ExportAuditReport(exportFormat, reportData) if err != nil { return controller.JSONBaseErrorReq(c, err) } diff --git a/sqle/server/auditreport/report_csv.go b/sqle/server/auditreport/report_csv.go new file mode 100644 index 0000000000..bfcb7b0595 --- /dev/null +++ b/sqle/server/auditreport/report_csv.go @@ -0,0 +1,75 @@ +package auditreport + +import ( + "fmt" + + "github.com/actiontech/sqle/sqle/utils" +) + +// CSVReportGenerator CSV 格式报告生成器 +// 复用已有的 CSVBuilder 生成 CSV 报告,实现 ReportGenerator 接口。 +type CSVReportGenerator struct{} + +// NewCSVReportGenerator 创建并返回一个新的 CSVReportGenerator 实例 +func NewCSVReportGenerator() *CSVReportGenerator { + return &CSVReportGenerator{} +} + +// ReportType 返回生成器支持的导出格式 +func (g *CSVReportGenerator) ReportType() utils.ExportFormat { + return utils.CsvExportFormat +} + +// CSVHeaders 返回 CSV 报告的表头列表 +func (g *CSVReportGenerator) CSVHeaders(data *AuditReportData) []string { + return []string{ + data.Labels.Number, + data.Labels.SQL, + data.Labels.AuditStatus, + data.Labels.AuditResult, + data.Labels.ExecStatus, + data.Labels.ExecResult, + data.Labels.RollbackSQL, + data.Labels.Description, + } +} + +// ToCSVRow 将单条审核 SQL 转为 CSV 行 +func (g *CSVReportGenerator) ToCSVRow(item *AuditSQLItem) []string { + return []string{ + fmt.Sprintf("%d", item.Number), + item.SQL, + item.AuditStatus, + item.AuditResult, + item.ExecStatus, + item.ExecResult, + item.RollbackSQL, + item.Description, + } +} + +// Generate 根据审核报告数据生成 CSV 格式的文件 +func (g *CSVReportGenerator) Generate(data *AuditReportData) (*utils.ExportDataResult, error) { + builder := utils.NewCSVBuilder() + + if err := builder.WriteHeader(g.CSVHeaders(data)); err != nil { + return nil, fmt.Errorf("write csv header failed: %v", err) + } + + for i := range data.SQLList { + if err := builder.WriteRow(g.ToCSVRow(&data.SQLList[i])); err != nil { + return nil, fmt.Errorf("write csv row failed: %v", err) + } + } + + content := builder.FlushAndGetBuffer().Bytes() + if err := builder.Error(); err != nil { + return nil, fmt.Errorf("csv builder error: %v", err) + } + + return &utils.ExportDataResult{ + Content: content, + ContentType: "text/csv", + FileName: fmt.Sprintf("SQL_audit_report_%s_%d.csv", data.InstanceName, data.TaskID), + }, nil +} diff --git a/sqle/utils/report_generator.go b/sqle/server/auditreport/report_generator.go similarity index 75% rename from sqle/utils/report_generator.go rename to sqle/server/auditreport/report_generator.go index 34e3366d44..f463bc23c5 100644 --- a/sqle/utils/report_generator.go +++ b/sqle/server/auditreport/report_generator.go @@ -1,8 +1,9 @@ -package utils +package auditreport import ( - "fmt" "time" + + "github.com/actiontech/sqle/sqle/utils" ) // AuditReportData 审核报告完整数据模型 @@ -30,20 +31,6 @@ type AuditReportData struct { Labels ReportLabels `json:"labels"` } -// CSVHeaders 返回 CSV 报告的表头列表 -func (d *AuditReportData) CSVHeaders() []string { - return []string{ - d.Labels.Number, - d.Labels.SQL, - d.Labels.AuditStatus, - d.Labels.AuditResult, - d.Labels.ExecStatus, - d.Labels.ExecResult, - d.Labels.RollbackSQL, - d.Labels.Description, - } -} - // AuditSummary 审核概要 type AuditSummary struct { AuditTime string `json:"audit_time"` @@ -58,7 +45,7 @@ type AuditSummary struct { // AuditStatistics 审核结果统计 type AuditStatistics struct { LevelDistribution []LevelCount `json:"level_distribution"` // 按等级分布 - RuleHits []RuleHit `json:"rule_hits"` // 规则命中统计 + RuleHits []RuleHit `json:"rule_hits"` // 规则命中统计 } // LevelCount 等级统计 @@ -89,20 +76,6 @@ type AuditSQLItem struct { Suggestion string `json:"suggestion"` // 优化建议 } -// ToCSVRow 将审核 SQL 项转换为 CSV 行数据 -func (item *AuditSQLItem) ToCSVRow() []string { - return []string{ - fmt.Sprintf("%d", item.Number), - item.SQL, - item.AuditStatus, - item.AuditResult, - item.ExecStatus, - item.ExecResult, - item.RollbackSQL, - item.Description, - } -} - // ReportLabels 报告中的国际化标签 type ReportLabels struct { AuditSummary string `json:"audit_summary"` @@ -133,12 +106,7 @@ type ReportLabels struct { // ReportGenerator 报告生成器接口 type ReportGenerator interface { // Generate 根据报告数据生成指定格式的文件 - Generate(data *AuditReportData) (*ExportDataResult, error) - // Format 返回生成器支持的格式 - Format() ExportFormat + Generate(data *AuditReportData) (*utils.ExportDataResult, error) + // ReportType 返回生成器对应的导出格式 + ReportType() utils.ExportFormat } - -// ExportAuditReport 统一导出入口(CE/EE 通过 build tags 区分实现) -// CE 版本支持 CSV 和 HTML 格式,EE 版本额外支持 PDF 和 WORD 格式。 -// 函数签名:func ExportAuditReport(format ExportFormat, data *AuditReportData) (*ExportDataResult, error) -// 实现分别位于 report_generator_ce.go 和 report_generator_ee.go 中。 diff --git a/sqle/utils/report_generator_ce.go b/sqle/server/auditreport/report_generator_ce.go similarity index 66% rename from sqle/utils/report_generator_ce.go rename to sqle/server/auditreport/report_generator_ce.go index ac2723c0fa..cd3b2af8cd 100644 --- a/sqle/utils/report_generator_ce.go +++ b/sqle/server/auditreport/report_generator_ce.go @@ -1,24 +1,28 @@ //go:build !enterprise -package utils +package auditreport -import "fmt" +import ( + "fmt" + + "github.com/actiontech/sqle/sqle/utils" +) // ExportAuditReport CE 版统一导出入口。 // CE 版仅支持 CSV 和 HTML 两种格式。 // 请求 PDF 或 WORD 格式时返回错误提示,提醒用户需要企业版。 // 无效格式返回错误(REQ-6.3)。 -func ExportAuditReport(format ExportFormat, data *AuditReportData) (*ExportDataResult, error) { +func ExportAuditReport(format utils.ExportFormat, data *AuditReportData) (*utils.ExportDataResult, error) { switch format { - case CsvExportFormat: + case utils.CsvExportFormat: return NewCSVReportGenerator().Generate(data) - case ExportFormatHTML: + case utils.ExportFormatHTML: gen, err := NewHTMLReportGenerator() if err != nil { return nil, err } return gen.Generate(data) - case ExportFormatPDF, ExportFormatWORD: + case utils.ExportFormatPDF, utils.ExportFormatWORD: return nil, fmt.Errorf("export format %s is only supported in enterprise edition", format) default: return nil, fmt.Errorf("unsupported export format: %s", format) diff --git a/sqle/server/auditreport/report_generator_ce_test.go b/sqle/server/auditreport/report_generator_ce_test.go new file mode 100644 index 0000000000..8f1a8b7849 --- /dev/null +++ b/sqle/server/auditreport/report_generator_ce_test.go @@ -0,0 +1,111 @@ +//go:build !enterprise + +package auditreport + +import ( + "strings" + "testing" + + "github.com/actiontech/sqle/sqle/utils" +) + +func TestExportAuditReport_CEEdition_PDFBlocked(t *testing.T) { + testCases := map[string]struct { + format utils.ExportFormat + wantErrSubstr string + }{ + "PDF format is blocked in CE edition": { + format: utils.ExportFormatPDF, + wantErrSubstr: "enterprise edition", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err == nil { + t.Fatal("ExportAuditReport(PDF) should return error in CE edition, got nil") + } + if result != nil { + t.Errorf("ExportAuditReport(PDF) should return nil result in CE edition, got %+v", result) + } + if !strings.Contains(err.Error(), tc.wantErrSubstr) { + t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) + } + if !strings.Contains(err.Error(), string(tc.format)) { + t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) + } + }) + } +} + +func TestExportAuditReport_CEEdition_WORDBlocked(t *testing.T) { + testCases := map[string]struct { + format utils.ExportFormat + wantErrSubstr string + }{ + "WORD format is blocked in CE edition": { + format: utils.ExportFormatWORD, + wantErrSubstr: "enterprise edition", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err == nil { + t.Fatal("ExportAuditReport(WORD) should return error in CE edition, got nil") + } + if result != nil { + t.Errorf("ExportAuditReport(WORD) should return nil result in CE edition, got %+v", result) + } + if !strings.Contains(err.Error(), tc.wantErrSubstr) { + t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) + } + if !strings.Contains(err.Error(), string(tc.format)) { + t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) + } + }) + } +} + +func TestExportAuditReport_DefaultCSV(t *testing.T) { + testCases := map[string]struct { + format utils.ExportFormat + wantErrSubstr string + }{ + "invalid format returns error": { + format: utils.ExportFormat("invalid"), + wantErrSubstr: "unsupported export format", + }, + "empty format returns error": { + format: utils.ExportFormat(""), + wantErrSubstr: "unsupported export format", + }, + "unknown format returns error": { + format: utils.ExportFormat("json"), + wantErrSubstr: "unsupported export format", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + data := buildTestReportData() + result, err := ExportAuditReport(tc.format, data) + if err == nil { + t.Fatalf("ExportAuditReport(%q) expected error, got nil", tc.format) + } + if result != nil { + t.Errorf("ExportAuditReport(%q) expected nil result, got %+v", tc.format, result) + } + if !strings.Contains(err.Error(), tc.wantErrSubstr) { + t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) + } + if !strings.Contains(err.Error(), string(tc.format)) { + t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) + } + }) + } +} diff --git a/sqle/utils/report_generator_test.go b/sqle/server/auditreport/report_generator_test.go similarity index 77% rename from sqle/utils/report_generator_test.go rename to sqle/server/auditreport/report_generator_test.go index b97b4171f9..c557dce0e1 100644 --- a/sqle/utils/report_generator_test.go +++ b/sqle/server/auditreport/report_generator_test.go @@ -1,107 +1,13 @@ -package utils +package auditreport import ( "fmt" "strings" "testing" "time" -) - -func TestNormalizeExportFormatStr(t *testing.T) { - testCases := map[string]struct { - input string - expected ExportFormat - expectError bool - }{ - "empty string defaults to csv": { - input: "", - expected: CsvExportFormat, - }, - "csv returns csv": { - input: "csv", - expected: CsvExportFormat, - }, - "CSV uppercase returns csv": { - input: "CSV", - expected: CsvExportFormat, - }, - "excel returns excel": { - input: "excel", - expected: ExcelExportFormat, - }, - "xlsx returns excel": { - input: "xlsx", - expected: ExcelExportFormat, - }, - "html returns html": { - input: "html", - expected: ExportFormatHTML, - }, - "HTML uppercase returns html": { - input: "HTML", - expected: ExportFormatHTML, - }, - "pdf returns pdf": { - input: "pdf", - expected: ExportFormatPDF, - }, - "PDF uppercase returns pdf": { - input: "PDF", - expected: ExportFormatPDF, - }, - "word returns word": { - input: "word", - expected: ExportFormatWORD, - }, - "WORD uppercase returns word": { - input: "WORD", - expected: ExportFormatWORD, - }, - "docx returns word": { - input: "docx", - expected: ExportFormatWORD, - }, - "DOCX uppercase returns word": { - input: "DOCX", - expected: ExportFormatWORD, - }, - "invalid value returns error": { - input: "invalid", - expectError: true, - }, - "unknown format returns error": { - input: "json", - expectError: true, - }, - "whitespace-only defaults to csv": { - input: " ", - expected: CsvExportFormat, - }, - "leading and trailing spaces are trimmed": { - input: " pdf ", - expected: ExportFormatPDF, - }, - } - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - result, err := NormalizeExportFormatStr(tc.input) - if tc.expectError { - if err == nil { - t.Errorf("NormalizeExportFormatStr(%q) expected error, got nil", tc.input) - } - return - } - if err != nil { - t.Errorf("NormalizeExportFormatStr(%q) unexpected error: %v", tc.input, err) - return - } - if result != tc.expected { - t.Errorf("NormalizeExportFormatStr(%q) = %q, want %q", tc.input, result, tc.expected) - } - }) - } -} + "github.com/actiontech/sqle/sqle/utils" +) // buildTestReportData 构建测试用的 AuditReportData func buildTestReportData() *AuditReportData { @@ -442,10 +348,10 @@ func TestCSVReportGenerator_SpecialChars(t *testing.T) { } } -func TestCSVReportGenerator_Format(t *testing.T) { +func TestCSVReportGenerator_ReportType(t *testing.T) { gen := NewCSVReportGenerator() - if gen.Format() != CsvExportFormat { - t.Errorf("Format() = %q, want %q", gen.Format(), CsvExportFormat) + if gen.ReportType() != utils.CsvExportFormat { + t.Errorf("ReportType() = %q, want %q", gen.ReportType(), utils.CsvExportFormat) } } @@ -529,9 +435,9 @@ func TestHTMLReportGenerator_Normal(t *testing.T) { } } - // Verify Format() returns ExportFormatHTML - if gen.Format() != ExportFormatHTML { - t.Errorf("Format() = %q, want %q", gen.Format(), ExportFormatHTML) + // Verify ReportType() returns ExportFormatHTML + if gen.ReportType() != utils.ExportFormatHTML { + t.Errorf("ReportType() = %q, want %q", gen.ReportType(), utils.ExportFormatHTML) } }) } @@ -769,19 +675,14 @@ func TestHTMLReportGenerator_LargeData(t *testing.T) { } } -// ========================================================================= -// ExportAuditReport CE 版测试 -// 以下测试在默认(非 enterprise)构建条件下运行,验证 CE 版格式分发逻辑。 -// ========================================================================= - func TestExportAuditReport_CSVFormat(t *testing.T) { testCases := map[string]struct { - format ExportFormat + format utils.ExportFormat wantContentType string wantFileSuffix string }{ "CSV format returns valid CSV result": { - format: CsvExportFormat, + format: utils.CsvExportFormat, wantContentType: "text/csv", wantFileSuffix: ".csv", }, @@ -812,13 +713,13 @@ func TestExportAuditReport_CSVFormat(t *testing.T) { func TestExportAuditReport_HTMLFormat(t *testing.T) { testCases := map[string]struct { - format ExportFormat + format utils.ExportFormat wantContentType string wantFileSuffix string wantHTMLTags []string }{ "HTML format returns valid HTML result": { - format: ExportFormatHTML, + format: utils.ExportFormatHTML, wantContentType: "text/html", wantFileSuffix: ".html", wantHTMLTags: []string{"", "", "
"}, @@ -851,106 +752,3 @@ func TestExportAuditReport_HTMLFormat(t *testing.T) { } } -func TestExportAuditReport_CEEdition_PDFBlocked(t *testing.T) { - testCases := map[string]struct { - format ExportFormat - wantErrSubstr string - }{ - "PDF format is blocked in CE edition": { - format: ExportFormatPDF, - wantErrSubstr: "enterprise edition", - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - data := buildTestReportData() - result, err := ExportAuditReport(tc.format, data) - if err == nil { - t.Fatal("ExportAuditReport(PDF) should return error in CE edition, got nil") - } - if result != nil { - t.Errorf("ExportAuditReport(PDF) should return nil result in CE edition, got %+v", result) - } - if !strings.Contains(err.Error(), tc.wantErrSubstr) { - t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) - } - // Verify the error message includes the format name - if !strings.Contains(err.Error(), string(tc.format)) { - t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) - } - }) - } -} - -func TestExportAuditReport_CEEdition_WORDBlocked(t *testing.T) { - testCases := map[string]struct { - format ExportFormat - wantErrSubstr string - }{ - "WORD format is blocked in CE edition": { - format: ExportFormatWORD, - wantErrSubstr: "enterprise edition", - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - data := buildTestReportData() - result, err := ExportAuditReport(tc.format, data) - if err == nil { - t.Fatal("ExportAuditReport(WORD) should return error in CE edition, got nil") - } - if result != nil { - t.Errorf("ExportAuditReport(WORD) should return nil result in CE edition, got %+v", result) - } - if !strings.Contains(err.Error(), tc.wantErrSubstr) { - t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) - } - // Verify the error message includes the format name - if !strings.Contains(err.Error(), string(tc.format)) { - t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) - } - }) - } -} - -func TestExportAuditReport_DefaultCSV(t *testing.T) { - testCases := map[string]struct { - format ExportFormat - wantErrSubstr string - }{ - "invalid format returns error": { - format: ExportFormat("invalid"), - wantErrSubstr: "unsupported export format", - }, - "empty format returns error": { - format: ExportFormat(""), - wantErrSubstr: "unsupported export format", - }, - "unknown format returns error": { - format: ExportFormat("json"), - wantErrSubstr: "unsupported export format", - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - data := buildTestReportData() - result, err := ExportAuditReport(tc.format, data) - if err == nil { - t.Fatalf("ExportAuditReport(%q) expected error, got nil", tc.format) - } - if result != nil { - t.Errorf("ExportAuditReport(%q) expected nil result, got %+v", tc.format, result) - } - if !strings.Contains(err.Error(), tc.wantErrSubstr) { - t.Errorf("error message %q does not contain %q", err.Error(), tc.wantErrSubstr) - } - // Verify the error message includes the format name - if !strings.Contains(err.Error(), string(tc.format)) { - t.Errorf("error message %q does not contain format name %q", err.Error(), tc.format) - } - }) - } -} diff --git a/sqle/utils/report_html.go b/sqle/server/auditreport/report_html.go similarity index 77% rename from sqle/utils/report_html.go rename to sqle/server/auditreport/report_html.go index 9ca036334c..e59ecdca0e 100644 --- a/sqle/utils/report_html.go +++ b/sqle/server/auditreport/report_html.go @@ -1,9 +1,11 @@ -package utils +package auditreport import ( "bytes" "fmt" "html/template" + + "github.com/actiontech/sqle/sqle/utils" ) // HTMLReportGenerator HTML 格式报告生成器 @@ -29,28 +31,19 @@ func NewHTMLReportGenerator() (*HTMLReportGenerator, error) { return &HTMLReportGenerator{tmpl: tmpl}, nil } -// Format 返回生成器支持的导出格式 -func (g *HTMLReportGenerator) Format() ExportFormat { - return ExportFormatHTML +// ReportType 返回生成器支持的导出格式 +func (g *HTMLReportGenerator) ReportType() utils.ExportFormat { + return utils.ExportFormatHTML } // Generate 根据审核报告数据生成 HTML 格式的文件 -// -// 参数: -// -// data: 审核报告完整数据模型 -// -// 返回: -// -// *ExportDataResult: 包含 HTML 文件内容、ContentType 和文件名 -// error: 生成过程中的错误 -func (g *HTMLReportGenerator) Generate(data *AuditReportData) (*ExportDataResult, error) { +func (g *HTMLReportGenerator) Generate(data *AuditReportData) (*utils.ExportDataResult, error) { var buf bytes.Buffer if err := g.tmpl.Execute(&buf, data); err != nil { return nil, fmt.Errorf("render HTML report failed: %w", err) } - return &ExportDataResult{ + return &utils.ExportDataResult{ Content: buf.Bytes(), ContentType: "text/html", FileName: fmt.Sprintf("SQL_audit_report_%s_%d.html", data.InstanceName, data.TaskID), diff --git a/sqle/utils/report_html_template.go b/sqle/server/auditreport/report_html_template.go similarity index 83% rename from sqle/utils/report_html_template.go rename to sqle/server/auditreport/report_html_template.go index 773c506bd8..b8d2c1975b 100644 --- a/sqle/utils/report_html_template.go +++ b/sqle/server/auditreport/report_html_template.go @@ -1,11 +1,10 @@ -package utils +package auditreport import "embed" //go:embed templates/audit_report.html var auditReportTemplateFS embed.FS -// auditReportHTMLTemplatePath is the path to the embedded HTML template file. const auditReportHTMLTemplatePath = "templates/audit_report.html" // GetAuditReportHTMLTemplate reads the embedded HTML template and returns its content as a string. diff --git a/sqle/utils/templates/audit_report.html b/sqle/server/auditreport/templates/audit_report.html similarity index 100% rename from sqle/utils/templates/audit_report.html rename to sqle/server/auditreport/templates/audit_report.html diff --git a/sqle/utils/report_csv.go b/sqle/utils/report_csv.go deleted file mode 100644 index 7d16b07a7a..0000000000 --- a/sqle/utils/report_csv.go +++ /dev/null @@ -1,55 +0,0 @@ -package utils - -import "fmt" - -// CSVReportGenerator CSV 格式报告生成器 -// 复用已有的 CSVBuilder 生成 CSV 报告,实现 ReportGenerator 接口。 -type CSVReportGenerator struct{} - -// NewCSVReportGenerator 创建并返回一个新的 CSVReportGenerator 实例 -func NewCSVReportGenerator() *CSVReportGenerator { - return &CSVReportGenerator{} -} - -// Format 返回生成器支持的导出格式 -func (g *CSVReportGenerator) Format() ExportFormat { - return CsvExportFormat -} - -// Generate 根据审核报告数据生成 CSV 格式的文件 -// -// 参数: -// -// data: 审核报告完整数据模型 -// -// 返回: -// -// *ExportDataResult: 包含 CSV 文件内容、ContentType 和文件名 -// error: 生成过程中的错误 -func (g *CSVReportGenerator) Generate(data *AuditReportData) (*ExportDataResult, error) { - builder := NewCSVBuilder() - - // 写入表头 - if err := builder.WriteHeader(data.CSVHeaders()); err != nil { - return nil, fmt.Errorf("write csv header failed: %v", err) - } - - // 写入数据行 - for _, sql := range data.SQLList { - if err := builder.WriteRow(sql.ToCSVRow()); err != nil { - return nil, fmt.Errorf("write csv row failed: %v", err) - } - } - - // 刷新缓冲区并获取内容 - content := builder.FlushAndGetBuffer().Bytes() - if err := builder.Error(); err != nil { - return nil, fmt.Errorf("csv builder error: %v", err) - } - - return &ExportDataResult{ - Content: content, - ContentType: "text/csv", - FileName: fmt.Sprintf("SQL_audit_report_%s_%d.csv", data.InstanceName, data.TaskID), - }, nil -} From 0fc0407618cdbc11d96fafd7a88ab7cc2439f5cd Mon Sep 17 00:00:00 2001 From: yangzhongjiao Date: Fri, 10 Apr 2026 06:41:50 +0000 Subject: [PATCH 10/14] refactor(report): optimize SQL list initialization in BuildAuditReportData function Updated the BuildAuditReportData function to initialize sqlList and problemSQLs with a predefined capacity based on taskSQLsDetail length, improving performance and memory usage. --- sqle/api/controller/v1/report_data_builder.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sqle/api/controller/v1/report_data_builder.go b/sqle/api/controller/v1/report_data_builder.go index dd5b8c3551..0bf237b4ec 100644 --- a/sqle/api/controller/v1/report_data_builder.go +++ b/sqle/api/controller/v1/report_data_builder.go @@ -37,8 +37,9 @@ func BuildAuditReportData(task *model.Task, s *model.Storage, noDuplicate bool, // 3. 构建 SQL 列表和统计数据 levelDist := make(map[string]int) ruleHits := make(map[string]int) - var sqlList []auditreport.AuditSQLItem - var problemSQLs []auditreport.AuditSQLItem + n := len(taskSQLsDetail) + sqlList := make([]auditreport.AuditSQLItem, 0, n) + problemSQLs := make([]auditreport.AuditSQLItem, 0, n) for _, td := range taskSQLsDetail { // 构造临时 ExecuteSQL 对象以复用状态描述方法 From 52697f838bc26180885cf506de61bb6123dc4524 Mon Sep 17 00:00:00 2001 From: yangzhongjiao Date: Mon, 13 Apr 2026 03:07:47 +0000 Subject: [PATCH 11/14] feat(auditreport): move report ExportFormat to server layer Define ExportFormat and NormalizeExportFormatStr in auditreport; keep utils.ExportFormat for generic csv/excel data export only. --- sqle/server/auditreport/export_format.go | 36 +++++++++++++++ sqle/server/auditreport/report_csv.go | 4 +- sqle/server/auditreport/report_generator.go | 6 +-- .../server/auditreport/report_generator_ce.go | 8 ++-- .../auditreport/report_generator_ce_test.go | 18 ++++---- .../auditreport/report_generator_test.go | 45 +++++++++---------- sqle/server/auditreport/report_html.go | 4 +- sqle/utils/file.go | 26 +---------- 8 files changed, 77 insertions(+), 70 deletions(-) create mode 100644 sqle/server/auditreport/export_format.go diff --git a/sqle/server/auditreport/export_format.go b/sqle/server/auditreport/export_format.go new file mode 100644 index 0000000000..762d887ef8 --- /dev/null +++ b/sqle/server/auditreport/export_format.go @@ -0,0 +1,36 @@ +package auditreport + +import ( + "fmt" + "strings" +) + +// ExportFormat 审核报告等业务的导出格式(与 utils 中通用表格导出的 csv/excel 区分)。 +type ExportFormat string + +const ( + CsvExportFormat ExportFormat = "csv" + ExcelExportFormat ExportFormat = "excel" + ExportFormatHTML ExportFormat = "html" + ExportFormatPDF ExportFormat = "pdf" + ExportFormatWORD ExportFormat = "word" +) + +// NormalizeExportFormatStr 规范化导出格式查询参数。 +// 空字符串默认返回 CSV(向后兼容);无效格式返回错误。 +func NormalizeExportFormatStr(format string) (ExportFormat, error) { + switch strings.ToLower(strings.TrimSpace(format)) { + case "html": + return ExportFormatHTML, nil + case "pdf": + return ExportFormatPDF, nil + case "word", "docx": + return ExportFormatWORD, nil + case "excel", "xlsx": + return ExcelExportFormat, nil + case "csv", "": + return CsvExportFormat, nil + default: + return "", fmt.Errorf("unsupported export format: %s", format) + } +} diff --git a/sqle/server/auditreport/report_csv.go b/sqle/server/auditreport/report_csv.go index bfcb7b0595..39b5a4096a 100644 --- a/sqle/server/auditreport/report_csv.go +++ b/sqle/server/auditreport/report_csv.go @@ -16,8 +16,8 @@ func NewCSVReportGenerator() *CSVReportGenerator { } // ReportType 返回生成器支持的导出格式 -func (g *CSVReportGenerator) ReportType() utils.ExportFormat { - return utils.CsvExportFormat +func (g *CSVReportGenerator) ReportType() ExportFormat { + return CsvExportFormat } // CSVHeaders 返回 CSV 报告的表头列表 diff --git a/sqle/server/auditreport/report_generator.go b/sqle/server/auditreport/report_generator.go index f463bc23c5..658888a700 100644 --- a/sqle/server/auditreport/report_generator.go +++ b/sqle/server/auditreport/report_generator.go @@ -45,7 +45,7 @@ type AuditSummary struct { // AuditStatistics 审核结果统计 type AuditStatistics struct { LevelDistribution []LevelCount `json:"level_distribution"` // 按等级分布 - RuleHits []RuleHit `json:"rule_hits"` // 规则命中统计 + RuleHits []RuleHit `json:"rule_hits"` // 规则命中统计 } // LevelCount 等级统计 @@ -72,7 +72,7 @@ type AuditSQLItem struct { RollbackSQL string `json:"rollback_sql"` Description string `json:"description"` // HTML/PDF/WORD 报告扩展字段 - RuleName string `json:"rule_name"` // 触发的规则名称 + RuleName string `json:"rule_name"` // 触发的规则名称 Suggestion string `json:"suggestion"` // 优化建议 } @@ -108,5 +108,5 @@ type ReportGenerator interface { // Generate 根据报告数据生成指定格式的文件 Generate(data *AuditReportData) (*utils.ExportDataResult, error) // ReportType 返回生成器对应的导出格式 - ReportType() utils.ExportFormat + ReportType() ExportFormat } diff --git a/sqle/server/auditreport/report_generator_ce.go b/sqle/server/auditreport/report_generator_ce.go index cd3b2af8cd..d2d034eb44 100644 --- a/sqle/server/auditreport/report_generator_ce.go +++ b/sqle/server/auditreport/report_generator_ce.go @@ -12,17 +12,17 @@ import ( // CE 版仅支持 CSV 和 HTML 两种格式。 // 请求 PDF 或 WORD 格式时返回错误提示,提醒用户需要企业版。 // 无效格式返回错误(REQ-6.3)。 -func ExportAuditReport(format utils.ExportFormat, data *AuditReportData) (*utils.ExportDataResult, error) { +func ExportAuditReport(format ExportFormat, data *AuditReportData) (*utils.ExportDataResult, error) { switch format { - case utils.CsvExportFormat: + case CsvExportFormat: return NewCSVReportGenerator().Generate(data) - case utils.ExportFormatHTML: + case ExportFormatHTML: gen, err := NewHTMLReportGenerator() if err != nil { return nil, err } return gen.Generate(data) - case utils.ExportFormatPDF, utils.ExportFormatWORD: + case ExportFormatPDF, ExportFormatWORD: return nil, fmt.Errorf("export format %s is only supported in enterprise edition", format) default: return nil, fmt.Errorf("unsupported export format: %s", format) diff --git a/sqle/server/auditreport/report_generator_ce_test.go b/sqle/server/auditreport/report_generator_ce_test.go index 8f1a8b7849..e3f852ffa1 100644 --- a/sqle/server/auditreport/report_generator_ce_test.go +++ b/sqle/server/auditreport/report_generator_ce_test.go @@ -5,17 +5,15 @@ package auditreport import ( "strings" "testing" - - "github.com/actiontech/sqle/sqle/utils" ) func TestExportAuditReport_CEEdition_PDFBlocked(t *testing.T) { testCases := map[string]struct { - format utils.ExportFormat + format ExportFormat wantErrSubstr string }{ "PDF format is blocked in CE edition": { - format: utils.ExportFormatPDF, + format: ExportFormatPDF, wantErrSubstr: "enterprise edition", }, } @@ -42,11 +40,11 @@ func TestExportAuditReport_CEEdition_PDFBlocked(t *testing.T) { func TestExportAuditReport_CEEdition_WORDBlocked(t *testing.T) { testCases := map[string]struct { - format utils.ExportFormat + format ExportFormat wantErrSubstr string }{ "WORD format is blocked in CE edition": { - format: utils.ExportFormatWORD, + format: ExportFormatWORD, wantErrSubstr: "enterprise edition", }, } @@ -73,19 +71,19 @@ func TestExportAuditReport_CEEdition_WORDBlocked(t *testing.T) { func TestExportAuditReport_DefaultCSV(t *testing.T) { testCases := map[string]struct { - format utils.ExportFormat + format ExportFormat wantErrSubstr string }{ "invalid format returns error": { - format: utils.ExportFormat("invalid"), + format: ExportFormat("invalid"), wantErrSubstr: "unsupported export format", }, "empty format returns error": { - format: utils.ExportFormat(""), + format: ExportFormat(""), wantErrSubstr: "unsupported export format", }, "unknown format returns error": { - format: utils.ExportFormat("json"), + format: ExportFormat("json"), wantErrSubstr: "unsupported export format", }, } diff --git a/sqle/server/auditreport/report_generator_test.go b/sqle/server/auditreport/report_generator_test.go index c557dce0e1..f4ae8a24fe 100644 --- a/sqle/server/auditreport/report_generator_test.go +++ b/sqle/server/auditreport/report_generator_test.go @@ -5,8 +5,6 @@ import ( "strings" "testing" "time" - - "github.com/actiontech/sqle/sqle/utils" ) // buildTestReportData 构建测试用的 AuditReportData @@ -350,8 +348,8 @@ func TestCSVReportGenerator_SpecialChars(t *testing.T) { func TestCSVReportGenerator_ReportType(t *testing.T) { gen := NewCSVReportGenerator() - if gen.ReportType() != utils.CsvExportFormat { - t.Errorf("ReportType() = %q, want %q", gen.ReportType(), utils.CsvExportFormat) + if gen.ReportType() != CsvExportFormat { + t.Errorf("ReportType() = %q, want %q", gen.ReportType(), CsvExportFormat) } } @@ -365,13 +363,13 @@ func TestHTMLReportGenerator_Normal(t *testing.T) { var _ ReportGenerator = gen testCases := map[string]struct { - data *AuditReportData - wantContentType string - wantFilePrefix string - wantFileSuffix string - wantHTMLTags []string - wantSQLContents []string - wantLabels []string + data *AuditReportData + wantContentType string + wantFilePrefix string + wantFileSuffix string + wantHTMLTags []string + wantSQLContents []string + wantLabels []string }{ "normal data generates valid HTML report": { data: buildTestReportData(), @@ -436,8 +434,8 @@ func TestHTMLReportGenerator_Normal(t *testing.T) { } // Verify ReportType() returns ExportFormatHTML - if gen.ReportType() != utils.ExportFormatHTML { - t.Errorf("ReportType() = %q, want %q", gen.ReportType(), utils.ExportFormatHTML) + if gen.ReportType() != ExportFormatHTML { + t.Errorf("ReportType() = %q, want %q", gen.ReportType(), ExportFormatHTML) } }) } @@ -455,18 +453,18 @@ func TestHTMLReportGenerator_XSSPrevention(t *testing.T) { wantDescription string }{ "script tag in SQL is escaped": { - maliciousSQL: "", - wantAbsent: []string{""}, + maliciousSQL: "", + wantAbsent: []string{""}, wantDescription: "script tags should be HTML-escaped by html/template", }, "img onerror in SQL is escaped": { - maliciousSQL: ``, - wantAbsent: []string{`onerror="alert`}, + maliciousSQL: ``, + wantAbsent: []string{`onerror="alert`}, wantDescription: "event handler attributes should be HTML-escaped", }, "script tag in description is escaped": { - maliciousSQL: "SELECT 1", - wantAbsent: []string{"