From 6e0a7a94088470aba137ccc758c24e3488897cb9 Mon Sep 17 00:00:00 2001 From: eric Date: Fri, 5 Jun 2026 11:13:54 -0700 Subject: [PATCH] fix(iceberg): Tier-1 PostgreSQL-compatibility fixes Addresses the highest-severity data-integrity gaps from the Iceberg QA report (docs/iceberg-pg-syntax-qa.md). All verified live against the Iceberg backend and covered by unit tests. - DDL hybrid warn/error (Iceberg only; DuckLake keeps silent-strip for sqlmesh/dbt): WARNING for unenforced PK/UNIQUE/CHECK/FK; ERROR (0A000) for the silently-NULL features SERIAL/BIGSERIAL, GENERATED ... STORED, and DEFAULT /now(). DEFAULT NULL and NOT NULL preserved. Adds a Warnings channel on the transpile Result, surfaced as NoticeResponse in both simple and extended protocols. - EXPLAIN (ANALYZE) of a write no longer double-executes: GetQuerySchema returns a synthetic schema for EXPLAIN without running it, and the extended-protocol Describe path no longer probe-executes EXPLAIN. - DROP COLUMN guard: DDL-time WARNING, and the Iceberg "newer schema id" scan failure is mapped to a clear 0A000 message instead of a raw XX000. - bytea hex literals '\xDEADBEEF'::bytea now decode to the correct bytes (unhex), and B'101' bit-string literals map to '101'::BIT. - jsonb || now merges objects (json_merge_patch) instead of silently string-concatenating; plain string/array || is left untouched. - Writable-CTE UPDATE ... RETURNING that reads a modified column is rejected (would return pre-update values); RETURNING an unmodified key and RETURNING * still work (Airbyte pattern preserved). Co-Authored-By: Claude Opus 4.8 (1M context) --- duckdbservice/arrow_helpers.go | 32 +++ duckdbservice/arrow_helpers_test.go | 48 ++++ server/conn.go | 121 ++++++++-- server/conn_describe_test.go | 36 +++ server/schema_evolution_error_test.go | 44 ++++ transpiler/backend/profile.go | 40 +++- transpiler/config.go | 5 + transpiler/transform/ddl.go | 217 +++++++++++++++--- transpiler/transform/ddl_test.go | 198 ++++++++++++++++ transpiler/transform/literals.go | 122 ++++++++++ transpiler/transform/literals_test.go | 83 +++++++ transpiler/transform/operators.go | 59 +++++ transpiler/transform/transform.go | 6 + transpiler/transform/writablecte.go | 59 ++++- .../transform/writablecte_returning_test.go | 73 ++++++ transpiler/transpiler.go | 40 +++- transpiler/transpiler_test.go | 26 ++- 17 files changed, 1146 insertions(+), 63 deletions(-) create mode 100644 server/schema_evolution_error_test.go create mode 100644 transpiler/transform/ddl_test.go create mode 100644 transpiler/transform/literals.go create mode 100644 transpiler/transform/literals_test.go create mode 100644 transpiler/transform/writablecte_returning_test.go diff --git a/duckdbservice/arrow_helpers.go b/duckdbservice/arrow_helpers.go index 44e2273e..817d1523 100644 --- a/duckdbservice/arrow_helpers.go +++ b/duckdbservice/arrow_helpers.go @@ -76,6 +76,24 @@ type contextQueryer interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } +// isExplainQuery reports whether the (already upper-cased) query is an EXPLAIN +// statement, i.e. starts with the EXPLAIN keyword followed by a delimiter. +func isExplainQuery(upper string) bool { + s := strings.TrimSpace(upper) + const kw = "EXPLAIN" + if !strings.HasPrefix(s, kw) { + return false + } + if len(s) == len(kw) { + return true + } + switch s[len(kw)] { + case ' ', '\t', '\n', '\r', '(': + return true + } + return false +} + func isNil(i contextQueryer) bool { if i == nil { return true @@ -89,6 +107,20 @@ func GetQuerySchema(ctx context.Context, db contextQueryer, query string, tx con q := strings.TrimRight(strings.TrimSpace(query), ";") queryWithLimit := q upper := strings.ToUpper(q) + // EXPLAIN [ANALYZE] returns a fixed single-column textual plan. We must NOT + // execute it to discover its schema: EXPLAIN ANALYZE runs the statement to + // gather statistics, so executing it here as a schema probe would run (and, + // for a write, mutate) the statement a second time on top of the real DoGet + // execution. Return a synthetic schema without executing. + if isExplainQuery(upper) { + name := "physical_plan" + if strings.Contains(upper, "ANALYZE") { + name = "analyzed_plan" + } + return arrow.NewSchema([]arrow.Field{ + {Name: name, Type: arrowmap.DuckDBTypeToArrow("VARCHAR"), Nullable: true}, + }, nil), nil + } // Only append LIMIT 0 for SELECT/WITH/VALUES/TABLE statements. // SHOW, DESCRIBE, EXPLAIN, PRAGMA, CALL etc. don't support LIMIT. if !strings.Contains(upper, "LIMIT") && arrowmap.SupportsLimit(upper) { diff --git a/duckdbservice/arrow_helpers_test.go b/duckdbservice/arrow_helpers_test.go index 5ad83195..1234afb9 100644 --- a/duckdbservice/arrow_helpers_test.go +++ b/duckdbservice/arrow_helpers_test.go @@ -1021,6 +1021,54 @@ func TestNestedTypesRoundTrip(t *testing.T) { } } +func TestGetQuerySchemaExplainDoesNotExecute(t *testing.T) { + // Regression: GetQuerySchema used to execute EXPLAIN ANALYZE to learn its + // schema, which for a write mutates — and DoGet then executes it again, + // double-inserting. EXPLAIN must now yield a synthetic schema without running. + db, err := sql.Open("duckdb", "") + if err != nil { + t.Fatalf("failed to open DuckDB: %v", err) + } + defer func() { _ = db.Close() }() + + if _, err := db.Exec("CREATE TABLE t (id INTEGER)"); err != nil { + t.Fatalf("create table: %v", err) + } + + cases := []struct { + name string + query string + wantCol string + }{ + {"explain select", "EXPLAIN SELECT 1", "physical_plan"}, + {"explain analyze insert", "EXPLAIN ANALYZE INSERT INTO t VALUES (1)", "analyzed_plan"}, + {"explain analyze parens", "EXPLAIN (ANALYZE) INSERT INTO t VALUES (2)", "analyzed_plan"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + schema, err := GetQuerySchema(context.Background(), db, tc.query, nil) + if err != nil { + t.Fatalf("GetQuerySchema(%q) error: %v", tc.query, err) + } + if schema.NumFields() != 1 { + t.Fatalf("GetQuerySchema(%q) = %d fields, want 1", tc.query, schema.NumFields()) + } + if got := schema.Field(0).Name; got != tc.wantCol { + t.Errorf("column name = %q, want %q", got, tc.wantCol) + } + }) + } + + // The schema probes above must NOT have inserted any rows. + var n int + if err := db.QueryRow("SELECT count(*) FROM t").Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 0 { + t.Errorf("EXPLAIN ANALYZE schema probe executed the write: %d rows, want 0", n) + } +} + func TestGetQuerySchemaTrailingSemicolon(t *testing.T) { // Regression test: queries ending with ";" caused "syntax error at or near LIMIT" // because GetQuerySchema appended " LIMIT 0" after the semicolon, producing "; LIMIT 0". diff --git a/server/conn.go b/server/conn.go index 6d06a419..93ca7407 100644 --- a/server/conn.go +++ b/server/conn.go @@ -93,6 +93,7 @@ type preparedStmt struct { cursorName string // Cursor name cursorQuery string // Transpiled inner SELECT (for DECLARE) fetchCount int64 // FETCH row count + warnings []string // Transpiler warnings to surface as NoticeResponse at Execute } type portal struct { @@ -450,6 +451,13 @@ func classifyErrorCode(err error) string { // classifiers below apply; otherwise every Iceberg/DuckLake worker error // would fall through to XX000 and the client would see the raw rpc string. msg := unwrapFlightError(err.Error()) + // Iceberg schema-evolution failure: after a column is dropped (or other + // schema churn) DuckDB's iceberg extension can fail every scan with + // "Tried to scan a snapshot created with a newer schema id ...". Surface it + // as feature_not_supported with an actionable message rather than XX000. + if isSchemaEvolutionErrorMsg(msg) { + return "0A000" + } switch { case strings.HasPrefix(msg, "Catalog Error:"): return catalogErrorCode(msg) @@ -496,6 +504,24 @@ func transformErrorSQLState(err error) string { return "42704" } +// isSchemaEvolutionErrorMsg reports whether an (already Flight-unwrapped) DuckDB +// error message is the Iceberg "newer schema id" scan failure that a DROP COLUMN +// can leave a table in. Matched on a stable substring; if the wording changes +// the error simply falls back to its generic class (still safe). +func isSchemaEvolutionErrorMsg(msg string) bool { + return strings.Contains(msg, "newer schema id") +} + +// friendlyExecError rewrites known-cryptic DuckDB/Iceberg execution errors into +// actionable client messages. Returns the original message unchanged otherwise. +func friendlyExecError(err error) string { + msg := err.Error() + if isSchemaEvolutionErrorMsg(unwrapFlightError(msg)) { + return "reading this table failed because a column was dropped: DROP COLUMN is not safely supported on this catalog (it can leave the table unreadable after schema changes); recreate the table instead" + } + return msg +} + // unwrapFlightError recovers the underlying error message from an Arrow Flight / // gRPC wrapper. gRPC errors stringify as "… rpc error: code = X desc = ", // and the control plane further prefixes worker failures with "flight execute @@ -1494,6 +1520,12 @@ func (c *clientConn) handleQuery(body []byte) error { return nil } + // Surface any transpiler warnings (e.g. an unenforced constraint stripped on a + // lake catalog) as NoticeResponse before the command result. + for _, w := range result.Warnings { + c.sendNotice("WARNING", "01000", w) + } + // Handle ignored SET parameters if result.IsIgnoredSet { slog.Debug("Ignoring PostgreSQL-specific SET.", "user", c.username, "query", query) @@ -1587,7 +1619,7 @@ func (c *clientConn) handleQuery(body []byte) error { } if err != nil { errCode := classifyErrorCode(err) - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errMsg = "canceling statement due to user request" } else { @@ -1675,7 +1707,7 @@ func (c *clientConn) executeQueryDirect(query, cmdType string) error { if err != nil { queryFinalErr = err errCode := classifyErrorCode(err) - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errMsg = "canceling statement due to user request" } else { @@ -1869,7 +1901,7 @@ func (c *clientConn) executeSelectQuery(query string, cmdType string) (int64, st if err != nil { queryFinalErr = err errCode := classifyErrorCode(err) - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errMsg = "canceling statement due to user request" } else { @@ -1887,7 +1919,7 @@ func (c *clientConn) executeSelectQuery(query string, cmdType string) (int64, st if err != nil { queryFinalErr = err errCode := "42000" - errMsg := err.Error() + errMsg := friendlyExecError(err) if !c.isCallerCancellation(err) { c.logQueryError(query, err) } @@ -1902,7 +1934,7 @@ func (c *clientConn) executeSelectQuery(query string, cmdType string) (int64, st if err != nil { queryFinalErr = err errCode := "42000" - errMsg := err.Error() + errMsg := friendlyExecError(err) if !c.isCallerCancellation(err) { c.logQueryError(query, err) } @@ -1947,7 +1979,7 @@ func (c *clientConn) executeSelectQuery(query string, cmdType string) (int64, st if err := rows.Scan(valuePtrs...); err != nil { queryFinalErr = err errCode := "42000" - errMsg := err.Error() + errMsg := friendlyExecError(err) if !c.isCallerCancellation(err) { c.logQueryError(query, err) } @@ -1972,7 +2004,7 @@ func (c *clientConn) executeSelectQuery(query string, cmdType string) (int64, st if err := rows.Err(); err != nil { queryFinalErr = err errCode := "42000" - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errCode = "57014" errMsg = "canceling statement due to user request" @@ -2294,7 +2326,7 @@ func (c *clientConn) executeSingleStatement(query string) (errSent bool, fatalEr if err != nil { queryFinalErr = err errCode := classifyErrorCode(err) - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errMsg = "canceling statement due to user request" } else { @@ -2346,7 +2378,7 @@ func (c *clientConn) executeSingleStatement(query string) (errSent bool, fatalEr if err != nil { queryFinalErr = err errCode := classifyErrorCode(err) - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errMsg = "canceling statement due to user request" } else { @@ -3127,6 +3159,41 @@ func isWithDML(query string) bool { strings.HasPrefix(outer, "DELETE") } +// isExplainStmt reports whether the query is an EXPLAIN statement (the EXPLAIN +// keyword followed by a space or '('). Used to avoid executing EXPLAIN at +// Describe time — EXPLAIN ANALYZE of a write mutates, and a describe-probe +// execution would run it a second time. +func isExplainStmt(query string) bool { + upper := strings.ToUpper(stripLeadingNoise(query)) + const kw = "EXPLAIN" + if !strings.HasPrefix(upper, kw) { + return false + } + if len(upper) == len(kw) { + return true + } + switch upper[len(kw)] { + case ' ', '\t', '\n', '\r', '(': + return true + } + return false +} + +// explainPlanColumn returns the single column name DuckDB uses for an EXPLAIN +// result: "analyzed_plan" for EXPLAIN ANALYZE, "physical_plan" otherwise. +func explainPlanColumn(query string) string { + if strings.Contains(strings.ToUpper(query), "ANALYZE") { + return "analyzed_plan" + } + return "physical_plan" +} + +// staticColumnType is a minimal ColumnTyper reporting a fixed DuckDB type name, +// used to synthesize a RowDescription without executing a query. +type staticColumnType string + +func (s staticColumnType) DatabaseTypeName() string { return string(s) } + // skipBalancedParens advances past a parenthesized group in an uppercased SQL // string. i must point to the character immediately after the opening '('. // It tracks paren depth while correctly skipping SQL constructs that may @@ -5338,6 +5405,7 @@ func (c *clientConn) handleParse(body []byte) { noOpTag: result.NoOpTag, statements: result.Statements, // Multi-statement rewrite (writable CTE) cleanupStatements: result.CleanupStatements, // Cleanup statements + warnings: result.Warnings, // Surfaced as NoticeResponse at Execute } slog.Debug("Prepared statement.", "user", c.username, "name", stmtName, "query", query) @@ -5535,6 +5603,16 @@ func (c *clientConn) handleDescribe(body []byte) { return } + // EXPLAIN [ANALYZE] returns a single textual plan column. Describing it via + // the LIMIT-0 probe below would EXECUTE it — and EXPLAIN ANALYZE of a write + // mutates — so the statement would run at Describe and again at Execute. + // Send a synthetic RowDescription without executing. + if isExplainStmt(ps.query) { + _ = c.sendRowDescription([]string{explainPlanColumn(ps.query)}, []ColumnTyper{staticColumnType("VARCHAR")}) + ps.described = true + return + } + // For SELECT, we need to describe the result columns // The cleanest approach is to add a "WHERE false" or "LIMIT 0" clause // to get column info without actually running the query @@ -5639,6 +5717,15 @@ func (c *clientConn) handleDescribe(body []byte) { return } + // EXPLAIN [ANALYZE]: synthesize the single plan column without executing + // (see the statement-Describe branch above). + if isExplainStmt(p.stmt.query) { + _ = c.sendRowDescriptionWithFormats([]string{explainPlanColumn(p.stmt.query)}, []ColumnTyper{staticColumnType("VARCHAR")}, p.resultFormats) + p.described = true + p.stmt.described = true + return + } + // For SELECT, we need to describe the result columns // We'll do a trial query with LIMIT 0 to get column info args, err := p.decodeParams() @@ -5772,6 +5859,12 @@ func (c *clientConn) handleExecute(body []byte) { slog.Debug("Execute portal.", "user", c.username, "portal", portalName, "params", len(args), "query", p.stmt.query) + // Surface any transpiler warnings (e.g. an unenforced constraint stripped on a + // lake catalog) as NoticeResponse before the command result. + for _, w := range p.stmt.warnings { + c.sendNotice("WARNING", "01000", w) + } + // Check if this is a PostgreSQL-specific SET command that should be ignored // (determined by transpiler during Parse) if p.stmt.isIgnoredSet { @@ -5861,7 +5954,7 @@ func (c *clientConn) handleExecute(body []byte) { if err != nil { queryFinalErr = err errCode := classifyErrorCode(err) - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errMsg = "canceling statement due to user request" } else { @@ -5913,7 +6006,7 @@ func (c *clientConn) handleExecute(body []byte) { if err != nil { queryFinalErr = err errCode := classifyErrorCode(err) - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errMsg = "canceling statement due to user request" } else { @@ -5987,7 +6080,7 @@ func (c *clientConn) handleExecute(body []byte) { if err := rows.Err(); err != nil { queryFinalErr = err errCode := "42000" - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errCode = "57014" errMsg = "canceling statement due to user request" @@ -6547,7 +6640,7 @@ func (c *clientConn) handleFetchCursor(query string, stmt *pg_query.FetchStmt) e if cursor.rows == nil { if err := c.openCursor(cursor); err != nil { errCode := "42000" - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errCode = "57014" errMsg = "canceling statement due to user request" @@ -6622,7 +6715,7 @@ func (c *clientConn) handleFetchCursor(query string, stmt *pg_query.FetchStmt) e if err := cursor.rows.Err(); err != nil { errCode := "42000" - errMsg := err.Error() + errMsg := friendlyExecError(err) if c.isCallerCancellation(err) { errCode = "57014" errMsg = "canceling statement due to user request" diff --git a/server/conn_describe_test.go b/server/conn_describe_test.go index c92bae31..f7a35214 100644 --- a/server/conn_describe_test.go +++ b/server/conn_describe_test.go @@ -96,6 +96,42 @@ func TestHandleDescribePortalUsesLimitZeroProbe(t *testing.T) { } } +func TestHandleDescribeExplainDoesNotExecute(t *testing.T) { + // Describing an EXPLAIN [ANALYZE] must NOT execute a probe query — for + // EXPLAIN ANALYZE of a write that would mutate, then Execute mutates again. + for _, tc := range []struct { + name string + query string + wantCol string + }{ + {"statement explain", "EXPLAIN SELECT 1", "physical_plan"}, + {"portal explain analyze", "EXPLAIN ANALYZE INSERT INTO t VALUES (1)", "analyzed_plan"}, + } { + t.Run(tc.name, func(t *testing.T) { + exec := &describeRecordingExecutor{} + var out bytes.Buffer + c := &clientConn{ + executor: exec, + writer: bufio.NewWriter(&out), + stmts: map[string]*preparedStmt{ + "s1": {query: tc.query, convertedQuery: tc.query}, + }, + portals: map[string]*portal{ + "p1": {stmt: &preparedStmt{query: tc.query, convertedQuery: tc.query}}, + }, + cursors: map[string]*cursorState{}, + } + + c.handleDescribe([]byte{'S', 's', '1', 0}) + c.handleDescribe([]byte{'P', 'p', '1', 0}) + + if len(exec.queries) != 0 { + t.Fatalf("EXPLAIN describe executed probe queries: %v", exec.queries) + } + }) + } +} + func TestHandleDescribePortalPreservesExistingLimit(t *testing.T) { exec := &describeRecordingExecutor{ rowSet: &describeStaticRowSet{ diff --git a/server/schema_evolution_error_test.go b/server/schema_evolution_error_test.go new file mode 100644 index 00000000..621f114b --- /dev/null +++ b/server/schema_evolution_error_test.go @@ -0,0 +1,44 @@ +package server + +import ( + "errors" + "strings" + "testing" +) + +// The Iceberg "newer schema id" scan failure (which a DROP COLUMN can leave a +// table in) must be classified as feature_not_supported (0A000) and rewritten +// into an actionable message, both when raw and when Flight-wrapped. +func TestSchemaEvolutionErrorMapping(t *testing.T) { + cases := []struct { + name string + err error + }{ + { + name: "raw", + err: errors.New("INTERNAL Error: Tried to scan a snapshot created with a newer schema id (1) than the schema id selected for the scan (0)"), + }, + { + name: "flight-wrapped", + err: errors.New("flight execute: rpc error: code = Internal desc = Invalid Input Error: Tried to scan a snapshot created with a newer schema id (1) than the schema id selected for the scan (0)"), + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := classifyErrorCode(tc.err); got != "0A000" { + t.Errorf("classifyErrorCode = %q, want 0A000", got) + } + msg := friendlyExecError(tc.err) + if !strings.Contains(msg, "DROP COLUMN is not safely supported") { + t.Errorf("friendlyExecError = %q, want the DROP COLUMN guidance", msg) + } + }) + } +} + +func TestFriendlyExecErrorPassesThroughUnrelated(t *testing.T) { + err := errors.New("Catalog Error: Table with name foo does not exist!") + if got := friendlyExecError(err); got != err.Error() { + t.Errorf("friendlyExecError rewrote an unrelated error: %q", got) + } +} diff --git a/transpiler/backend/profile.go b/transpiler/backend/profile.go index 00719a0d..662cdfe5 100644 --- a/transpiler/backend/profile.go +++ b/transpiler/backend/profile.go @@ -36,6 +36,18 @@ type DDLPolicy struct { UnsupportedDDL UnsupportedDDLHandling RewriteCascadeDrop bool SplitMultiAlter bool + + // WarnOnStrippedConstraints emits a NoticeResponse (WARNING) when an + // unenforceable constraint (PK/UNIQUE/CHECK/FK/EXCLUSION) is stripped, so + // the client knows the constraint is accepted-but-not-enforced rather than + // silently dropped. + WarnOnStrippedConstraints bool + + // ErrorOnSilentNullDefaults rejects (rather than silently stripping) the DDL + // features that otherwise produce silently-NULL data on a lake catalog: + // SERIAL/BIGSERIAL, GENERATED ... STORED, and non-literal DEFAULT expressions + // (including DEFAULT now()/CURRENT_TIMESTAMP and DEFAULT true/false). + ErrorOnSilentNullDefaults bool } func (p DDLPolicy) NeedsTransform() bool { @@ -44,7 +56,9 @@ func (p DDLPolicy) NeedsTransform() bool { p.StripVolatileDefaults || p.UnsupportedDDL == NoOpUnsupportedDDL || p.RewriteCascadeDrop || - p.SplitMultiAlter + p.SplitMultiAlter || + p.WarnOnStrippedConstraints || + p.ErrorOnSilentNullDefaults } type DMLPolicy struct { @@ -86,9 +100,13 @@ func (p Profile) Metadata() MetadataPolicy { func ForName(name Name) Profile { switch name { case DuckLake: - return lakeProfile(DuckLake, "ducklake", true) + // DuckLake keeps the historical silent-strip behavior (sqlmesh/dbt issue + // PK/serial/DEFAULT now() DDL and rely on it succeeding). + return lakeProfile(DuckLake, "ducklake", true, false) case Iceberg: - return lakeProfile(Iceberg, "iceberg", false) + // Iceberg surfaces the dropped Postgres semantics: WARNING for unenforced + // constraints, ERROR for silently-NULL data features. + return lakeProfile(Iceberg, "iceberg", false, true) default: return Profile{ name: Memory, @@ -102,7 +120,7 @@ func ForName(name Name) Profile { } } -func lakeProfile(name Name, physical string, mapPublicToMain bool) Profile { +func lakeProfile(name Name, physical string, mapPublicToMain, hybridDDLGuards bool) Profile { return Profile{ name: name, catalog: CatalogPolicy{ @@ -111,12 +129,14 @@ func lakeProfile(name Name, physical string, mapPublicToMain bool) Profile { QualifyMacros: true, }, ddl: DDLPolicy{ - ConstraintHandling: StripConstraints, - RewriteSerial: true, - StripVolatileDefaults: true, - UnsupportedDDL: NoOpUnsupportedDDL, - RewriteCascadeDrop: true, - SplitMultiAlter: true, + ConstraintHandling: StripConstraints, + RewriteSerial: true, + StripVolatileDefaults: true, + UnsupportedDDL: NoOpUnsupportedDDL, + RewriteCascadeDrop: true, + SplitMultiAlter: true, + WarnOnStrippedConstraints: hybridDDLGuards, + ErrorOnSilentNullDefaults: hybridDDLGuards, }, dml: DMLPolicy{ConflictHandling: RewriteToMerge}, metadata: MetadataPolicy{InterceptShowCreate: true}, diff --git a/transpiler/config.go b/transpiler/config.go index f3118784..3410f70a 100644 --- a/transpiler/config.go +++ b/transpiler/config.go @@ -98,4 +98,9 @@ type Result struct { // should be attempted directly against DuckDB. This enables two-tier query // processing where DuckDB-specific syntax works automatically. FallbackToNative bool + + // Warnings holds human-readable messages to surface to the client as + // NoticeResponse (WARNING) while still executing the statement (e.g. an + // unenforced constraint that was stripped on a lake catalog). + Warnings []string } diff --git a/transpiler/transform/ddl.go b/transpiler/transform/ddl.go index 1fffc04b..6e27ea4a 100644 --- a/transpiler/transform/ddl.go +++ b/transpiler/transform/ddl.go @@ -1,6 +1,7 @@ package transform import ( + "fmt" "strings" pg_query "github.com/pganalyze/pg_query_go/v6" @@ -35,9 +36,12 @@ func (t *DDLTransform) Transform(tree *pg_query.ParseResult, result *Result) (bo switch n := stmt.Stmt.Node.(type) { case *pg_query.Node_CreateStmt: if n.CreateStmt != nil { - if t.transformCreateStmt(n.CreateStmt) { + if t.transformCreateStmt(n.CreateStmt, result) { changed = true } + if result.Error != nil { + return true, nil + } } case *pg_query.Node_IndexStmt: @@ -140,8 +144,11 @@ func (t *DDLTransform) Transform(tree *pg_query.ParseResult, result *Result) (bo return changed, nil } -// transformCreateStmt modifies a CREATE TABLE statement for DuckLake compatibility -func (t *DDLTransform) transformCreateStmt(stmt *pg_query.CreateStmt) bool { +// transformCreateStmt modifies a CREATE TABLE statement for DuckLake compatibility. +// Unenforceable constraints (PK/UNIQUE/CHECK/FK) are stripped (with a WARNING when +// WarnOnStrippedConstraints is set); silently-NULL data features (SERIAL, GENERATED +// STORED, DEFAULT ) raise an error when ErrorOnSilentNullDefaults is set. +func (t *DDLTransform) transformCreateStmt(stmt *pg_query.CreateStmt, result *Result) bool { changed := false // Process column definitions @@ -151,9 +158,12 @@ func (t *DDLTransform) transformCreateStmt(stmt *pg_query.CreateStmt) bool { case *pg_query.Node_ColumnDef: if n.ColumnDef != nil { // Transform the column definition - if t.transformColumnDef(n.ColumnDef) { + if t.transformColumnDef(n.ColumnDef, result) { changed = true } + if result.Error != nil { + return changed + } newTableElts = append(newTableElts, elt) } case *pg_query.Node_Constraint: @@ -161,6 +171,7 @@ func (t *DDLTransform) transformCreateStmt(stmt *pg_query.CreateStmt) bool { if n.Constraint != nil { if t.policy.ConstraintHandling == backend.StripConstraints && t.isUnsupportedConstraint(n.Constraint) { changed = true + t.warnStrippedConstraint(result, n.Constraint, "") continue // Skip this constraint } newTableElts = append(newTableElts, elt) @@ -180,6 +191,7 @@ func (t *DDLTransform) transformCreateStmt(stmt *pg_query.CreateStmt) bool { newConstraints = append(newConstraints, c) } else { changed = true + t.warnStrippedConstraint(result, constraint, "") } } else { newConstraints = append(newConstraints, c) @@ -191,14 +203,26 @@ func (t *DDLTransform) transformCreateStmt(stmt *pg_query.CreateStmt) bool { return changed } -// transformColumnDef modifies a column definition for DuckLake compatibility -func (t *DDLTransform) transformColumnDef(col *pg_query.ColumnDef) bool { +// transformColumnDef modifies a column definition for DuckLake/Iceberg compatibility. +// SERIAL, GENERATED STORED, and DEFAULT would silently produce NULL data on a +// lake catalog; when ErrorOnSilentNullDefaults is set they raise a feature_not_supported +// error instead. Unenforceable column constraints (PK/UNIQUE/CHECK/FK) are stripped, with +// a WARNING when WarnOnStrippedConstraints is set. +func (t *DDLTransform) transformColumnDef(col *pg_query.ColumnDef, result *Result) bool { changed := false - // Convert SERIAL types to INTEGER types - if t.policy.RewriteSerial && col.TypeName != nil { - if t.convertSerialType(col.TypeName) { - changed = true + // SERIAL/BIGSERIAL: no backing sequence on a lake catalog -> ids silently NULL. + if col.TypeName != nil { + if serialName := serialTypeName(col.TypeName); serialName != "" { + if t.policy.ErrorOnSilentNullDefaults { + result.Error = NewFeatureNotSupported(fmt.Sprintf( + "%s is not supported on this catalog: there is no backing sequence, so generated ids would be silently NULL; use an explicit INTEGER/BIGINT column and supply values", + strings.ToUpper(serialName))) + return changed + } + if t.policy.RewriteSerial && t.convertSerialType(col.TypeName) { + changed = true + } } } @@ -206,27 +230,46 @@ func (t *DDLTransform) transformColumnDef(col *pg_query.ColumnDef) bool { if len(col.Constraints) > 0 { newConstraints := make([]*pg_query.Node, 0) for _, c := range col.Constraints { - if constraint := c.GetConstraint(); constraint != nil { - if t.policy.ConstraintHandling == backend.StripConstraints && t.isUnsupportedColumnConstraint(constraint) { + constraint := c.GetConstraint() + if constraint == nil { + newConstraints = append(newConstraints, c) + continue + } + // Unenforceable constraints (PK/UNIQUE/CHECK/FK/EXCLUSION): strip + warn. + if t.policy.ConstraintHandling == backend.StripConstraints && t.isUnsupportedColumnConstraint(constraint) { + changed = true + t.warnStrippedConstraint(result, constraint, col.Colname) + continue + } + // GENERATED ALWAYS AS (...) STORED: computed value would be silently NULL. + if constraint.Contype == pg_query.ConstrType_CONSTR_GENERATED { + if t.policy.ErrorOnSilentNullDefaults { + result.Error = NewFeatureNotSupported(fmt.Sprintf( + "GENERATED ALWAYS AS (...) STORED column %q is not supported on this catalog: the computed value would be silently NULL", + col.Colname)) + return changed + } + if t.policy.StripVolatileDefaults { changed = true continue } - // Check for DEFAULT now()/current_timestamp - if t.policy.StripVolatileDefaults && constraint.Contype == pg_query.ConstrType_CONSTR_DEFAULT { - if t.isUnsupportedDefault(constraint.RawExpr) { - changed = true - continue - } + } + // DEFAULT /now(): the default would be silently dropped (NULL). + // Literal int/float/string and DEFAULT NULL are not "unsupported" here and + // are passed through to the engine. + if constraint.Contype == pg_query.ConstrType_CONSTR_DEFAULT && t.isUnsupportedDefault(constraint.RawExpr) { + if t.policy.ErrorOnSilentNullDefaults { + result.Error = NewFeatureNotSupported(fmt.Sprintf( + "DEFAULT on column %q is not supported on this catalog: the default would be silently dropped and the column left NULL; supply the value explicitly on INSERT", + col.Colname)) + return changed } - // Check for GENERATED columns - if t.policy.StripVolatileDefaults && constraint.Contype == pg_query.ConstrType_CONSTR_GENERATED { + if t.policy.StripVolatileDefaults { changed = true continue } - newConstraints = append(newConstraints, c) - } else { - newConstraints = append(newConstraints, c) } + newConstraints = append(newConstraints, c) } col.Constraints = newConstraints } @@ -319,13 +362,18 @@ func (t *DDLTransform) isUnsupportedDefault(expr *pg_query.Node) bool { return t.isUnsupportedDefault(typeCast.Arg) } - // Check for A_Const nodes - only allow Integer and String + // Check for A_Const nodes - allow NULL and Integer/Float/String literals. if aconst := expr.GetAConst(); aconst != nil { + // DEFAULT NULL is fine: NULL is the implicit default anyway, so it is + // neither stripped nor an error (must not trip ErrorOnSilentNullDefaults). + if aconst.Isnull { + return false + } switch aconst.Val.(type) { case *pg_query.A_Const_Ival, *pg_query.A_Const_Fval, *pg_query.A_Const_Sval: return false // These are supported default: - return true // Booleans, NULLs, etc. are not supported + return true // Booleans, etc. are not supported } } @@ -351,14 +399,34 @@ func (t *DDLTransform) transformAlterTableStmt(stmt *pg_query.AlterTableStmt, re "ALTER COLUMN TYPE ... USING is not supported on this catalog") return false, nil } - if t.isUnsupportedAlterCommand(alterCmd) { + // ADD COLUMN carries a full ColumnDef: apply the same SERIAL/GENERATED/ + // DEFAULT-expr error and constraint-warn handling as CREATE TABLE. + if alterCmd.Subtype == pg_query.AlterTableType_AT_AddColumn { + if def := alterCmd.Def.GetColumnDef(); def != nil { + t.transformColumnDef(def, result) + if result.Error != nil { + return false, nil + } + } + // Make ADD COLUMN idempotent by adding IF NOT EXISTS. + // DuckLake reuses physical tables across sqlmesh snapshots, so + // columns may already exist when sqlmesh issues ALTER TABLE ADD COLUMN. + if !alterCmd.MissingOk { + alterCmd.MissingOk = true + } + supported = append(supported, cmd) continue } - // Make ADD COLUMN idempotent by adding IF NOT EXISTS. - // DuckLake reuses physical tables across sqlmesh snapshots, so - // columns may already exist when sqlmesh issues ALTER TABLE ADD COLUMN. - if alterCmd.Subtype == pg_query.AlterTableType_AT_AddColumn && !alterCmd.MissingOk { - alterCmd.MissingOk = true + // DROP COLUMN is allowed through, but warn: on the Iceberg catalog dropping + // a column after schema churn can make the table unreadable (see PR3 guard). + if alterCmd.Subtype == pg_query.AlterTableType_AT_DropColumn { + t.warnDropColumn(result) + supported = append(supported, cmd) + continue + } + if t.isUnsupportedAlterCommand(alterCmd) { + t.warnDroppedAlterCommand(result, alterCmd) + continue } supported = append(supported, cmd) } @@ -435,3 +503,90 @@ func (t *DDLTransform) isUnsupportedAlterCommand(cmd *pg_query.AlterTableCmd) bo } return false } + +// serialTypeName returns the lowercased serial type name (serial/bigserial/...) +// if the column type is a SERIAL pseudo-type, or "" otherwise. +func serialTypeName(typeName *pg_query.TypeName) string { + if typeName == nil || len(typeName.Names) == 0 { + return "" + } + var typeStr string + for _, name := range typeName.Names { + if str := name.GetString_(); str != nil { + typeStr = strings.ToLower(str.Sval) + break + } + } + switch typeStr { + case "serial", "serial4", "bigserial", "serial8", "smallserial", "serial2": + return typeStr + } + return "" +} + +// constraintLabel returns the SQL keyword for a constraint type, for messages. +func constraintLabel(ct pg_query.ConstrType) string { + switch ct { + case pg_query.ConstrType_CONSTR_PRIMARY: + return "PRIMARY KEY" + case pg_query.ConstrType_CONSTR_UNIQUE: + return "UNIQUE" + case pg_query.ConstrType_CONSTR_FOREIGN: + return "FOREIGN KEY" + case pg_query.ConstrType_CONSTR_CHECK: + return "CHECK" + case pg_query.ConstrType_CONSTR_EXCLUSION: + return "EXCLUSION" + default: + return "constraint" + } +} + +// warnStrippedConstraint records a WARNING that an unenforceable constraint was +// accepted-but-ignored, when the policy enables it. +func (t *DDLTransform) warnStrippedConstraint(result *Result, c *pg_query.Constraint, colName string) { + if !t.policy.WarnOnStrippedConstraints || result == nil { + return + } + label := constraintLabel(c.Contype) + if colName != "" { + result.Warnings = append(result.Warnings, fmt.Sprintf( + "%s constraint on column %q is not enforced on this catalog and was ignored", label, colName)) + } else { + result.Warnings = append(result.Warnings, fmt.Sprintf( + "%s constraint is not enforced on this catalog and was ignored", label)) + } +} + +// warnDroppedAlterCommand records a WARNING that an unsupported ALTER TABLE +// command was silently dropped, when the policy enables it. +func (t *DDLTransform) warnDroppedAlterCommand(result *Result, cmd *pg_query.AlterTableCmd) { + if !t.policy.WarnOnStrippedConstraints || result == nil { + return + } + var msg string + switch cmd.Subtype { + case pg_query.AlterTableType_AT_AddConstraint, + pg_query.AlterTableType_AT_ValidateConstraint, + pg_query.AlterTableType_AT_DropConstraint: + msg = "ALTER TABLE constraint changes are not enforced on this catalog and were ignored" + case pg_query.AlterTableType_AT_SetNotNull, + pg_query.AlterTableType_AT_DropNotNull: + msg = "ALTER TABLE SET/DROP NOT NULL is not supported on this catalog and was ignored" + case pg_query.AlterTableType_AT_ColumnDefault: + msg = "ALTER TABLE SET DEFAULT is not supported on this catalog and was ignored" + default: + return + } + result.Warnings = append(result.Warnings, msg) +} + +// warnDropColumn records a WARNING that DROP COLUMN may make existing data +// unreadable on the Iceberg catalog (see the runtime guard in server/conn.go). +func (t *DDLTransform) warnDropColumn(result *Result) { + if !t.policy.WarnOnStrippedConstraints || result == nil { + return + } + result.Warnings = append(result.Warnings, + "DROP COLUMN may make existing data unreadable on this catalog after schema changes; prefer recreating the table") +} diff --git a/transpiler/transform/ddl_test.go b/transpiler/transform/ddl_test.go new file mode 100644 index 00000000..e04b69ea --- /dev/null +++ b/transpiler/transform/ddl_test.go @@ -0,0 +1,198 @@ +package transform + +import ( + "strings" + "testing" + + pg_query "github.com/pganalyze/pg_query_go/v6" + "github.com/posthog/duckgres/transpiler/backend" +) + +// lakeDDLPolicy mirrors the lake (DuckLake/Iceberg) DDL policy: strip + hybrid +// warn/error. +func lakeDDLPolicy() backend.DDLPolicy { + return backend.ForName(backend.Iceberg).DDL() +} + +func runDDL(t *testing.T, policy backend.DDLPolicy, sql string) (*Result, bool) { + t.Helper() + tree, err := pg_query.Parse(sql) + if err != nil { + t.Fatalf("Parse(%q) error: %v", sql, err) + } + tr := NewDDLTransform(policy) + result := &Result{} + changed, err := tr.Transform(tree, result) + if err != nil { + t.Fatalf("Transform error: %v", err) + } + return result, changed +} + +func TestDDL_ConstraintsWarnNotError(t *testing.T) { + policy := lakeDDLPolicy() + cases := []struct { + name string + sql string + want string // substring expected in a warning + }{ + {"primary key", "CREATE TABLE t (id int PRIMARY KEY)", "PRIMARY KEY"}, + {"unique", "CREATE TABLE t (id int UNIQUE)", "UNIQUE"}, + {"check", "CREATE TABLE t (id int CHECK (id > 0))", "CHECK"}, + {"foreign key", "CREATE TABLE t (id int REFERENCES other(id))", "FOREIGN KEY"}, + {"table-level pk", "CREATE TABLE t (id int, PRIMARY KEY (id))", "PRIMARY KEY"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + result, _ := runDDL(t, policy, tc.sql) + if result.Error != nil { + t.Fatalf("expected no error, got %v", result.Error) + } + if len(result.Warnings) == 0 { + t.Fatalf("expected a warning, got none") + } + joined := strings.Join(result.Warnings, "\n") + if !strings.Contains(joined, tc.want) { + t.Errorf("warning %q does not contain %q", joined, tc.want) + } + if !strings.Contains(joined, "not enforced") { + t.Errorf("warning %q should explain the constraint is not enforced", joined) + } + }) + } +} + +func TestDDL_SilentNullFeaturesError(t *testing.T) { + policy := lakeDDLPolicy() + cases := []struct { + name string + sql string + }{ + {"serial", "CREATE TABLE t (id serial)"}, + {"bigserial", "CREATE TABLE t (id bigserial, v text)"}, + {"smallserial", "CREATE TABLE t (id smallserial)"}, + {"generated stored", "CREATE TABLE t (a int, b int GENERATED ALWAYS AS (a * 2) STORED)"}, + {"default now()", "CREATE TABLE t (id int, ts timestamp DEFAULT now())"}, + {"default current_timestamp", "CREATE TABLE t (id int, ts timestamp DEFAULT CURRENT_TIMESTAMP)"}, + {"default expression", "CREATE TABLE t (id int, n int DEFAULT (1 + 2))"}, + {"default boolean", "CREATE TABLE t (id int, b boolean DEFAULT true)"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + result, _ := runDDL(t, policy, tc.sql) + if result.Error == nil { + t.Fatalf("expected a feature-not-supported error, got none") + } + if got := transformErrSQLState(result.Error); got != "0A000" { + t.Errorf("SQLSTATE = %q, want 0A000 (err: %v)", got, result.Error) + } + }) + } +} + +func TestDDL_SafeDefaultsPreserved(t *testing.T) { + policy := lakeDDLPolicy() + cases := []struct { + name string + sql string + }{ + {"default null", "CREATE TABLE t (id int, n int DEFAULT NULL)"}, + {"default int literal", "CREATE TABLE t (id int, n int DEFAULT 5)"}, + {"default string literal", "CREATE TABLE t (id int, s text DEFAULT 'x')"}, + {"not null", "CREATE TABLE t (id int NOT NULL)"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + result, _ := runDDL(t, policy, tc.sql) + if result.Error != nil { + t.Fatalf("expected no error for %q, got %v", tc.sql, result.Error) + } + }) + } +} + +// TestDDL_NotNullPreserved locks the invariant that NOT NULL is not stripped on +// a lake catalog (it is the one constraint that is actually enforced). +func TestDDL_NotNullPreserved(t *testing.T) { + policy := lakeDDLPolicy() + tree, err := pg_query.Parse("CREATE TABLE t (id int NOT NULL)") + if err != nil { + t.Fatal(err) + } + tr := NewDDLTransform(policy) + result := &Result{} + if _, err := tr.Transform(tree, result); err != nil { + t.Fatal(err) + } + deparsed, err := pg_query.Deparse(tree) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(strings.ToUpper(deparsed), "NOT NULL") { + t.Errorf("NOT NULL should be preserved, got %q", deparsed) + } +} + +func TestDDL_MemoryProfileUnaffected(t *testing.T) { + policy := backend.ForName(backend.Memory).DDL() + result, _ := runDDL(t, policy, "CREATE TABLE t (id serial PRIMARY KEY, ts timestamp DEFAULT now())") + if result.Error != nil { + t.Errorf("memory profile should not error on PK/serial/default now(), got %v", result.Error) + } + if len(result.Warnings) != 0 { + t.Errorf("memory profile should not warn, got %v", result.Warnings) + } +} + +// TestDDL_DuckLakeStillStripsSilently locks that DuckLake keeps the historical +// silent-strip behavior (no warn/error) so sqlmesh/dbt-issued DDL still succeeds. +func TestDDL_DuckLakeStillStripsSilently(t *testing.T) { + policy := backend.ForName(backend.DuckLake).DDL() + result, _ := runDDL(t, policy, "CREATE TABLE t (id bigserial PRIMARY KEY, ts timestamp DEFAULT now())") + if result.Error != nil { + t.Errorf("DuckLake should not error on serial/default now(), got %v", result.Error) + } + if len(result.Warnings) != 0 { + t.Errorf("DuckLake should not warn, got %v", result.Warnings) + } +} + +func TestDDL_AlterAddColumnSilentNullErrors(t *testing.T) { + policy := lakeDDLPolicy() + cases := []string{ + "ALTER TABLE t ADD COLUMN id serial", + "ALTER TABLE t ADD COLUMN ts timestamp DEFAULT now()", + "ALTER TABLE t ADD COLUMN b int GENERATED ALWAYS AS (1) STORED", + } + for _, sql := range cases { + t.Run(sql, func(t *testing.T) { + result, _ := runDDL(t, policy, sql) + if result.Error == nil { + t.Fatalf("expected error for %q", sql) + } + if got := transformErrSQLState(result.Error); got != "0A000" { + t.Errorf("SQLSTATE = %q, want 0A000", got) + } + }) + } +} + +func TestDDL_AlterDropColumnWarns(t *testing.T) { + policy := lakeDDLPolicy() + result, _ := runDDL(t, policy, "ALTER TABLE t DROP COLUMN x") + if result.Error != nil { + t.Fatalf("DROP COLUMN should not error, got %v", result.Error) + } + if len(result.Warnings) == 0 || !strings.Contains(strings.Join(result.Warnings, "\n"), "DROP COLUMN") { + t.Errorf("expected a DROP COLUMN warning, got %v", result.Warnings) + } +} + +// transformErrSQLState extracts the SQLSTATE from a transform error. +func transformErrSQLState(err error) string { + type sqlStater interface{ SQLState() string } + if s, ok := err.(sqlStater); ok { + return s.SQLState() + } + return "" +} diff --git a/transpiler/transform/literals.go b/transpiler/transform/literals.go new file mode 100644 index 00000000..ea7babc7 --- /dev/null +++ b/transpiler/transform/literals.go @@ -0,0 +1,122 @@ +package transform + +import ( + "strings" + + pg_query "github.com/pganalyze/pg_query_go/v6" +) + +// LiteralTransform rewrites PostgreSQL literal syntaxes that DuckDB misparses: +// +// - bytea hex literals: '\xDEADBEEF'::bytea — DuckDB does not understand the +// Postgres \x hex-prefix in a string literal and drops the first byte. We +// rewrite to unhex('DEADBEEF'), which yields the correct BLOB bytes. +// - bit-string literals: B'101' — pg_query carries these as an A_Const bsval +// ("b101"); DuckDB turns the deparsed form into the string 'b101'. We rewrite +// to '101'::BIT. +// +// It walks every node and mutates matching literal nodes in place, so it must run +// BEFORE TypeMappingTransform (which would rewrite the `bytea` type name to +// `blob`, losing the signal this transform keys on). +type LiteralTransform struct{} + +func NewLiteralTransform() *LiteralTransform { return &LiteralTransform{} } + +func (t *LiteralTransform) Name() string { return "literals" } + +func (t *LiteralTransform) Transform(tree *pg_query.ParseResult, result *Result) (bool, error) { + changed := false + WalkFunc(tree, func(node *pg_query.Node) bool { + // bytea hex literal: TypeCast(arg=A_Const sval "\x..", type bytea) -> unhex('..') + if tc := node.GetTypeCast(); tc != nil && isByteaTypeName(tc.TypeName) { + if s, ok := stringConstValue(tc.Arg); ok && hasHexPrefix(s) { + node.Node = &pg_query.Node_FuncCall{FuncCall: unhexFuncCall(s[2:])} + changed = true + return true + } + } + // bit-string literal: A_Const{bsval:"b101"} -> '101'::BIT + if ac := node.GetAConst(); ac != nil { + if bs := ac.GetBsval(); bs != nil { + if cast := bitStringToCast(bs.Bsval); cast != nil { + node.Node = cast.Node + changed = true + return true + } + } + } + return true + }) + return changed, nil +} + +// isByteaTypeName reports whether a TypeName's final element is `bytea`. +func isByteaTypeName(tn *pg_query.TypeName) bool { + if tn == nil || len(tn.Names) == 0 { + return false + } + last := tn.Names[len(tn.Names)-1].GetString_() + return last != nil && strings.EqualFold(last.Sval, "bytea") +} + +// stringConstValue returns the string value of an A_Const string literal. +func stringConstValue(node *pg_query.Node) (string, bool) { + if node == nil { + return "", false + } + ac := node.GetAConst() + if ac == nil { + return "", false + } + if sv := ac.GetSval(); sv != nil { + return sv.Sval, true + } + return "", false +} + +// hasHexPrefix reports whether a string uses Postgres bytea hex format (\x...). +func hasHexPrefix(s string) bool { + return len(s) >= 2 && s[0] == '\\' && (s[1] == 'x' || s[1] == 'X') +} + +// unhexFuncCall builds unhex(''), which DuckDB evaluates to a BLOB. +func unhexFuncCall(hex string) *pg_query.FuncCall { + return &pg_query.FuncCall{ + Funcname: []*pg_query.Node{ + {Node: &pg_query.Node_String_{String_: &pg_query.String{Sval: "unhex"}}}, + }, + Args: []*pg_query.Node{strConst(hex)}, + } +} + +// bitStringToCast converts a pg_query bit-string value ("b101" / "B101") to a +// '101'::BIT TypeCast node. Returns nil for hex bit-strings (X'..') or anything +// it does not recognize, leaving the original node untouched. +func bitStringToCast(bsval string) *pg_query.Node { + if len(bsval) < 1 { + return nil + } + switch bsval[0] { + case 'b', 'B': + bits := bsval[1:] + if bits == "" { + return nil + } + return &pg_query.Node{Node: &pg_query.Node_TypeCast{TypeCast: &pg_query.TypeCast{ + Arg: strConst(bits), + TypeName: &pg_query.TypeName{ + Names: []*pg_query.Node{{Node: &pg_query.Node_String_{String_: &pg_query.String{Sval: "BIT"}}}}, + Typemod: -1, + }, + }}} + } + // Hex bit-strings (X'..') are left as-is (out of scope). + return nil +} + +// strConst builds an A_Const string-literal node. +func strConst(s string) *pg_query.Node { + return &pg_query.Node{Node: &pg_query.Node_AConst{AConst: &pg_query.A_Const{ + Val: &pg_query.A_Const_Sval{Sval: &pg_query.String{Sval: s}}, + }}} +} diff --git a/transpiler/transform/literals_test.go b/transpiler/transform/literals_test.go new file mode 100644 index 00000000..59a2d0fd --- /dev/null +++ b/transpiler/transform/literals_test.go @@ -0,0 +1,83 @@ +package transform + +import ( + "strings" + "testing" + + pg_query "github.com/pganalyze/pg_query_go/v6" +) + +func deparseAfter(t *testing.T, tr Transform, sql string) string { + t.Helper() + tree, err := pg_query.Parse(sql) + if err != nil { + t.Fatalf("Parse(%q): %v", sql, err) + } + if _, err := tr.Transform(tree, &Result{}); err != nil { + t.Fatalf("Transform(%q): %v", sql, err) + } + out, err := pg_query.Deparse(tree) + if err != nil { + t.Fatalf("Deparse(%q): %v", sql, err) + } + return out +} + +func TestLiteralTransform_ByteaHex(t *testing.T) { + tr := NewLiteralTransform() + out := deparseAfter(t, tr, `SELECT '\xDEADBEEF'::bytea`) + if !strings.Contains(strings.ToLower(out), "unhex('deadbeef')") { + t.Errorf("expected unhex('DEADBEEF'), got %q", out) + } + if strings.Contains(out, `\x`) { + t.Errorf("hex prefix should be gone, got %q", out) + } +} + +func TestLiteralTransform_ByteaEscapeLeftAlone(t *testing.T) { + tr := NewLiteralTransform() + // Non-hex bytea literal must be left untouched (no unhex rewrite). + out := deparseAfter(t, tr, `SELECT 'plain'::bytea`) + if strings.Contains(strings.ToLower(out), "unhex") { + t.Errorf("escape-format bytea should not be rewritten, got %q", out) + } +} + +func TestLiteralTransform_BitString(t *testing.T) { + tr := NewLiteralTransform() + out := deparseAfter(t, tr, `SELECT B'101'`) + up := strings.ToUpper(out) + if !strings.Contains(up, "'101'") || !strings.Contains(up, "BIT") { + t.Errorf("expected '101'::BIT, got %q", out) + } + if strings.Contains(strings.ToLower(out), "b101") { + t.Errorf("bit literal should not deparse to the string b101, got %q", out) + } +} + +func TestLiteralTransform_NonLiteralUnchanged(t *testing.T) { + tr := NewLiteralTransform() + out := deparseAfter(t, tr, `SELECT 1 + 2`) + if !strings.Contains(out, "1 + 2") { + t.Errorf("unrelated query changed: %q", out) + } +} + +func TestOperatorTransform_JSONBMerge(t *testing.T) { + tr := NewOperatorTransform() + out := deparseAfter(t, tr, `SELECT '{"a":1}'::jsonb || '{"b":2}'::jsonb`) + if !strings.Contains(strings.ToLower(out), "json_merge_patch") { + t.Errorf("expected json_merge_patch, got %q", out) + } +} + +func TestOperatorTransform_StringConcatUnchanged(t *testing.T) { + tr := NewOperatorTransform() + out := deparseAfter(t, tr, `SELECT 'a' || 'b'`) + if strings.Contains(strings.ToLower(out), "json_merge_patch") { + t.Errorf("plain string concat should not become json_merge_patch, got %q", out) + } + if !strings.Contains(out, "||") { + t.Errorf("string concat || should be preserved, got %q", out) + } +} diff --git a/transpiler/transform/operators.go b/transpiler/transform/operators.go index ce7efd9c..f87e1159 100644 --- a/transpiler/transform/operators.go +++ b/transpiler/transform/operators.go @@ -1,6 +1,8 @@ package transform import ( + "strings" + pg_query "github.com/pganalyze/pg_query_go/v6" ) @@ -288,6 +290,14 @@ func (t *OperatorTransform) transformExpression(node *pg_query.Node) *pg_query.N return t.createJsonExtractFuncCall(aexpr.Lexpr, aexpr.Rexpr, false) case "->>": return t.createJsonExtractFuncCall(aexpr.Lexpr, aexpr.Rexpr, true) + // jsonb || jsonb is an object merge in Postgres, but DuckDB treats || as + // string concatenation, silently producing invalid JSON. Rewrite to + // json_merge_patch only when an operand is clearly JSON; otherwise leave + // || alone (string/array concat is the safe default). + case "||": + if looksJSON(aexpr.Lexpr) || looksJSON(aexpr.Rexpr) { + return t.createJSONMergeFuncCall(aexpr.Lexpr, aexpr.Rexpr) + } // Regex operators — only match binary ~ (both operands present). // Unary ~ (bitwise NOT, e.g. ~id) has Lexpr=nil and must be left as-is; // DuckDB supports ~ as bitwise NOT natively. Passing nil into @@ -496,6 +506,55 @@ func (t *OperatorTransform) createJsonExtractFuncCall(left, right *pg_query.Node } } +// looksJSON reports whether a node is syntactically JSON: a cast to json/jsonb, +// or a json* function call (including the json_extract calls this transform +// produces from -> / ->>). Used to gate the || -> json_merge_patch rewrite so +// genuine string/array concatenation is left untouched. +func looksJSON(node *pg_query.Node) bool { + if node == nil { + return false + } + if tc := node.GetTypeCast(); tc != nil && tc.TypeName != nil && len(tc.TypeName.Names) > 0 { + if last := tc.TypeName.Names[len(tc.TypeName.Names)-1].GetString_(); last != nil { + switch strings.ToLower(last.Sval) { + case "json", "jsonb": + return true + } + } + } + if fc := node.GetFuncCall(); fc != nil && len(fc.Funcname) > 0 { + if last := fc.Funcname[len(fc.Funcname)-1].GetString_(); last != nil { + if strings.HasPrefix(strings.ToLower(last.Sval), "json") { + return true + } + } + } + return false +} + +// createJSONMergeFuncCall rewrites `a || b` to json_merge_patch(a, b). Note this +// is RFC 7396 merge-patch semantics (recursive, and a null value deletes a key), +// which matches Postgres jsonb || for the common flat-object-merge case but +// diverges for nested objects and explicit nulls. +func (t *OperatorTransform) createJSONMergeFuncCall(left, right *pg_query.Node) *pg_query.Node { + if newLeft := t.transformExpression(left); newLeft != nil { + left = newLeft + } + if newRight := t.transformExpression(right); newRight != nil { + right = newRight + } + return &pg_query.Node{ + Node: &pg_query.Node_FuncCall{ + FuncCall: &pg_query.FuncCall{ + Funcname: []*pg_query.Node{ + {Node: &pg_query.Node_String_{String_: &pg_query.String{Sval: "json_merge_patch"}}}, + }, + Args: []*pg_query.Node{left, right}, + }, + }, + } +} + // createRegexFuncCall creates a regexp_matches function call node // For negated operators, wraps in NOT func (t *OperatorTransform) createRegexFuncCall(left, right *pg_query.Node, caseInsensitive, negated bool) *pg_query.Node { diff --git a/transpiler/transform/transform.go b/transpiler/transform/transform.go index 35c10996..94def626 100644 --- a/transpiler/transform/transform.go +++ b/transpiler/transform/transform.go @@ -36,6 +36,12 @@ type Result struct { // for the final query but before streaming results. Typically DROP TEMP TABLE // and COMMIT statements. Execute these with best-effort (ignore errors). CleanupStatements []string + + // Warnings holds human-readable messages a transform wants surfaced to the + // client as NoticeResponse (WARNING) while still allowing the statement to + // run. Used e.g. when a constraint is accepted-but-not-enforced on a lake + // catalog so the client is told its Postgres semantics were dropped. + Warnings []string } // Transform defines the interface for SQL transformations. diff --git a/transpiler/transform/writablecte.go b/transpiler/transform/writablecte.go index 4726d601..f6464f36 100644 --- a/transpiler/transform/writablecte.go +++ b/transpiler/transform/writablecte.go @@ -308,7 +308,17 @@ func (t *WritableCTETransform) rewriteWritableCTE( quotedTempName := quoteIdentifier(tempName) if cte.isWrite { - // WRITABLE CTE: capture RETURNING output, then execute DML + // WRITABLE CTE: capture RETURNING output, then execute DML. + // The RETURNING capture runs as a SELECT of the rows BEFORE the + // UPDATE, so it yields pre-update (OLD) values. That is correct only + // for columns the UPDATE does not modify (the common "RETURNING id" + // row-identification pattern). If RETURNING reads a modified column, + // the captured value would be stale — reject rather than return wrong + // data. + if u := cte.node.Ctequery.GetUpdateStmt(); u != nil && updateReturningReadsModifiedColumn(u) { + return nil, NewFeatureNotSupported( + "UPDATE ... RETURNING a modified column inside a writable CTE is not supported on this catalog: it would return the pre-update value; return only unmodified columns (e.g. the key) or run the UPDATE without wrapping it in a CTE") + } // 2a. Generate SELECT that captures what RETURNING would produce returningSelect, err := t.generateReturningSelect(cte.node, tempTableNames) @@ -386,6 +396,53 @@ func (t *WritableCTETransform) generateReturningSelect( } } +// updateReturningReadsModifiedColumn reports whether an UPDATE's RETURNING list +// references, by name, a column the UPDATE modifies (so the pre-update capture +// would be stale). RETURNING * is exempt (see returningRefsColumn). Returns false +// when there is no RETURNING list or it reads only unmodified columns. +func updateReturningReadsModifiedColumn(u *pg_query.UpdateStmt) bool { + if u == nil || len(u.ReturningList) == 0 { + return false + } + setCols := map[string]bool{} + for _, tgt := range u.TargetList { + if rt := tgt.GetResTarget(); rt != nil && rt.Name != "" { + setCols[strings.ToLower(rt.Name)] = true + } + } + if len(setCols) == 0 { + return false + } + for _, ret := range u.ReturningList { + if returningRefsColumn(ret, setCols) { + return true + } + } + return false +} + +// returningRefsColumn reports whether a RETURNING item references, by name, any +// column in cols. RETURNING * is deliberately NOT treated as a hit: it is the +// common row-identification idiom (callers typically consume only the key), and +// rejecting it would break legitimate upsert patterns. The narrower risk that a +// caller reads a modified column out of a RETURNING * is left as a documented +// limitation. +func returningRefsColumn(ret *pg_query.Node, cols map[string]bool) bool { + found := false + WalkFunc(&pg_query.ParseResult{Stmts: []*pg_query.RawStmt{{Stmt: ret}}}, func(n *pg_query.Node) bool { + if cr := n.GetColumnRef(); cr != nil { + for _, f := range cr.Fields { + if s := f.GetString_(); s != nil && cols[strings.ToLower(s.Sval)] { + found = true + return false + } + } + } + return true + }) + return found +} + // updateToSelect converts UPDATE...FROM...WHERE...RETURNING to equivalent SELECT func (t *WritableCTETransform) updateToSelect( update *pg_query.UpdateStmt, diff --git a/transpiler/transform/writablecte_returning_test.go b/transpiler/transform/writablecte_returning_test.go new file mode 100644 index 00000000..4806e343 --- /dev/null +++ b/transpiler/transform/writablecte_returning_test.go @@ -0,0 +1,73 @@ +package transform + +import ( + "testing" + + pg_query "github.com/pganalyze/pg_query_go/v6" +) + +func transpileCTE(t *testing.T, sql string) *Result { + t.Helper() + tree, err := pg_query.Parse(sql) + if err != nil { + t.Fatalf("Parse(%q): %v", sql, err) + } + result := &Result{} + if _, err := NewWritableCTETransform().Transform(tree, result); err != nil { + t.Fatalf("Transform(%q): %v", sql, err) + } + return result +} + +func sqlState(err error) string { + if s, ok := err.(interface{ SQLState() string }); ok { + return s.SQLState() + } + return "" +} + +// UPDATE ... RETURNING a modified column inside a writable CTE must be rejected +// (the pre-update capture would return a stale value). +func TestWritableCTE_UpdateReturningModifiedColumnRejected(t *testing.T) { + r := transpileCTE(t, `WITH u AS (UPDATE t SET val = val + 1 WHERE id = 1 RETURNING id, val) SELECT * FROM u`) + if r.Error == nil { + t.Fatalf("expected rejection, got statements=%v", r.Statements) + } + if sqlState(r.Error) != "0A000" { + t.Errorf("SQLSTATE = %q, want 0A000", sqlState(r.Error)) + } +} + +// RETURNING * is exempt from the guard (common row-identification idiom; callers +// typically read only the key). It must still rewrite, not error. +func TestWritableCTE_UpdateReturningStarAllowed(t *testing.T) { + r := transpileCTE(t, `WITH u AS (UPDATE t SET val = 5 WHERE id = 1 RETURNING *) SELECT * FROM u`) + if r.Error != nil { + t.Fatalf("RETURNING * should be allowed, got %v", r.Error) + } + if len(r.Statements) == 0 { + t.Errorf("expected a multi-statement rewrite, got none") + } +} + +// The Airbyte-style pattern — RETURNING only an unmodified key — must keep working. +func TestWritableCTE_UpdateReturningUnmodifiedKeyAllowed(t *testing.T) { + r := transpileCTE(t, `WITH u AS (UPDATE t SET val = val + 1 WHERE id = 1 RETURNING id) SELECT * FROM u`) + if r.Error != nil { + t.Fatalf("RETURNING an unmodified key should be allowed, got %v", r.Error) + } + if len(r.Statements) == 0 { + t.Errorf("expected a multi-statement rewrite, got none") + } +} + +// DELETE ... RETURNING in a CTE is correct (OLD == the deleted rows) and allowed. +func TestWritableCTE_DeleteReturningAllowed(t *testing.T) { + r := transpileCTE(t, `WITH d AS (DELETE FROM t WHERE id = 1 RETURNING id, val) SELECT * FROM d`) + if r.Error != nil { + t.Fatalf("DELETE ... RETURNING should be allowed, got %v", r.Error) + } + if len(r.Statements) == 0 { + t.Errorf("expected a multi-statement rewrite, got none") + } +} diff --git a/transpiler/transpiler.go b/transpiler/transpiler.go index 348a3460..888333eb 100644 --- a/transpiler/transpiler.go +++ b/transpiler/transpiler.go @@ -34,6 +34,7 @@ const ( FlagCtid // ctid -> rowid mapping FlagDDL // DDL constraint stripping FlagPlaceholder // $1/$2 placeholder conversion + FlagLiterals // bytea hex / bit-string literal rewrites flagSentinel // must be last — used to derive FlagAll FlagAll TransformFlags = flagSentinel - 1 // All flags set @@ -97,6 +98,10 @@ func New(cfg Config) *Transpiler { // 3.2 Map logical database catalog references to the physical backend catalog t.transforms = append(t.transforms, taggedTransform{FlagLogicalCatalog, transform.NewLogicalCatalogTransform(cfg.LogicalDatabaseName, catalogPolicy.PhysicalName, catalogPolicy.MapPublicToMain)}) + // 3.3 Literal rewrites (bytea \x hex, bit-string B'..') — MUST run before + // TypeMapping so it still sees the `bytea` type name. + t.transforms = append(t.transforms, taggedTransform{FlagLiterals, transform.NewLiteralTransform()}) + // 4. Type mappings (JSONB->JSON, CHAR->TEXT, etc.) t.transforms = append(t.transforms, taggedTransform{FlagTypeMapping, transform.NewTypeMappingTransform()}) @@ -265,6 +270,7 @@ func (t *Transpiler) transpileWithFlags(sql string, flags TransformFlags) (*Resu Statements: restoreLongIdentifiersAll(transformResult.Statements, longIdents), CleanupStatements: restoreLongIdentifiersAll(transformResult.CleanupStatements, longIdents), ParamCount: transformResult.ParamCount, + Warnings: transformResult.Warnings, }, nil } @@ -277,6 +283,7 @@ func (t *Transpiler) transpileWithFlags(sql string, flags TransformFlags) (*Resu IsNoOp: transformResult.IsNoOp, NoOpTag: transformResult.NoOpTag, IsIgnoredSet: transformResult.IsIgnoredSet, + Warnings: transformResult.Warnings, }, nil } @@ -293,6 +300,7 @@ func (t *Transpiler) transpileWithFlags(sql string, flags TransformFlags) (*Resu IsNoOp: transformResult.IsNoOp, NoOpTag: transformResult.NoOpTag, IsIgnoredSet: transformResult.IsIgnoredSet, + Warnings: transformResult.Warnings, }, nil } @@ -309,6 +317,7 @@ func (t *Transpiler) transpileWithFlags(sql string, flags TransformFlags) (*Resu IsNoOp: transformResult.IsNoOp, NoOpTag: transformResult.NoOpTag, IsIgnoredSet: transformResult.IsIgnoredSet, + Warnings: transformResult.Warnings, }, nil } @@ -395,6 +404,14 @@ func Classify(sql string, cfg Config) Classification { flags |= FlagTypeMapping } + // Literal rewrites: bytea \x hex literals and B'..' bit-string literals. + // Over-triggering is safe — LiteralTransform is a no-op when nothing matches — + // but we must not match B' inside an identifier (e.g. "my.db'") or it would + // pull plain DuckDB statements out of the Tier-0 direct path. + if strings.Contains(upper, `'\X`) || containsBitStringLiteral(upper) { + flags |= FlagLiterals + } + // Type casts that need rewriting if containsAny(upper, "::REGTYPE", "::REGCLASS", "::REGNAMESPACE", "::REGPROC", "::OID") { flags |= FlagTypeCast @@ -458,7 +475,7 @@ func Classify(sql string, cfg Config) Classification { // This is acceptable: false positive just runs the operator transform (cheap), // while a false negative would break PostgreSQL regex queries that DuckDB // handles via regexp_matches rewrite. - if containsAny(upper, "->", "~") { + if containsAny(upper, "->", "~", "||") { flags |= FlagOperators } // Also check for SIMILAR TO which gets transformed through operators @@ -566,6 +583,27 @@ func containsAny(s string, substrs ...string) bool { return false } +// containsBitStringLiteral reports whether the (upper-cased) SQL contains a +// bit-string literal B'..'. The B must be a standalone token (not preceded by an +// identifier character) so it doesn't match B' inside names like "my.db'". +func containsBitStringLiteral(upper string) bool { + for i := 0; i+1 < len(upper); i++ { + if upper[i] == 'B' && upper[i+1] == '\'' { + if i == 0 || !isIdentChar(upper[i-1]) { + return true + } + } + } + return false +} + +func isIdentChar(b byte) bool { + return b == '_' || + (b >= 'A' && b <= 'Z') || + (b >= 'a' && b <= 'z') || + (b >= '0' && b <= '9') +} + // hasAnyPrefix returns true if s starts with any of the given prefixes. // The check is performed after trimming leading whitespace, line comments (-- ...), // and block comments (/* ... */). diff --git a/transpiler/transpiler_test.go b/transpiler/transpiler_test.go index e7d99461..97255165 100644 --- a/transpiler/transpiler_test.go +++ b/transpiler/transpiler_test.go @@ -369,20 +369,34 @@ func TestTranspile_PublicSchema_Iceberg(t *testing.T) { } func TestTranspile_DDL_Iceberg(t *testing.T) { - // The Iceberg backend uses the same DDL policy as DuckLake: strip enforced - // constraints, rewrite SERIAL, no-op unsupported DDL. + // The Iceberg backend strips unenforceable constraints (with a WARNING) and + // no-ops unsupported DDL, but ERRORs on silently-NULL data features (SERIAL, + // GENERATED STORED, DEFAULT ). tr := New(Config{Backend: BackendIceberg}) - t.Run("strip PRIMARY KEY and convert SERIAL", func(t *testing.T) { - result, err := tr.Transpile("CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)") + t.Run("strip PRIMARY KEY with a warning", func(t *testing.T) { + result, err := tr.Transpile("CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT)") if err != nil { t.Fatalf("Transpile error: %v", err) } + if result.Error != nil { + t.Fatalf("expected no error, got %v", result.Error) + } if strings.Contains(strings.ToUpper(result.SQL), "PRIMARY KEY") { t.Errorf("PRIMARY KEY not stripped: %q", result.SQL) } - if strings.Contains(strings.ToUpper(result.SQL), "SERIAL") { - t.Errorf("SERIAL not converted: %q", result.SQL) + if len(result.Warnings) == 0 { + t.Errorf("expected a PRIMARY KEY warning, got none") + } + }) + + t.Run("SERIAL is rejected", func(t *testing.T) { + result, err := tr.Transpile("CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)") + if err != nil { + t.Fatalf("Transpile error: %v", err) + } + if result.Error == nil { + t.Errorf("expected SERIAL to be rejected, got SQL=%q", result.SQL) } })