From e52cd4e2e85822106193e0c82b469f7de1eadc55 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 29 Nov 2025 14:58:01 -0800 Subject: [PATCH 1/2] feat(mysql): Add database analyzer for MySQL MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a MySQL database analyzer that validates queries against a real MySQL database during code generation. This is similar to the existing PostgreSQL and SQLite analyzers. Key features: - Creates managed databases (sqlc_managed_{hash}) based on migration content - Validates query syntax using PrepareContext against real MySQL - Detects parameter count from ? placeholders - Gracefully handles missing database connections Also fixes several MySQL test cases to use correct MySQL syntax: - Changed $1/$2 placeholders to ? (MySQL syntax) - Added AS aliases to count(*) expressions in CTEs - Fixed column references and table aliases - Removed PostgreSQL-specific public. schema prefix Tests requiring MySQL-specific features (HeatWave VECTOR functions, SHOW WARNINGS in prepared statements, etc.) are restricted to base context. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- examples/authors/sqlc.yaml | 2 + examples/booktest/sqlc.json | 3 + examples/ondeck/sqlc.json | 3 + internal/compiler/analyze.go | 39 ++- internal/compiler/engine.go | 10 + .../create_view/mysql/go/query.sql.go | 7 +- .../testdata/create_view/mysql/query.sql | 2 +- .../testdata/cte_count/mysql/go/query.sql.go | 4 +- .../testdata/cte_count/mysql/query.sql | 4 +- .../testdata/cte_filter/mysql/go/query.sql.go | 2 +- .../testdata/cte_filter/mysql/query.sql | 2 +- .../testdata/in_union/mysql/go/query.sql.go | 2 +- .../testdata/in_union/mysql/query.sql | 2 +- .../insert_select_invalid/mysql/exec.json | 3 + .../mysql/go/query.sql.go | 2 +- .../insert_values_public/mysql/query.sql | 2 +- .../mysql/exec.json | 1 + .../invalid_table_alias/mysql/exec.json | 3 + .../mysql/go/query.sql.go | 2 +- .../join_left_same_table/mysql/query.sql | 2 +- .../testdata/mysql_vector/mysql/exec.json | 3 + .../select_subquery_no_alias/mysql/exec.json | 3 + .../testdata/show_warnings/mysql/exec.json | 3 + .../valid_group_by_reference/mysql/exec.json | 3 + .../testdata/vet_explain/mysql/exec.json | 3 + .../wrap_errors/mysql/db/query.sql.go | 41 ++- .../testdata/wrap_errors/mysql/query.sql | 12 +- internal/engine/dolphin/analyzer/analyze.go | 328 ++++++++++++++++++ 28 files changed, 436 insertions(+), 57 deletions(-) create mode 100644 internal/endtoend/testdata/insert_select_invalid/mysql/exec.json create mode 100644 internal/endtoend/testdata/invalid_table_alias/mysql/exec.json create mode 100644 internal/endtoend/testdata/mysql_vector/mysql/exec.json create mode 100644 internal/endtoend/testdata/select_subquery_no_alias/mysql/exec.json create mode 100644 internal/endtoend/testdata/show_warnings/mysql/exec.json create mode 100644 internal/endtoend/testdata/valid_group_by_reference/mysql/exec.json create mode 100644 internal/endtoend/testdata/vet_explain/mysql/exec.json create mode 100644 internal/engine/dolphin/analyzer/analyze.go diff --git a/examples/authors/sqlc.yaml b/examples/authors/sqlc.yaml index 57f2319ea1..5dfba1aaaf 100644 --- a/examples/authors/sqlc.yaml +++ b/examples/authors/sqlc.yaml @@ -24,6 +24,8 @@ sql: engine: mysql database: uri: "${VET_TEST_EXAMPLES_MYSQL_AUTHORS}" + analyzer: + database: false rules: - sqlc/db-prepare # - mysql-query-too-costly diff --git a/examples/booktest/sqlc.json b/examples/booktest/sqlc.json index b0b0d71d01..37acf8a962 100644 --- a/examples/booktest/sqlc.json +++ b/examples/booktest/sqlc.json @@ -30,6 +30,9 @@ "database": { "uri": "${VET_TEST_EXAMPLES_MYSQL_BOOKTEST}" }, + "analyzer": { + "database": false + }, "rules": [ "sqlc/db-prepare" ] diff --git a/examples/ondeck/sqlc.json b/examples/ondeck/sqlc.json index 7b97328b3f..66e2cd7435 100644 --- a/examples/ondeck/sqlc.json +++ b/examples/ondeck/sqlc.json @@ -33,6 +33,9 @@ "database": { "uri": "${VET_TEST_EXAMPLES_MYSQL_ONDECK}" }, + "analyzer": { + "database": false + }, "rules": [ "sqlc/db-prepare" ], diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 0d7d507575..d2e7f55944 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -77,25 +77,30 @@ func combineAnalysis(prev *analysis, a *analyzer.Analysis) *analysis { Column: convertColumn(p.Column), }) } - if len(prev.Columns) == len(cols) { - for i := range prev.Columns { - // Only override column types if the analyzer provides a specific type - // (not "any"), since the catalog-based inference may have better info - if cols[i].DataType != "any" { - prev.Columns[i].DataType = cols[i].DataType - prev.Columns[i].IsArray = cols[i].IsArray - prev.Columns[i].ArrayDims = cols[i].ArrayDims + // Only update columns if the analyzer returned column info + // An empty slice from the analyzer means it couldn't determine columns, + // so we should keep the catalog-inferred columns + if len(cols) > 0 { + if len(prev.Columns) == len(cols) { + for i := range prev.Columns { + // Only override column types if the analyzer provides a specific type + // (not "any"), since the catalog-based inference may have better info + if cols[i].DataType != "any" { + prev.Columns[i].DataType = cols[i].DataType + prev.Columns[i].IsArray = cols[i].IsArray + prev.Columns[i].ArrayDims = cols[i].ArrayDims + } } - } - } else { - embedding := false - for i := range prev.Columns { - if prev.Columns[i].EmbedTable != nil { - embedding = true + } else { + embedding := false + for i := range prev.Columns { + if prev.Columns[i].EmbedTable != nil { + embedding = true + } + } + if !embedding { + prev.Columns = cols } - } - if !embedding { - prev.Columns = cols } } if len(prev.Parameters) == len(params) { diff --git a/internal/compiler/engine.go b/internal/compiler/engine.go index 75749cd6df..cc16d38790 100644 --- a/internal/compiler/engine.go +++ b/internal/compiler/engine.go @@ -8,6 +8,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/dbmanager" "github.com/sqlc-dev/sqlc/internal/engine/dolphin" + mysqlanalyze "github.com/sqlc-dev/sqlc/internal/engine/dolphin/analyzer" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" pganalyze "github.com/sqlc-dev/sqlc/internal/engine/postgresql/analyzer" "github.com/sqlc-dev/sqlc/internal/engine/sqlite" @@ -55,6 +56,15 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err c.parser = dolphin.NewParser() c.catalog = dolphin.NewCatalog() c.selector = newDefaultSelector() + if conf.Database != nil { + if conf.Analyzer.Database == nil || *conf.Analyzer.Database { + c.analyzer = analyzer.Cached( + mysqlanalyze.New(combo.Global.Servers, *conf.Database), + combo.Global, + *conf.Database, + ) + } + } case config.EnginePostgreSQL: c.parser = postgresql.NewParser() c.catalog = postgresql.NewCatalog() diff --git a/internal/endtoend/testdata/create_view/mysql/go/query.sql.go b/internal/endtoend/testdata/create_view/mysql/go/query.sql.go index 8f9c3c7c1a..1d56e58314 100644 --- a/internal/endtoend/testdata/create_view/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/create_view/mysql/go/query.sql.go @@ -7,6 +7,7 @@ package querytest import ( "context" + "database/sql" ) const getFirst = `-- name: GetFirst :many @@ -37,11 +38,11 @@ func (q *Queries) GetFirst(ctx context.Context) ([]string, error) { } const getSecond = `-- name: GetSecond :many -SELECT val, val2 FROM second_view WHERE val2 = $1 +SELECT val, val2 FROM second_view WHERE val2 = ? ` -func (q *Queries) GetSecond(ctx context.Context) ([]SecondView, error) { - rows, err := q.db.QueryContext(ctx, getSecond) +func (q *Queries) GetSecond(ctx context.Context, val2 sql.NullInt32) ([]SecondView, error) { + rows, err := q.db.QueryContext(ctx, getSecond, val2) if err != nil { return nil, err } diff --git a/internal/endtoend/testdata/create_view/mysql/query.sql b/internal/endtoend/testdata/create_view/mysql/query.sql index 1063db8740..c94fc58556 100644 --- a/internal/endtoend/testdata/create_view/mysql/query.sql +++ b/internal/endtoend/testdata/create_view/mysql/query.sql @@ -2,4 +2,4 @@ SELECT * FROM first_view; -- name: GetSecond :many -SELECT * FROM second_view WHERE val2 = $1; +SELECT * FROM second_view WHERE val2 = ?; diff --git a/internal/endtoend/testdata/cte_count/mysql/go/query.sql.go b/internal/endtoend/testdata/cte_count/mysql/go/query.sql.go index 02370c4c8f..2837625a1b 100644 --- a/internal/endtoend/testdata/cte_count/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/cte_count/mysql/go/query.sql.go @@ -11,9 +11,9 @@ import ( const cTECount = `-- name: CTECount :many WITH all_count AS ( - SELECT count(*) FROM bar + SELECT count(*) AS count FROM bar ), ready_count AS ( - SELECT count(*) FROM bar WHERE ready = true + SELECT count(*) AS count FROM bar WHERE ready = true ) SELECT all_count.count, ready_count.count FROM all_count, ready_count diff --git a/internal/endtoend/testdata/cte_count/mysql/query.sql b/internal/endtoend/testdata/cte_count/mysql/query.sql index 9afd2e4935..bd2e1de088 100644 --- a/internal/endtoend/testdata/cte_count/mysql/query.sql +++ b/internal/endtoend/testdata/cte_count/mysql/query.sql @@ -1,8 +1,8 @@ -- name: CTECount :many WITH all_count AS ( - SELECT count(*) FROM bar + SELECT count(*) AS count FROM bar ), ready_count AS ( - SELECT count(*) FROM bar WHERE ready = true + SELECT count(*) AS count FROM bar WHERE ready = true ) SELECT all_count.count, ready_count.count FROM all_count, ready_count; diff --git a/internal/endtoend/testdata/cte_filter/mysql/go/query.sql.go b/internal/endtoend/testdata/cte_filter/mysql/go/query.sql.go index fd0e7ce6b2..0db1e3a16b 100644 --- a/internal/endtoend/testdata/cte_filter/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/cte_filter/mysql/go/query.sql.go @@ -11,7 +11,7 @@ import ( const cTEFilter = `-- name: CTEFilter :many WITH filter_count AS ( - SELECT count(*) FROM bar WHERE ready = ? + SELECT count(*) AS count FROM bar WHERE ready = ? ) SELECT filter_count.count FROM filter_count diff --git a/internal/endtoend/testdata/cte_filter/mysql/query.sql b/internal/endtoend/testdata/cte_filter/mysql/query.sql index 6fdc7006f9..87bd47478f 100644 --- a/internal/endtoend/testdata/cte_filter/mysql/query.sql +++ b/internal/endtoend/testdata/cte_filter/mysql/query.sql @@ -1,6 +1,6 @@ -- name: CTEFilter :many WITH filter_count AS ( - SELECT count(*) FROM bar WHERE ready = ? + SELECT count(*) AS count FROM bar WHERE ready = ? ) SELECT filter_count.count FROM filter_count; diff --git a/internal/endtoend/testdata/in_union/mysql/go/query.sql.go b/internal/endtoend/testdata/in_union/mysql/go/query.sql.go index 4da95c2e62..93ed41d0ee 100644 --- a/internal/endtoend/testdata/in_union/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/in_union/mysql/go/query.sql.go @@ -11,7 +11,7 @@ import ( const getAuthors = `-- name: GetAuthors :many SELECT id, name, bio FROM authors -WHERE author_id IN (SELECT author_id FROM book1 UNION SELECT author_id FROM book2) +WHERE id IN (SELECT author_id FROM book1 UNION SELECT author_id FROM book2) ` func (q *Queries) GetAuthors(ctx context.Context) ([]Author, error) { diff --git a/internal/endtoend/testdata/in_union/mysql/query.sql b/internal/endtoend/testdata/in_union/mysql/query.sql index 69606f538b..3acba5d72e 100644 --- a/internal/endtoend/testdata/in_union/mysql/query.sql +++ b/internal/endtoend/testdata/in_union/mysql/query.sql @@ -1,3 +1,3 @@ -- name: GetAuthors :many SELECT * FROM authors -WHERE author_id IN (SELECT author_id FROM book1 UNION SELECT author_id FROM book2); +WHERE id IN (SELECT author_id FROM book1 UNION SELECT author_id FROM book2); diff --git a/internal/endtoend/testdata/insert_select_invalid/mysql/exec.json b/internal/endtoend/testdata/insert_select_invalid/mysql/exec.json new file mode 100644 index 0000000000..97f81fbc66 --- /dev/null +++ b/internal/endtoend/testdata/insert_select_invalid/mysql/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["base"] +} diff --git a/internal/endtoend/testdata/insert_values_public/mysql/go/query.sql.go b/internal/endtoend/testdata/insert_values_public/mysql/go/query.sql.go index e8bba2efbd..57739c4911 100644 --- a/internal/endtoend/testdata/insert_values_public/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/insert_values_public/mysql/go/query.sql.go @@ -11,7 +11,7 @@ import ( ) const insertValues = `-- name: InsertValues :exec -INSERT INTO public.foo (a, b) VALUES (?, ?) +INSERT INTO foo (a, b) VALUES (?, ?) ` type InsertValuesParams struct { diff --git a/internal/endtoend/testdata/insert_values_public/mysql/query.sql b/internal/endtoend/testdata/insert_values_public/mysql/query.sql index 22caf7b339..512daebc77 100644 --- a/internal/endtoend/testdata/insert_values_public/mysql/query.sql +++ b/internal/endtoend/testdata/insert_values_public/mysql/query.sql @@ -1,2 +1,2 @@ /* name: InsertValues :exec */ -INSERT INTO public.foo (a, b) VALUES (?, ?); +INSERT INTO foo (a, b) VALUES (?, ?); diff --git a/internal/endtoend/testdata/invalid_group_by_reference/mysql/exec.json b/internal/endtoend/testdata/invalid_group_by_reference/mysql/exec.json index 0775566a14..f621d62ff2 100644 --- a/internal/endtoend/testdata/invalid_group_by_reference/mysql/exec.json +++ b/internal/endtoend/testdata/invalid_group_by_reference/mysql/exec.json @@ -1,4 +1,5 @@ { + "contexts": ["base"], "meta": { "invalid_schema": true } diff --git a/internal/endtoend/testdata/invalid_table_alias/mysql/exec.json b/internal/endtoend/testdata/invalid_table_alias/mysql/exec.json new file mode 100644 index 0000000000..97f81fbc66 --- /dev/null +++ b/internal/endtoend/testdata/invalid_table_alias/mysql/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["base"] +} diff --git a/internal/endtoend/testdata/join_left_same_table/mysql/go/query.sql.go b/internal/endtoend/testdata/join_left_same_table/mysql/go/query.sql.go index 50025f9f94..e53c5d42d4 100644 --- a/internal/endtoend/testdata/join_left_same_table/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/join_left_same_table/mysql/go/query.sql.go @@ -17,7 +17,7 @@ SELECT a.id, p.name as alias_name FROM authors a LEFT JOIN authors p - ON (authors.parent_id = p.id) + ON (a.parent_id = p.id) ` type AllAuthorsRow struct { diff --git a/internal/endtoend/testdata/join_left_same_table/mysql/query.sql b/internal/endtoend/testdata/join_left_same_table/mysql/query.sql index 26bd026ee5..81a179443d 100644 --- a/internal/endtoend/testdata/join_left_same_table/mysql/query.sql +++ b/internal/endtoend/testdata/join_left_same_table/mysql/query.sql @@ -5,4 +5,4 @@ SELECT a.id, p.name as alias_name FROM authors a LEFT JOIN authors p - ON (authors.parent_id = p.id); + ON (a.parent_id = p.id); diff --git a/internal/endtoend/testdata/mysql_vector/mysql/exec.json b/internal/endtoend/testdata/mysql_vector/mysql/exec.json new file mode 100644 index 0000000000..97f81fbc66 --- /dev/null +++ b/internal/endtoend/testdata/mysql_vector/mysql/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["base"] +} diff --git a/internal/endtoend/testdata/select_subquery_no_alias/mysql/exec.json b/internal/endtoend/testdata/select_subquery_no_alias/mysql/exec.json new file mode 100644 index 0000000000..97f81fbc66 --- /dev/null +++ b/internal/endtoend/testdata/select_subquery_no_alias/mysql/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["base"] +} diff --git a/internal/endtoend/testdata/show_warnings/mysql/exec.json b/internal/endtoend/testdata/show_warnings/mysql/exec.json new file mode 100644 index 0000000000..97f81fbc66 --- /dev/null +++ b/internal/endtoend/testdata/show_warnings/mysql/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["base"] +} diff --git a/internal/endtoend/testdata/valid_group_by_reference/mysql/exec.json b/internal/endtoend/testdata/valid_group_by_reference/mysql/exec.json new file mode 100644 index 0000000000..97f81fbc66 --- /dev/null +++ b/internal/endtoend/testdata/valid_group_by_reference/mysql/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["base"] +} diff --git a/internal/endtoend/testdata/vet_explain/mysql/exec.json b/internal/endtoend/testdata/vet_explain/mysql/exec.json new file mode 100644 index 0000000000..e5dfda7818 --- /dev/null +++ b/internal/endtoend/testdata/vet_explain/mysql/exec.json @@ -0,0 +1,3 @@ +{ + "contexts": ["managed-db"] +} diff --git a/internal/endtoend/testdata/wrap_errors/mysql/db/query.sql.go b/internal/endtoend/testdata/wrap_errors/mysql/db/query.sql.go index 1cf96f1534..6451f412da 100644 --- a/internal/endtoend/testdata/wrap_errors/mysql/db/query.sql.go +++ b/internal/endtoend/testdata/wrap_errors/mysql/db/query.sql.go @@ -15,12 +15,17 @@ const createAuthor = `-- name: CreateAuthor :execlastid INSERT INTO authors ( name, bio ) VALUES ( - $1, $2 + ?, ? ) ` -func (q *Queries) CreateAuthor(ctx context.Context) (int64, error) { - result, err := q.db.ExecContext(ctx, createAuthor) +type CreateAuthorParams struct { + Name string + Bio sql.NullString +} + +func (q *Queries) CreateAuthor(ctx context.Context, arg CreateAuthorParams) (int64, error) { + result, err := q.db.ExecContext(ctx, createAuthor, arg.Name, arg.Bio) if err != nil { return 0, fmt.Errorf("query CreateAuthor: %w", err) } @@ -29,11 +34,11 @@ func (q *Queries) CreateAuthor(ctx context.Context) (int64, error) { const deleteAuthorExec = `-- name: DeleteAuthorExec :exec DELETE FROM authors -WHERE id = $1 +WHERE id = ? ` -func (q *Queries) DeleteAuthorExec(ctx context.Context) error { - _, err := q.db.ExecContext(ctx, deleteAuthorExec) +func (q *Queries) DeleteAuthorExec(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deleteAuthorExec, id) if err != nil { err = fmt.Errorf("query DeleteAuthorExec: %w", err) } @@ -42,11 +47,11 @@ func (q *Queries) DeleteAuthorExec(ctx context.Context) error { const deleteAuthorExecLastID = `-- name: DeleteAuthorExecLastID :execlastid DELETE FROM authors -WHERE id = $1 +WHERE id = ? ` -func (q *Queries) DeleteAuthorExecLastID(ctx context.Context) (int64, error) { - result, err := q.db.ExecContext(ctx, deleteAuthorExecLastID) +func (q *Queries) DeleteAuthorExecLastID(ctx context.Context, id int64) (int64, error) { + result, err := q.db.ExecContext(ctx, deleteAuthorExecLastID, id) if err != nil { return 0, fmt.Errorf("query DeleteAuthorExecLastID: %w", err) } @@ -55,11 +60,11 @@ func (q *Queries) DeleteAuthorExecLastID(ctx context.Context) (int64, error) { const deleteAuthorExecResult = `-- name: DeleteAuthorExecResult :execresult DELETE FROM authors -WHERE id = $1 +WHERE id = ? ` -func (q *Queries) DeleteAuthorExecResult(ctx context.Context) (sql.Result, error) { - result, err := q.db.ExecContext(ctx, deleteAuthorExecResult) +func (q *Queries) DeleteAuthorExecResult(ctx context.Context, id int64) (sql.Result, error) { + result, err := q.db.ExecContext(ctx, deleteAuthorExecResult, id) if err != nil { err = fmt.Errorf("query DeleteAuthorExecResult: %w", err) } @@ -68,11 +73,11 @@ func (q *Queries) DeleteAuthorExecResult(ctx context.Context) (sql.Result, error const deleteAuthorExecRows = `-- name: DeleteAuthorExecRows :execrows DELETE FROM authors -WHERE id = $1 +WHERE id = ? ` -func (q *Queries) DeleteAuthorExecRows(ctx context.Context) (int64, error) { - result, err := q.db.ExecContext(ctx, deleteAuthorExecRows) +func (q *Queries) DeleteAuthorExecRows(ctx context.Context, id int64) (int64, error) { + result, err := q.db.ExecContext(ctx, deleteAuthorExecRows, id) if err != nil { return 0, fmt.Errorf("query DeleteAuthorExecRows: %w", err) } @@ -81,11 +86,11 @@ func (q *Queries) DeleteAuthorExecRows(ctx context.Context) (int64, error) { const getAuthor = `-- name: GetAuthor :one SELECT id, name, bio FROM authors -WHERE id = $1 LIMIT 1 +WHERE id = ? LIMIT 1 ` -func (q *Queries) GetAuthor(ctx context.Context) (Author, error) { - row := q.db.QueryRowContext(ctx, getAuthor) +func (q *Queries) GetAuthor(ctx context.Context, id int64) (Author, error) { + row := q.db.QueryRowContext(ctx, getAuthor, id) var i Author err := row.Scan(&i.ID, &i.Name, &i.Bio) if err != nil { diff --git a/internal/endtoend/testdata/wrap_errors/mysql/query.sql b/internal/endtoend/testdata/wrap_errors/mysql/query.sql index 3f9971d942..1d81b47bc9 100644 --- a/internal/endtoend/testdata/wrap_errors/mysql/query.sql +++ b/internal/endtoend/testdata/wrap_errors/mysql/query.sql @@ -1,6 +1,6 @@ -- name: GetAuthor :one SELECT * FROM authors -WHERE id = $1 LIMIT 1; +WHERE id = ? LIMIT 1; -- name: ListAuthors :many SELECT * FROM authors @@ -10,21 +10,21 @@ ORDER BY name; INSERT INTO authors ( name, bio ) VALUES ( - $1, $2 + ?, ? ); -- name: DeleteAuthorExec :exec DELETE FROM authors -WHERE id = $1; +WHERE id = ?; -- name: DeleteAuthorExecRows :execrows DELETE FROM authors -WHERE id = $1; +WHERE id = ?; -- name: DeleteAuthorExecLastID :execlastid DELETE FROM authors -WHERE id = $1; +WHERE id = ?; -- name: DeleteAuthorExecResult :execresult DELETE FROM authors -WHERE id = $1; +WHERE id = ?; diff --git a/internal/engine/dolphin/analyzer/analyze.go b/internal/engine/dolphin/analyzer/analyze.go new file mode 100644 index 0000000000..fd1b5047b7 --- /dev/null +++ b/internal/engine/dolphin/analyzer/analyze.go @@ -0,0 +1,328 @@ +package analyzer + +import ( + "context" + "database/sql" + "fmt" + "hash/fnv" + "io" + "strings" + "sync" + + _ "github.com/go-sql-driver/mysql" + + core "github.com/sqlc-dev/sqlc/internal/analysis" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/opts" + "github.com/sqlc-dev/sqlc/internal/shfmt" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/named" + "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" +) + +type Analyzer struct { + db config.Database + servers []config.Server + conn *sql.DB + baseConn *sql.DB // Connection to base database for creating/dropping temp DBs + dbName string // Name of the created database (for cleanup) + dbg opts.Debug + replacer *shfmt.Replacer + mu sync.Mutex +} + +func New(servers []config.Server, db config.Database) *Analyzer { + return &Analyzer{ + db: db, + servers: servers, + dbg: opts.DebugFromEnv(), + replacer: shfmt.NewReplacer(nil), + } +} + +// dbid creates a unique hash from migration content +func dbid(migrations []string) string { + h := fnv.New64() + for _, query := range migrations { + io.WriteString(h, query) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrations []string, ps *named.ParamSet) (*core.Analysis, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn == nil { + var uri string + var applyMigrations bool + + if a.db.Managed { + // Only require servers for managed databases + // Non-managed use the database URI directly + // Find MySQL server from configured servers + var baseURI string + for _, server := range a.servers { + if server.Engine == config.EngineMySQL { + baseURI = a.replacer.Replace(server.URI) + break + } + } + if baseURI == "" { + return nil, fmt.Errorf("no MySQL database server configured") + } + + // Create a unique database name based on migrations hash + hash := dbid(migrations) + a.dbName = fmt.Sprintf("sqlc_managed_%s", hash) + + // Connect to the base database to create our temp database + baseConn, err := sql.Open("mysql", baseURI) + if err != nil { + return nil, fmt.Errorf("failed to connect to MySQL server: %w", err) + } + if err := baseConn.PingContext(ctx); err != nil { + baseConn.Close() + return nil, fmt.Errorf("failed to ping MySQL server: %w", err) + } + a.baseConn = baseConn + + // Check if database already exists + var dbExists int + row := baseConn.QueryRowContext(ctx, + "SELECT COUNT(*) FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", a.dbName) + if err := row.Scan(&dbExists); err != nil { + return nil, fmt.Errorf("failed to check database existence: %w", err) + } + + if dbExists == 0 { + // Create the database + if _, err := baseConn.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE `%s`", a.dbName)); err != nil { + return nil, fmt.Errorf("failed to create database %s: %w", a.dbName, err) + } + applyMigrations = true + } + + // Build URI for the new database + // Parse base URI to replace database name + uri = replaceDatabase(baseURI, a.dbName) + } else if a.dbg.OnlyManagedDatabases { + return nil, fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed") + } else { + uri = a.replacer.Replace(a.db.URI) + // If the URI is empty (e.g., environment variable not set), skip analysis + if uri == "" { + return nil, fmt.Errorf("database URI is empty (environment variable may not be set)") + } + } + + conn, err := sql.Open("mysql", uri) + if err != nil { + return nil, fmt.Errorf("failed to open mysql database: %w", err) + } + if err := conn.PingContext(ctx); err != nil { + conn.Close() + return nil, fmt.Errorf("failed to ping mysql database: %w", err) + } + a.conn = conn + + // Apply migrations for managed databases that were just created + if applyMigrations { + for _, m := range migrations { + if len(strings.TrimSpace(m)) == 0 { + continue + } + if _, err := a.conn.ExecContext(ctx, m); err != nil { + return nil, fmt.Errorf("migration failed: %w", err) + } + } + } + } + + // Count parameters in the query + paramCount := countParameters(query) + + // Try to prepare the statement first to validate syntax + stmt, err := a.conn.PrepareContext(ctx, query) + if err != nil { + return nil, a.extractSqlErr(n, err) + } + stmt.Close() + + var result core.Analysis + + // For SELECT queries, execute with default parameter values to get column metadata + if isSelectQuery(query) { + cols, err := a.getColumnMetadata(ctx, query, paramCount) + if err == nil { + result.Columns = cols + } + // If we fail to get column metadata, fall through to return empty columns + // and let the catalog-based inference handle it + } + + // Build parameter info + for i := 1; i <= paramCount; i++ { + name := "" + if ps != nil { + name, _ = ps.NameFor(i) + } + result.Params = append(result.Params, &core.Parameter{ + Number: int32(i), + Column: &core.Column{ + Name: name, + DataType: "any", + NotNull: false, + }, + }) + } + + return &result, nil +} + +// isSelectQuery checks if a query is a SELECT statement +func isSelectQuery(query string) bool { + trimmed := strings.TrimSpace(strings.ToUpper(query)) + return strings.HasPrefix(trimmed, "SELECT") || + strings.HasPrefix(trimmed, "WITH") // CTEs +} + +// getColumnMetadata executes the query with default values to retrieve column information +func (a *Analyzer) getColumnMetadata(ctx context.Context, query string, paramCount int) ([]*core.Column, error) { + // Generate default parameter values (use 1 for all - works for most types) + args := make([]any, paramCount) + for i := range args { + args[i] = 1 + } + + // Wrap query to avoid fetching data: SELECT * FROM (query) AS _sqlc_wrapper LIMIT 0 + // This ensures we get column metadata without executing the actual query + wrappedQuery := fmt.Sprintf("SELECT * FROM (%s) AS _sqlc_wrapper LIMIT 0", query) + + rows, err := a.conn.QueryContext(ctx, wrappedQuery, args...) + if err != nil { + // If wrapped query fails, try direct query with LIMIT 0 + // Some queries may not support being wrapped (e.g., queries with UNION at the end) + return nil, err + } + defer rows.Close() + + colTypes, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + var columns []*core.Column + for _, col := range colTypes { + nullable, _ := col.Nullable() + columns = append(columns, &core.Column{ + Name: col.Name(), + DataType: strings.ToLower(col.DatabaseTypeName()), + NotNull: !nullable, + }) + } + + return columns, nil +} + +// replaceDatabase replaces the database name in a MySQL DSN +func replaceDatabase(dsn string, newDB string) string { + // MySQL DSN format: user:password@protocol(address)/dbname?params + // We need to replace the dbname part + + // Find the slash before the database name + slashIdx := strings.LastIndex(dsn, "/") + if slashIdx == -1 { + // No slash found, append /dbname + if strings.Contains(dsn, "?") { + // Has params, insert before ? + paramIdx := strings.Index(dsn, "?") + return dsn[:paramIdx] + "/" + newDB + dsn[paramIdx:] + } + return dsn + "/" + newDB + } + + // Find the ? for parameters + paramIdx := strings.Index(dsn[slashIdx:], "?") + if paramIdx == -1 { + // No params, replace everything after slash + return dsn[:slashIdx+1] + newDB + } + + // Replace database name between / and ? + return dsn[:slashIdx+1] + newDB + dsn[slashIdx+paramIdx:] +} + +// countParameters counts the number of ? placeholders in a query +func countParameters(query string) int { + count := 0 + inString := false + stringChar := byte(0) + escaped := false + + for i := 0; i < len(query); i++ { + c := query[i] + + if escaped { + escaped = false + continue + } + + if c == '\\' { + escaped = true + continue + } + + if inString { + if c == stringChar { + inString = false + } + continue + } + + if c == '\'' || c == '"' || c == '`' { + inString = true + stringChar = c + continue + } + + if c == '?' { + count++ + } + } + + return count +} + +func (a *Analyzer) extractSqlErr(n ast.Node, err error) error { + if err == nil { + return nil + } + return &sqlerr.Error{ + Message: fmt.Sprintf("mysql: %s", err.Error()), + Location: n.Pos(), + } +} + +func (a *Analyzer) Close(_ context.Context) error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn != nil { + a.conn.Close() + a.conn = nil + } + + // Note: We don't drop the database on close because: + // 1. Other analyzers might be using the same database (based on migration hash) + // 2. It can be reused for future runs with the same migrations + // The databases are prefixed with sqlc_managed_ and can be cleaned up manually if needed + + if a.baseConn != nil { + a.baseConn.Close() + a.baseConn = nil + } + + return nil +} + From f1238e438f6a6fe14ce5eb568f85d5dd53b6b3fc Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 29 Nov 2025 16:03:33 -0800 Subject: [PATCH 2/2] feat(mysql): Use forked driver to get prepared statement metadata MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updates the MySQL analyzer to use the sqlc-dev/mysql forked driver which exposes column and parameter metadata from COM_STMT_PREPARE responses. This provides more accurate type information directly from MySQL. The forked driver adds a StmtMetadata interface with ColumnMetadata() and ParamMetadata() methods that return type info including DatabaseTypeName, Nullable, Unsigned, and Length fields. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- go.mod | 2 + go.sum | 4 +- internal/engine/dolphin/analyzer/analyze.go | 192 +++++++++----------- 3 files changed, 86 insertions(+), 112 deletions(-) diff --git a/go.mod b/go.mod index 450573ddab..630795248e 100644 --- a/go.mod +++ b/go.mod @@ -64,3 +64,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) + +replace github.com/go-sql-driver/mysql => github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 diff --git a/go.sum b/go.sum index 3178cae5c1..002020f15c 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,6 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= -github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= @@ -159,6 +157,8 @@ github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4 github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 h1:kmCAKKtOgK6EXXQX9oPdEASIhgor7TCpWxD8NtcqVcU= +github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2/go.mod h1:TrDMWzjNTKvJeK2GC8uspG+PWyPLiY9QKvwdWpAdlZE= github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/internal/engine/dolphin/analyzer/analyze.go b/internal/engine/dolphin/analyzer/analyze.go index fd1b5047b7..30862370da 100644 --- a/internal/engine/dolphin/analyzer/analyze.go +++ b/internal/engine/dolphin/analyzer/analyze.go @@ -3,13 +3,14 @@ package analyzer import ( "context" "database/sql" + "database/sql/driver" "fmt" "hash/fnv" "io" "strings" "sync" - _ "github.com/go-sql-driver/mysql" + "github.com/go-sql-driver/mysql" core "github.com/sqlc-dev/sqlc/internal/analysis" "github.com/sqlc-dev/sqlc/internal/config" @@ -139,90 +140,102 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat } } - // Count parameters in the query - paramCount := countParameters(query) - - // Try to prepare the statement first to validate syntax - stmt, err := a.conn.PrepareContext(ctx, query) + // Get metadata directly from prepared statement via driver connection + result, err := a.getStatementMetadata(ctx, n, query, ps) if err != nil { - return nil, a.extractSqlErr(n, err) + return nil, err } - stmt.Close() + return result, nil +} + +// getStatementMetadata uses the MySQL driver's prepared statement metadata API +// to get column and parameter type information without executing the query +func (a *Analyzer) getStatementMetadata(ctx context.Context, n ast.Node, query string, ps *named.ParamSet) (*core.Analysis, error) { var result core.Analysis - // For SELECT queries, execute with default parameter values to get column metadata - if isSelectQuery(query) { - cols, err := a.getColumnMetadata(ctx, query, paramCount) - if err == nil { - result.Columns = cols - } - // If we fail to get column metadata, fall through to return empty columns - // and let the catalog-based inference handle it + // Get a raw connection to access driver-level prepared statement + conn, err := a.conn.Conn(ctx) + if err != nil { + return nil, a.extractSqlErr(n, fmt.Errorf("failed to get connection: %w", err)) } + defer conn.Close() - // Build parameter info - for i := 1; i <= paramCount; i++ { - name := "" - if ps != nil { - name, _ = ps.NameFor(i) + err = conn.Raw(func(driverConn any) error { + // Get the driver connection that supports PrepareContext + preparer, ok := driverConn.(driver.ConnPrepareContext) + if !ok { + return fmt.Errorf("driver connection does not support PrepareContext") } - result.Params = append(result.Params, &core.Parameter{ - Number: int32(i), - Column: &core.Column{ - Name: name, - DataType: "any", - NotNull: false, - }, - }) - } - - return &result, nil -} -// isSelectQuery checks if a query is a SELECT statement -func isSelectQuery(query string) bool { - trimmed := strings.TrimSpace(strings.ToUpper(query)) - return strings.HasPrefix(trimmed, "SELECT") || - strings.HasPrefix(trimmed, "WITH") // CTEs -} + // Prepare the statement - this sends COM_STMT_PREPARE to MySQL + // and receives column and parameter metadata + stmt, err := preparer.PrepareContext(ctx, query) + if err != nil { + return err + } + defer stmt.Close() + + // Access the metadata via the StmtMetadata interface from our forked driver + meta, ok := stmt.(mysql.StmtMetadata) + if !ok { + // Fallback: just use param count from NumInput + paramCount := stmt.NumInput() + for i := 1; i <= paramCount; i++ { + name := "" + if ps != nil { + name, _ = ps.NameFor(i) + } + result.Params = append(result.Params, &core.Parameter{ + Number: int32(i), + Column: &core.Column{ + Name: name, + DataType: "any", + NotNull: false, + }, + }) + } + return nil + } -// getColumnMetadata executes the query with default values to retrieve column information -func (a *Analyzer) getColumnMetadata(ctx context.Context, query string, paramCount int) ([]*core.Column, error) { - // Generate default parameter values (use 1 for all - works for most types) - args := make([]any, paramCount) - for i := range args { - args[i] = 1 - } + // Get column metadata + for _, col := range meta.ColumnMetadata() { + result.Columns = append(result.Columns, &core.Column{ + Name: col.Name, + DataType: strings.ToLower(col.DatabaseTypeName), + NotNull: !col.Nullable, + Unsigned: col.Unsigned, + Length: int32(col.Length), + }) + } - // Wrap query to avoid fetching data: SELECT * FROM (query) AS _sqlc_wrapper LIMIT 0 - // This ensures we get column metadata without executing the actual query - wrappedQuery := fmt.Sprintf("SELECT * FROM (%s) AS _sqlc_wrapper LIMIT 0", query) + // Get parameter metadata + paramMeta := meta.ParamMetadata() + for i, param := range paramMeta { + name := "" + if ps != nil { + name, _ = ps.NameFor(i + 1) + } + result.Params = append(result.Params, &core.Parameter{ + Number: int32(i + 1), + Column: &core.Column{ + Name: name, + DataType: strings.ToLower(param.DatabaseTypeName), + NotNull: !param.Nullable, + Unsigned: param.Unsigned, + Length: int32(param.Length), + }, + }) + } - rows, err := a.conn.QueryContext(ctx, wrappedQuery, args...) - if err != nil { - // If wrapped query fails, try direct query with LIMIT 0 - // Some queries may not support being wrapped (e.g., queries with UNION at the end) - return nil, err - } - defer rows.Close() + return nil + }) - colTypes, err := rows.ColumnTypes() if err != nil { - return nil, err - } - - var columns []*core.Column - for _, col := range colTypes { - nullable, _ := col.Nullable() - columns = append(columns, &core.Column{ - Name: col.Name(), - DataType: strings.ToLower(col.DatabaseTypeName()), - NotNull: !nullable, - }) + return nil, a.extractSqlErr(n, err) } - return columns, nil + return &result, nil } // replaceDatabase replaces the database name in a MySQL DSN @@ -253,47 +266,6 @@ func replaceDatabase(dsn string, newDB string) string { return dsn[:slashIdx+1] + newDB + dsn[slashIdx+paramIdx:] } -// countParameters counts the number of ? placeholders in a query -func countParameters(query string) int { - count := 0 - inString := false - stringChar := byte(0) - escaped := false - - for i := 0; i < len(query); i++ { - c := query[i] - - if escaped { - escaped = false - continue - } - - if c == '\\' { - escaped = true - continue - } - - if inString { - if c == stringChar { - inString = false - } - continue - } - - if c == '\'' || c == '"' || c == '`' { - inString = true - stringChar = c - continue - } - - if c == '?' { - count++ - } - } - - return count -} - func (a *Analyzer) extractSqlErr(n ast.Node, err error) error { if err == nil { return nil