diff --git a/internal/migration_acceptance_tests/composite_type_cases_test.go b/internal/migration_acceptance_tests/composite_type_cases_test.go new file mode 100644 index 0000000..352c2fa --- /dev/null +++ b/internal/migration_acceptance_tests/composite_type_cases_test.go @@ -0,0 +1,248 @@ +package migration_acceptance_tests + +import ( + "testing" + + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var compositeTypeAcceptanceTestCases = []acceptanceTestCase{ + { + name: "no-op", + oldSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + `}, + newSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + `}, + expectEmptyPlan: true, + }, + { + name: "create composite type", + oldSchemaDDL: []string{}, + newSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + `}, + }, + { + name: "drop composite type", + oldSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + `}, + newSchemaDDL: []string{}, + }, + { + name: "drop nested composite types", + oldSchemaDDL: []string{` + CREATE TYPE inner_t AS (n int); + CREATE TYPE outer_t AS (i inner_t, label text); + `}, + newSchemaDDL: []string{}, + }, + { + name: "create composite type used by function", + oldSchemaDDL: []string{}, + newSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + CREATE FUNCTION mk_pair(x int, y text) RETURNS pair LANGUAGE sql AS 'SELECT (x, y)::pair'; + `}, + }, + { + name: "create schema-qualified composite type used by plpgsql function", + oldSchemaDDL: []string{}, + newSchemaDDL: []string{` + CREATE SCHEMA app; + CREATE TYPE app.result AS (status text, reason text); + CREATE FUNCTION app.resolve() RETURNS app.result LANGUAGE plpgsql AS $$ + DECLARE + v_result app.result; + BEGIN + SELECT ROW('ok', 'ready')::app.result INTO v_result; + RETURN v_result; + END + $$; + `}, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeHasUntrackableDependencies, + }, + }, + { + name: "create composite types before functions that use them in signatures", + oldSchemaDDL: []string{}, + newSchemaDDL: []string{` + CREATE SCHEMA casino_wager_stats; + CREATE TYPE casino_wager_stats.peak_source_row AS ( + id bigint, + round_id bigint + ); + CREATE TYPE casino_wager_stats.peak_candidate_row AS ( + metric_code text, + round_id bigint + ); + CREATE FUNCTION casino_wager_stats.project_peak_candidates(p_peak_sources casino_wager_stats.peak_source_row[]) + RETURNS SETOF casino_wager_stats.peak_candidate_row + LANGUAGE sql + STABLE + AS 'SELECT ''payout''::text AS metric_code, source$.round_id FROM pg_catalog.unnest(p_peak_sources) AS source$'; + CREATE FUNCTION casino_wager_stats.refresh_peaks(p_peak_sources casino_wager_stats.peak_source_row[]) + RETURNS void + LANGUAGE plpgsql + AS $$ + DECLARE + v_candidates casino_wager_stats.peak_candidate_row[]; + BEGIN + SELECT COALESCE(array_agg(candidate$), ARRAY[]::casino_wager_stats.peak_candidate_row[]) + INTO v_candidates + FROM casino_wager_stats.project_peak_candidates(p_peak_sources) candidate$; + END; + $$; + `}, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeHasUntrackableDependencies, + }, + expectedPlanDDL: []string{ + `CREATE SCHEMA "casino_wager_stats"`, + `CREATE TYPE "casino_wager_stats"."peak_candidate_row" AS ( + "metric_code" text COLLATE "pg_catalog"."default", + "round_id" bigint +)`, + `CREATE TYPE "casino_wager_stats"."peak_source_row" AS ( + "id" bigint, + "round_id" bigint +)`, + "CREATE OR REPLACE FUNCTION casino_wager_stats.project_peak_candidates(p_peak_sources casino_wager_stats.peak_source_row[])\n RETURNS SETOF casino_wager_stats.peak_candidate_row\n LANGUAGE sql\n STABLE\nAS $function$SELECT 'payout'::text AS metric_code, source$.round_id FROM pg_catalog.unnest(p_peak_sources) AS source$$function$\n", + "CREATE OR REPLACE FUNCTION casino_wager_stats.refresh_peaks(p_peak_sources casino_wager_stats.peak_source_row[])\n RETURNS void\n LANGUAGE plpgsql\nAS $function$\n\t\t\t\tDECLARE\n\t\t\t\t\tv_candidates casino_wager_stats.peak_candidate_row[];\n\t\t\t\tBEGIN\n\t\t\t\t\tSELECT COALESCE(array_agg(candidate$), ARRAY[]::casino_wager_stats.peak_candidate_row[])\n\t\t\t\t\tINTO v_candidates\n\t\t\t\t\tFROM casino_wager_stats.project_peak_candidates(p_peak_sources) candidate$;\n\t\t\t\tEND;\n\t\t\t\t$function$\n", + }, + }, + { + name: "drop composite type after dropping function that used it", + oldSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + CREATE FUNCTION mk_pair(x int, y text) RETURNS pair LANGUAGE sql AS 'SELECT (x, y)::pair'; + `}, + newSchemaDDL: []string{}, + }, + { + name: "create composite type with attributes that have collation", + oldSchemaDDL: []string{}, + newSchemaDDL: []string{` + CREATE TYPE labelled AS (id int, label text COLLATE "C"); + `}, + }, + { + name: "create nested composite types (one references the other)", + oldSchemaDDL: []string{}, + newSchemaDDL: []string{` + CREATE TYPE inner_t AS (n int); + CREATE TYPE outer_t AS (i inner_t, label text); + `}, + }, + // ─── Phase 2: drop+recreate cascade for function-only dependents ─── + { + name: "alter composite type attrs - cascade through dependent function", + oldSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + CREATE FUNCTION mk_pair(x int, y text) RETURNS pair LANGUAGE sql AS 'SELECT (x, y)::pair'; + `}, + newSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text, c boolean); + CREATE FUNCTION mk_pair(x int, y text, z boolean) RETURNS pair LANGUAGE sql AS 'SELECT (x, y, z)::pair'; + `}, + }, + { + name: "alter composite type attrs - cascade through dependent procedure", + oldSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + CREATE PROCEDURE use_pair(p pair) LANGUAGE plpgsql AS $$ BEGIN END $$; + `}, + newSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text, c boolean); + CREATE PROCEDURE use_pair(p pair) LANGUAGE plpgsql AS $$ BEGIN END $$; + `}, + // Procedures always carry the untrackable-deps hazard regardless of the + // underlying composite-type recreation; pg-schema-diff cannot follow plpgsql + // body references through pg_depend. + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeHasUntrackableDependencies, + }, + }, + { + name: "alter composite type attrs - cascade through multiple dependent functions", + oldSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + CREATE FUNCTION f_a(p pair) RETURNS int LANGUAGE sql AS 'SELECT (p).a'; + CREATE FUNCTION f_b(p pair) RETURNS text LANGUAGE sql AS 'SELECT (p).b'; + `}, + newSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text, c boolean); + CREATE FUNCTION f_a(p pair) RETURNS int LANGUAGE sql AS 'SELECT (p).a'; + CREATE FUNCTION f_b(p pair) RETURNS text LANGUAGE sql AS 'SELECT (p).b'; + `}, + }, + { + name: "alter composite type attrs - cascade through dependent function using array argument", + oldSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + CREATE FUNCTION f_items(p pair[]) RETURNS int LANGUAGE sql AS 'SELECT pg_catalog.cardinality(p)'; + `}, + newSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text, c boolean); + CREATE FUNCTION f_items(p pair[]) RETURNS int LANGUAGE sql AS 'SELECT pg_catalog.cardinality(p)'; + `}, + }, + { + name: "alter composite type attrs - cascade through dependent composite type and function", + oldSchemaDDL: []string{` + CREATE TYPE inner_t AS (n int); + CREATE TYPE outer_t AS (i inner_t, label text); + CREATE FUNCTION f_outer(p outer_t) RETURNS int LANGUAGE sql AS 'SELECT ((p).i).n'; + `}, + newSchemaDDL: []string{` + CREATE TYPE inner_t AS (n int, extra text); + CREATE TYPE outer_t AS (i inner_t, label text); + CREATE FUNCTION f_outer(p outer_t) RETURNS int LANGUAGE sql AS 'SELECT ((p).i).n'; + `}, + }, + { + name: "alter composite type attrs is unsupported when used by a table column", + oldSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + CREATE TABLE users (id int, attrs pair); + `}, + newSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text, c boolean); + CREATE TABLE users (id int, attrs pair); + `}, + expectedPlanErrorIs: diff.ErrNotImplemented, + }, + { + name: "alter composite type attrs is unsupported when dependent composite type is used by a table column", + oldSchemaDDL: []string{` + CREATE TYPE inner_t AS (n int); + CREATE TYPE outer_t AS (i inner_t, label text); + CREATE TABLE users (id int, attrs outer_t); + `}, + newSchemaDDL: []string{` + CREATE TYPE inner_t AS (n int, extra text); + CREATE TYPE outer_t AS (i inner_t, label text); + CREATE TABLE users (id int, attrs outer_t); + `}, + expectedPlanErrorIs: diff.ErrNotImplemented, + }, + { + name: "alter composite type attrs is unsupported when used by a table array column", + oldSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text); + CREATE TABLE users (id int, attrs pair[]); + `}, + newSchemaDDL: []string{` + CREATE TYPE pair AS (a int, b text, c boolean); + CREATE TABLE users (id int, attrs pair[]); + `}, + expectedPlanErrorIs: diff.ErrNotImplemented, + }, +} + +func TestCompositeTypeTestCases(t *testing.T) { + runTestCases(t, compositeTypeAcceptanceTestCases) +} diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql index d8397c8..695b523 100644 --- a/internal/queries/queries.sql +++ b/internal/queries/queries.sql @@ -289,6 +289,64 @@ WHERE AND depend.deptype = 'e' ); +-- name: GetDependsOnCompositeTypes :many +-- Returns the composite types (typtype = 'c') that the given object depends on. +-- This includes dependencies through PostgreSQL's automatically-created array +-- type for a composite type, e.g. `some_type[]`. +-- Used to drive cascading drop+recreate of functions and procedures when the +-- attribute list of a composite type changes. +SELECT DISTINCT + pg_type.typname::TEXT AS type_name, + type_namespace.nspname::TEXT AS type_schema_name +FROM pg_catalog.pg_depend AS depend +INNER JOIN pg_catalog.pg_type AS referenced_type + ON + depend.refclassid = 'pg_type'::REGCLASS + AND depend.refobjid = referenced_type.oid +INNER JOIN pg_catalog.pg_type AS pg_type + ON + ( + referenced_type.oid = pg_type.oid + OR referenced_type.typelem = pg_type.oid + ) + AND pg_type.typtype = 'c' +INNER JOIN + pg_catalog.pg_namespace AS type_namespace + ON pg_type.typnamespace = type_namespace.oid +INNER JOIN pg_catalog.pg_class AS rel + ON pg_type.typrelid = rel.oid AND rel.relkind = 'c' +WHERE + depend.classid = sqlc.arg(system_catalog)::REGCLASS + AND depend.objid = sqlc.arg(object_id) + AND depend.deptype = 'n'; + +-- name: GetCompositeTypeTableConsumers :many +-- Returns the tables (relkind in r,p) whose columns are typed by the given +-- composite type. Used to refuse a `CREATE TYPE` attribute change when a +-- table column depends on the type — recreating the type would require +-- rewriting the table, which is out of scope for this generator. +SELECT + consumer_c.relname::TEXT AS table_name, + consumer_ns.nspname::TEXT AS table_schema_name +FROM pg_catalog.pg_attribute AS att +INNER JOIN pg_catalog.pg_class AS consumer_c + ON + att.attrelid = consumer_c.oid + AND consumer_c.relkind IN ('r', 'p') +INNER JOIN pg_catalog.pg_namespace AS consumer_ns + ON consumer_c.relnamespace = consumer_ns.oid +WHERE + ( + att.atttypid = sqlc.arg(type_oid) + OR att.atttypid = ( + SELECT typarray + FROM pg_catalog.pg_type + WHERE oid = sqlc.arg(type_oid) + ) + ) + AND att.attnum > 0 + AND NOT att.attisdropped; + -- name: GetDependsOnFunctions :many SELECT pg_proc.proname::TEXT AS func_name, @@ -401,6 +459,60 @@ WHERE AND extension_namespace.nspname !~ '^pg_temp'; +-- name: GetCompositeTypes :many +-- Returns one row per (composite type, attribute) pair, ordered so that the consumer +-- can rebuild attribute lists in their declared order. Types with zero attributes still +-- get a single row with attribute_name = '' so the type itself is not lost. +SELECT + pg_type.oid AS type_oid, + rel.oid AS type_rel_oid, + pg_type.typname::TEXT AS type_name, + type_namespace.nspname::TEXT AS type_schema_name, + COALESCE(att.attname, '')::TEXT AS attribute_name, + COALESCE( + pg_catalog.format_type(att.atttypid, att.atttypmod), '' + )::TEXT AS attribute_type, + COALESCE(coll.collname, '')::TEXT AS collation_name, + COALESCE(coll_ns.nspname, '')::TEXT AS collation_schema_name +FROM pg_catalog.pg_type AS pg_type +INNER JOIN + pg_catalog.pg_namespace AS type_namespace + ON pg_type.typnamespace = type_namespace.oid +INNER JOIN + pg_catalog.pg_class AS rel + -- A user-defined composite type's underlying class has relkind = 'c'. Implicit + -- row types created for tables/views/sequences have relkind in ('r','p','v','m','S') + -- and must be excluded. + ON pg_type.typrelid = rel.oid AND rel.relkind = 'c' +LEFT JOIN + pg_catalog.pg_attribute AS att + ON + att.attrelid = rel.oid + AND att.attnum > 0 + AND NOT att.attisdropped +LEFT JOIN + pg_catalog.pg_collation AS coll + ON att.attcollation = coll.oid +LEFT JOIN + pg_catalog.pg_namespace AS coll_ns + ON coll.collnamespace = coll_ns.oid +WHERE + pg_type.typtype = 'c' + AND type_namespace.nspname NOT IN ('pg_catalog', 'information_schema') + AND type_namespace.nspname !~ '^pg_toast' + AND type_namespace.nspname !~ '^pg_temp' + -- Exclude composite types belonging to extensions + AND NOT EXISTS ( + SELECT ext_depend.objid + FROM pg_catalog.pg_depend AS ext_depend + WHERE + ext_depend.classid = 'pg_type'::REGCLASS + AND ext_depend.objid = pg_type.oid + AND ext_depend.deptype = 'e' + ) +ORDER BY pg_type.oid, att.attnum; + + -- name: GetEnums :many SELECT pg_type.typname::TEXT AS enum_name, diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go index 4315102..e13a032 100644 --- a/internal/queries/queries.sql.go +++ b/internal/queries/queries.sql.go @@ -223,6 +223,224 @@ func (q *Queries) GetColumnsForTable(ctx context.Context, attrelid interface{}) return items, nil } +const getCompositeTypeTableConsumers = `-- name: GetCompositeTypeTableConsumers :many +SELECT + consumer_c.relname::TEXT AS table_name, + consumer_ns.nspname::TEXT AS table_schema_name +FROM pg_catalog.pg_attribute AS att +INNER JOIN pg_catalog.pg_class AS consumer_c + ON + att.attrelid = consumer_c.oid + AND consumer_c.relkind IN ('r', 'p') +INNER JOIN pg_catalog.pg_namespace AS consumer_ns + ON consumer_c.relnamespace = consumer_ns.oid +WHERE + ( + att.atttypid = $1 + OR att.atttypid = ( + SELECT typarray + FROM pg_catalog.pg_type + WHERE oid = $1 + ) + ) + AND att.attnum > 0 + AND NOT att.attisdropped +` + +type GetCompositeTypeTableConsumersRow struct { + TableName string + TableSchemaName string +} + +// Returns the tables (relkind in r,p) whose columns are typed by the given +// composite type. Used to refuse a `CREATE TYPE` attribute change when a +// table column depends on the type — recreating the type would require +// rewriting the table, which is out of scope for this generator. +func (q *Queries) GetCompositeTypeTableConsumers(ctx context.Context, typeOid interface{}) ([]GetCompositeTypeTableConsumersRow, error) { + rows, err := q.db.QueryContext(ctx, getCompositeTypeTableConsumers, typeOid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetCompositeTypeTableConsumersRow + for rows.Next() { + var i GetCompositeTypeTableConsumersRow + if err := rows.Scan(&i.TableName, &i.TableSchemaName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getCompositeTypes = `-- name: GetCompositeTypes :many +SELECT + pg_type.oid AS type_oid, + rel.oid AS type_rel_oid, + pg_type.typname::TEXT AS type_name, + type_namespace.nspname::TEXT AS type_schema_name, + COALESCE(att.attname, '')::TEXT AS attribute_name, + COALESCE( + pg_catalog.format_type(att.atttypid, att.atttypmod), '' + )::TEXT AS attribute_type, + COALESCE(coll.collname, '')::TEXT AS collation_name, + COALESCE(coll_ns.nspname, '')::TEXT AS collation_schema_name +FROM pg_catalog.pg_type AS pg_type +INNER JOIN + pg_catalog.pg_namespace AS type_namespace + ON pg_type.typnamespace = type_namespace.oid +INNER JOIN + pg_catalog.pg_class AS rel + -- A user-defined composite type's underlying class has relkind = 'c'. Implicit + -- row types created for tables/views/sequences have relkind in ('r','p','v','m','S') + -- and must be excluded. + ON pg_type.typrelid = rel.oid AND rel.relkind = 'c' +LEFT JOIN + pg_catalog.pg_attribute AS att + ON + att.attrelid = rel.oid + AND att.attnum > 0 + AND NOT att.attisdropped +LEFT JOIN + pg_catalog.pg_collation AS coll + ON att.attcollation = coll.oid +LEFT JOIN + pg_catalog.pg_namespace AS coll_ns + ON coll.collnamespace = coll_ns.oid +WHERE + pg_type.typtype = 'c' + AND type_namespace.nspname NOT IN ('pg_catalog', 'information_schema') + AND type_namespace.nspname !~ '^pg_toast' + AND type_namespace.nspname !~ '^pg_temp' + -- Exclude composite types belonging to extensions + AND NOT EXISTS ( + SELECT ext_depend.objid + FROM pg_catalog.pg_depend AS ext_depend + WHERE + ext_depend.classid = 'pg_type'::REGCLASS + AND ext_depend.objid = pg_type.oid + AND ext_depend.deptype = 'e' + ) +ORDER BY pg_type.oid, att.attnum +` + +type GetCompositeTypesRow struct { + TypeOid interface{} + TypeRelOid interface{} + TypeName string + TypeSchemaName string + AttributeName string + AttributeType string + CollationName string + CollationSchemaName string +} + +// Returns one row per (composite type, attribute) pair, ordered so that the consumer +// can rebuild attribute lists in their declared order. Types with zero attributes still +// get a single row with attribute_name = ” so the type itself is not lost. +func (q *Queries) GetCompositeTypes(ctx context.Context) ([]GetCompositeTypesRow, error) { + rows, err := q.db.QueryContext(ctx, getCompositeTypes) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetCompositeTypesRow + for rows.Next() { + var i GetCompositeTypesRow + if err := rows.Scan( + &i.TypeOid, + &i.TypeRelOid, + &i.TypeName, + &i.TypeSchemaName, + &i.AttributeName, + &i.AttributeType, + &i.CollationName, + &i.CollationSchemaName, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getDependsOnCompositeTypes = `-- name: GetDependsOnCompositeTypes :many +SELECT DISTINCT + pg_type.typname::TEXT AS type_name, + type_namespace.nspname::TEXT AS type_schema_name +FROM pg_catalog.pg_depend AS depend +INNER JOIN pg_catalog.pg_type AS referenced_type + ON + depend.refclassid = 'pg_type'::REGCLASS + AND depend.refobjid = referenced_type.oid +INNER JOIN pg_catalog.pg_type AS pg_type + ON + ( + referenced_type.oid = pg_type.oid + OR referenced_type.typelem = pg_type.oid + ) + AND pg_type.typtype = 'c' +INNER JOIN + pg_catalog.pg_namespace AS type_namespace + ON pg_type.typnamespace = type_namespace.oid +INNER JOIN pg_catalog.pg_class AS rel + ON pg_type.typrelid = rel.oid AND rel.relkind = 'c' +WHERE + depend.classid = $1::REGCLASS + AND depend.objid = $2 + AND depend.deptype = 'n' +` + +type GetDependsOnCompositeTypesParams struct { + SystemCatalog interface{} + ObjectID interface{} +} + +type GetDependsOnCompositeTypesRow struct { + TypeName string + TypeSchemaName string +} + +// Returns the composite types (typtype = 'c') that the given object depends on. +// This includes dependencies through PostgreSQL's automatically-created array +// type for a composite type, e.g. `some_type[]`. +// Used to drive cascading drop+recreate of functions and procedures when the +// attribute list of a composite type changes. +func (q *Queries) GetDependsOnCompositeTypes(ctx context.Context, arg GetDependsOnCompositeTypesParams) ([]GetDependsOnCompositeTypesRow, error) { + rows, err := q.db.QueryContext(ctx, getDependsOnCompositeTypes, arg.SystemCatalog, arg.ObjectID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetDependsOnCompositeTypesRow + for rows.Next() { + var i GetDependsOnCompositeTypesRow + if err := rows.Scan(&i.TypeName, &i.TypeSchemaName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getDependsOnFunctions = `-- name: GetDependsOnFunctions :many SELECT pg_proc.proname::TEXT AS func_name, diff --git a/internal/schema/schema.go b/internal/schema/schema.go index b3945d5..8ae1466 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -55,6 +55,7 @@ type Schema struct { NamedSchemas []NamedSchema Extensions []Extension Enums []Enum + CompositeTypes []CompositeType Tables []Table Indexes []Index ForeignKeyConstraints []ForeignKeyConstraint @@ -73,6 +74,15 @@ func (s Schema) Normalize() Schema { s.Extensions = sortSchemaObjectsByName(s.Extensions) s.Enums = sortSchemaObjectsByName(s.Enums) + // Composite type attribute order is meaningful (it determines the layout of every value + // of that type), so do NOT sort attributes — only sort the types themselves. + var normCompositeTypes []CompositeType + for _, compositeType := range sortSchemaObjectsByName(s.CompositeTypes) { + compositeType.DependsOnCompositeTypes = sortSchemaObjectsByName(compositeType.DependsOnCompositeTypes) + normCompositeTypes = append(normCompositeTypes, compositeType) + } + s.CompositeTypes = normCompositeTypes + var normTables []Table for _, t := range sortSchemaObjectsByName(s.Tables) { normTables = append(normTables, normalizeTable(t)) @@ -86,6 +96,7 @@ func (s Schema) Normalize() Schema { var normFunctions []Function for _, function := range sortSchemaObjectsByName(s.Functions) { function.DependsOnFunctions = sortSchemaObjectsByName(function.DependsOnFunctions) + function.DependsOnCompositeTypes = sortSchemaObjectsByName(function.DependsOnCompositeTypes) normFunctions = append(normFunctions, function) } s.Functions = normFunctions @@ -212,6 +223,34 @@ type Enum struct { Labels []string } +// CompositeTypeAttribute represents a single attribute (field) of a composite type. +type CompositeTypeAttribute struct { + Name string + // Type is the formatted type, as returned by `pg_catalog.format_type` (e.g. `integer`, `numeric(10,2)`, + // `text[]`, or another schema-qualified composite type). + Type string + Collation SchemaQualifiedName +} + +func (a CompositeTypeAttribute) GetName() string { + return a.Name +} + +// CompositeType represents a user-defined composite type (`CREATE TYPE foo AS (a int, b text)`). +// It does NOT include the implicit row types created for tables/views — those have a backing +// pg_class entry of relkind != 'c' and are filtered out by GetCompositeTypes. +type CompositeType struct { + SchemaQualifiedName + Attributes []CompositeTypeAttribute + // DependsOnCompositeTypes is the list of user-defined composite types referenced + // by this composite type's attributes, including references through array types. + DependsOnCompositeTypes []SchemaQualifiedName + // IsUsedByTable is true iff at least one table column has this composite type as its + // declared type. When true, attribute-level changes to the type are unsupported by the + // diff generator (recreating the type would require rewriting every consumer table). + IsUsedByTable bool +} + type Table struct { SchemaQualifiedName Columns []Column @@ -449,6 +488,11 @@ type Function struct { // can track the dependencies of the function (or not) Language string DependsOnFunctions []SchemaQualifiedName + // DependsOnCompositeTypes is the list of user-defined composite types referenced + // (by argument, return, or body resolution) by this function. When any of those + // types' attributes change, this function must be dropped and recreated alongside + // the type recreation. + DependsOnCompositeTypes []SchemaQualifiedName } type Procedure struct { @@ -457,6 +501,8 @@ type Procedure struct { // the procedure, as returned by `pg_get_functiondef`. It is a CREATE OR REPLACE // statement. Def string + // DependsOnCompositeTypes — see Function.DependsOnCompositeTypes. + DependsOnCompositeTypes []SchemaQualifiedName } var ( @@ -702,6 +748,13 @@ func (s *schemaFetcher) getSchema(ctx context.Context) (Schema, error) { return Schema{}, fmt.Errorf("starting enums future: %w", err) } + compositeTypesFuture, err := concurrent.SubmitFuture(ctx, goroutineRunner, func() ([]CompositeType, error) { + return s.fetchCompositeTypes(ctx) + }) + if err != nil { + return Schema{}, fmt.Errorf("starting composite types future: %w", err) + } + tablesFuture, err := concurrent.SubmitFuture(ctx, goroutineRunner, func() ([]Table, error) { return s.fetchTables(ctx) }) @@ -780,6 +833,11 @@ func (s *schemaFetcher) getSchema(ctx context.Context) (Schema, error) { return Schema{}, fmt.Errorf("getting enums: %w", err) } + compositeTypes, err := compositeTypesFuture.Get(ctx) + if err != nil { + return Schema{}, fmt.Errorf("getting composite types: %w", err) + } + tables, err := tablesFuture.Get(ctx) if err != nil { return Schema{}, fmt.Errorf("getting tables: %w", err) @@ -829,6 +887,7 @@ func (s *schemaFetcher) getSchema(ctx context.Context) (Schema, error) { NamedSchemas: schemas, Extensions: extensions, Enums: enums, + CompositeTypes: compositeTypes, Tables: tables, Indexes: indexes, ForeignKeyConstraints: fkCons, @@ -924,6 +983,80 @@ func (s *schemaFetcher) fetchEnums(ctx context.Context) ([]Enum, error) { return enums, nil } +func (s *schemaFetcher) fetchCompositeTypes(ctx context.Context) ([]CompositeType, error) { + rawAttrs, err := s.q.GetCompositeTypes(ctx) + if err != nil { + return nil, fmt.Errorf("GetCompositeTypes: %w", err) + } + + type ctWithOid struct { + typeOid interface{} + relOid interface{} + ct *CompositeType + } + byOid := make(map[interface{}]*CompositeType) + var ordered []ctWithOid + for _, row := range rawAttrs { + ct, ok := byOid[row.TypeOid] + if !ok { + ct = &CompositeType{ + SchemaQualifiedName: SchemaQualifiedName{ + SchemaName: row.TypeSchemaName, + EscapedName: EscapeIdentifier(row.TypeName), + }, + } + byOid[row.TypeOid] = ct + ordered = append(ordered, ctWithOid{typeOid: row.TypeOid, relOid: row.TypeRelOid, ct: ct}) + } + // rawAttrs may include a synthetic row with attribute_name = '' for types that + // have zero attributes (rare but valid for types being constructed). Skip those. + if row.AttributeName == "" { + continue + } + collation := SchemaQualifiedName{} + if row.CollationName != "" { + collation = SchemaQualifiedName{ + EscapedName: EscapeIdentifier(row.CollationName), + SchemaName: row.CollationSchemaName, + } + } + ct.Attributes = append(ct.Attributes, CompositeTypeAttribute{ + Name: row.AttributeName, + Type: row.AttributeType, + Collation: collation, + }) + } + + // Mark each composite type as IsUsedByTable iff at least one table column has it as + // its declared type. We probe pg_attribute per-type because pg_depend doesn't track + // the column→type relationship in a way we can rely on here. + var compositeTypes []CompositeType + for _, e := range ordered { + dependsOnTypes, err := s.fetchDependsOnCompositeTypes(ctx, "pg_class", e.relOid) + if err != nil { + return nil, fmt.Errorf("fetchDependsOnCompositeTypes(%s): %w", e.relOid, err) + } + e.ct.DependsOnCompositeTypes = dependsOnTypes + + consumers, err := s.q.GetCompositeTypeTableConsumers(ctx, e.typeOid) + if err != nil { + return nil, fmt.Errorf("GetCompositeTypeTableConsumers: %w", err) + } + e.ct.IsUsedByTable = len(consumers) > 0 + compositeTypes = append(compositeTypes, *e.ct) + } + + compositeTypes = filterSliceByName( + compositeTypes, + func(ct CompositeType) SchemaQualifiedName { + return ct.SchemaQualifiedName + }, + s.nameFilter, + ) + + return compositeTypes, nil +} + func (s *schemaFetcher) fetchTables(ctx context.Context) ([]Table, error) { rawTables, err := s.q.GetTables(ctx) if err != nil { @@ -1319,12 +1452,17 @@ func (s *schemaFetcher) buildFunction(ctx context.Context, rawFunction queries.G if err != nil { return Function{}, fmt.Errorf("fetchDependsOnFunctions(%s): %w", rawFunction.Oid, err) } + dependsOnTypes, err := s.fetchDependsOnCompositeTypes(ctx, "pg_proc", rawFunction.Oid) + if err != nil { + return Function{}, fmt.Errorf("fetchDependsOnCompositeTypes(%s): %w", rawFunction.Oid, err) + } return Function{ - SchemaQualifiedName: buildProcName(rawFunction.FuncName, rawFunction.FuncIdentityArguments, rawFunction.FuncSchemaName), - FunctionDef: rawFunction.FuncDef, - Language: rawFunction.FuncLang, - DependsOnFunctions: dependsOnFunctions, + SchemaQualifiedName: buildProcName(rawFunction.FuncName, rawFunction.FuncIdentityArguments, rawFunction.FuncSchemaName), + FunctionDef: rawFunction.FuncDef, + Language: rawFunction.FuncLang, + DependsOnFunctions: dependsOnFunctions, + DependsOnCompositeTypes: dependsOnTypes, }, nil } @@ -1345,6 +1483,25 @@ func (s *schemaFetcher) fetchDependsOnFunctions(ctx context.Context, systemCatal return functionNames, nil } +func (s *schemaFetcher) fetchDependsOnCompositeTypes(ctx context.Context, systemCatalog string, oid any) ([]SchemaQualifiedName, error) { + rows, err := s.q.GetDependsOnCompositeTypes(ctx, queries.GetDependsOnCompositeTypesParams{ + SystemCatalog: systemCatalog, + ObjectID: oid, + }) + if err != nil { + return nil, err + } + + var names []SchemaQualifiedName + for _, r := range rows { + names = append(names, SchemaQualifiedName{ + SchemaName: r.TypeSchemaName, + EscapedName: EscapeIdentifier(r.TypeName), + }) + } + return names, nil +} + func (s *schemaFetcher) fetchProcedures(ctx context.Context) ([]Procedure, error) { rawProcedures, err := s.q.GetProcs(ctx, 'p') if err != nil { @@ -1353,9 +1510,14 @@ func (s *schemaFetcher) fetchProcedures(ctx context.Context) ([]Procedure, error var procedures []Procedure for _, rawProcedure := range rawProcedures { + dependsOnTypes, err := s.fetchDependsOnCompositeTypes(ctx, "pg_proc", rawProcedure.Oid) + if err != nil { + return nil, fmt.Errorf("fetchDependsOnCompositeTypes(%s): %w", rawProcedure.Oid, err) + } p := Procedure{ - SchemaQualifiedName: buildProcName(rawProcedure.FuncName, rawProcedure.FuncIdentityArguments, rawProcedure.FuncSchemaName), - Def: rawProcedure.FuncDef, + SchemaQualifiedName: buildProcName(rawProcedure.FuncName, rawProcedure.FuncIdentityArguments, rawProcedure.FuncSchemaName), + Def: rawProcedure.FuncDef, + DependsOnCompositeTypes: dependsOnTypes, } procedures = append(procedures, p) } diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index c73bb17..6b2a4d5 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -239,7 +239,7 @@ var ( GRANT SELECT ON schema_2.foo TO some_role_1; GRANT INSERT ON schema_2.foo TO some_role_2 WITH GRANT OPTION; `}, - expectedHash: "4c2174e2cac3956b", + expectedHash: "3901f9071795f21b", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, @@ -591,7 +591,7 @@ var ( ALTER TABLE foo_fk_1 ADD CONSTRAINT foo_fk_1_fk FOREIGN KEY (author, content) REFERENCES foo_1 (author, content) NOT VALID; `}, - expectedHash: "32c5a9c52dcfb15e", + expectedHash: "2e424d75a012ed5e", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, @@ -1173,7 +1173,7 @@ var ( CREATE TYPE pg_temp.color AS ENUM ('red', 'green', 'blue'); `}, // Assert empty schema hash, since we want to validate specifically that this hash is deterministic - expectedHash: "9c413c6ad2f4a042", + expectedHash: "ab91b603e898f324", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, diff --git a/pkg/diff/composite_type_sql_generator.go b/pkg/diff/composite_type_sql_generator.go new file mode 100644 index 0000000..bfcb739 --- /dev/null +++ b/pkg/diff/composite_type_sql_generator.go @@ -0,0 +1,244 @@ +package diff + +import ( + "fmt" + "strings" + + "github.com/google/go-cmp/cmp" + "github.com/stripe/pg-schema-diff/internal/schema" +) + +// compositeTypeSQLVertexGenerator handles `CREATE TYPE foo AS (...)` and +// `DROP TYPE foo` for user-defined composite types. It is a SQLVertexGenerator +// (rather than a plain SQLGenerator) so it can carry explicit ordering +// dependencies relative to the consumers of a composite type — namely tables, +// functions, procedures, and triggers, which may reference the type in column +// types, parameter types, or return types. +// +// Phase 1 scope: Add and Drop only. ALTER (when attributes change) is +// returned as ErrNotImplemented so the diff machinery fails loudly rather +// than silently dropping the change. A future phase can extend Alter to +// drop+recreate cascade for the function-only-dependents case. +type compositeTypeSQLVertexGenerator struct { + // newSchema and oldSchema are used to set up dependency edges from every + // table/function/procedure/trigger that may reference a composite type to + // the type's add/delete vertices. We do not parse signatures to know + // which functions actually reference a given composite — instead we + // take a blanket approach and let topo-sort untangle it. This matches + // how procedureSQLVertexGenerator handles its own untrackable deps. + newSchema schema.Schema + oldSchema schema.Schema + + recreatedCompositeTypes map[string]bool +} + +func newCompositeTypeSQLVertexGenerator(oldSchema, newSchema schema.Schema, recreatedCompositeTypes map[string]bool) sqlVertexGenerator[schema.CompositeType, compositeTypeDiff] { + return &compositeTypeSQLVertexGenerator{ + newSchema: newSchema, + oldSchema: oldSchema, + recreatedCompositeTypes: recreatedCompositeTypes, + } +} + +func (c *compositeTypeSQLVertexGenerator) Add(ct schema.CompositeType) (partialSQLGraph, error) { + addVertexId := buildCompositeTypeVertexId(ct.SchemaQualifiedName, diffTypeAddAlter) + + stmts := []Statement{{ + DDL: buildCreateCompositeTypeDDL(ct), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }} + + // The type must exist before any consumer (table/function/procedure/trigger) + // is added or altered. We add blanket "after-this" dependencies on the type + // from every such consumer in the new schema; the topo sort will pick the + // correct order. + deps := c.consumerDepsForAddAlter(ct) + deps = append(deps, c.compositeTypeDepsForAddAlter(ct)...) + // Run after re-create (if recreated). Mirrors the view/mview pattern. + deps = append(deps, mustRun(addVertexId).after(buildCompositeTypeVertexId(ct.SchemaQualifiedName, diffTypeDelete))) + + return partialSQLGraph{ + vertices: []sqlVertex{{ + id: addVertexId, + priority: sqlPrioritySooner, + statements: stmts, + }}, + dependencies: deps, + }, nil +} + +func (c *compositeTypeSQLVertexGenerator) Delete(ct schema.CompositeType) (partialSQLGraph, error) { + deleteVertexId := buildCompositeTypeVertexId(ct.SchemaQualifiedName, diffTypeDelete) + + // The type must be dropped after every consumer (in the OLD schema) is + // dropped or altered to no longer reference it. Add blanket + // "before-this" dependencies on the type's delete from every consumer's + // delete and add/alter vertices. + deps := c.consumerDepsForDelete(ct) + deps = append(deps, c.compositeTypeDepsForDelete(ct)...) + + return partialSQLGraph{ + vertices: []sqlVertex{{ + id: deleteVertexId, + priority: sqlPriorityLater, + statements: []Statement{{ + DDL: fmt.Sprintf("DROP TYPE %s", ct.GetFQEscapedName()), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }}, + }}, + dependencies: deps, + }, nil +} + +func (c *compositeTypeSQLVertexGenerator) Alter(d compositeTypeDiff) (partialSQLGraph, error) { + if cmp.Equal(d.old, d.new) { + return partialSQLGraph{}, nil + } + + // Anything beyond a description change requires altering the type's + // attribute list. Phase 2 (drop dependent functions → drop type → recreate + // type → recreate functions) is not implemented yet. For now we surface + // ErrNotImplemented so the user gets an explicit error rather than a + // silent drop. + return partialSQLGraph{}, fmt.Errorf("altering composite type attributes: %w", ErrNotImplemented) +} + +func buildCompositeTypeVertexId(name schema.SchemaQualifiedName, d diffType) sqlVertexId { + return buildSchemaObjVertexId("composite_type", name.GetFQEscapedName(), d) +} + +func buildCreateCompositeTypeDDL(ct schema.CompositeType) string { + if len(ct.Attributes) == 0 { + // PostgreSQL does not allow `CREATE TYPE foo AS ()` with zero columns at + // creation time, but it does allow ALTER TYPE ... DROP ATTRIBUTE down + // to zero. Such a state is not reachable from a declarative schema + // (which always describes what it wants from scratch), so emitting an + // empty parens is fine — apply will fail with a clear PG error if it + // ever happens. + return fmt.Sprintf("CREATE TYPE %s AS ()", ct.GetFQEscapedName()) + } + var attrDefs []string + for _, a := range ct.Attributes { + def := fmt.Sprintf("\t%s %s", schema.EscapeIdentifier(a.Name), a.Type) + if !a.Collation.IsEmpty() { + def += fmt.Sprintf(" COLLATE %s", a.Collation.GetFQEscapedName()) + } + attrDefs = append(attrDefs, def) + } + return fmt.Sprintf("CREATE TYPE %s AS (\n%s\n)", ct.GetFQEscapedName(), strings.Join(attrDefs, ",\n")) +} + +// dependsOnAnyRecreatedType reports whether any of the given composite-type +// references is in the set of types being recreated (their attribute list is +// changing). Used to force-recreate functions and procedures so PostgreSQL +// resolves their argument/return types against the new layout. +func dependsOnAnyRecreatedType(deps []schema.SchemaQualifiedName, recreated map[string]bool) bool { + for _, d := range deps { + if recreated[d.GetName()] { + return true + } + } + return false +} + +// consumerDepsForAddAlter returns dependency edges that force the +// composite type's CREATE to run before any consumer's CREATE/ALTER in +// the new schema. +func (c *compositeTypeSQLVertexGenerator) consumerDepsForAddAlter(ct schema.CompositeType) []dependency { + addVertexId := buildCompositeTypeVertexId(ct.SchemaQualifiedName, diffTypeAddAlter) + + var deps []dependency + for _, t := range c.newSchema.Tables { + deps = append(deps, mustRun(addVertexId).before(buildTableVertexId(t.SchemaQualifiedName, diffTypeAddAlter))) + } + for _, f := range c.newSchema.Functions { + deps = append(deps, mustRun(addVertexId).before(buildFunctionVertexId(f.SchemaQualifiedName, diffTypeAddAlter))) + } + for _, p := range c.newSchema.Procedures { + deps = append(deps, mustRun(addVertexId).before(buildProcedureVertexId(p.SchemaQualifiedName, diffTypeAddAlter))) + } + return deps +} + +func (c *compositeTypeSQLVertexGenerator) compositeTypeDepsForAddAlter(ct schema.CompositeType) []dependency { + addVertexId := buildCompositeTypeVertexId(ct.SchemaQualifiedName, diffTypeAddAlter) + + var deps []dependency + for _, dep := range ct.DependsOnCompositeTypes { + deps = append(deps, mustRun(addVertexId).after(buildCompositeTypeVertexId(dep, diffTypeAddAlter))) + } + return deps +} + +// consumerDepsForDelete returns dependency edges that force the +// composite type's DROP to run after every consumer in the old schema is +// dropped or (in the case of a pure delete) altered to no longer reference +// the type. +// +// When a consumer (function/procedure) STILL references this type in the +// new schema — i.e. the type is being recreated because its attributes +// changed and the consumer is force-recreated alongside it — we must NOT +// add the `typeDelete > consumerAddAlter` edge: doing so would force the +// consumer to be created BEFORE the type is dropped, which contradicts +// the required order +// +// consumerDelete < typeDelete < typeAdd < consumerAddAlter +// +// Tables are unconditional because we explicitly refuse type-recreation +// when a table column depends on the type (see buildSchemaDiff), so the +// recreation edge case never arises for tables. +func (c *compositeTypeSQLVertexGenerator) consumerDepsForDelete(ct schema.CompositeType) []dependency { + deleteVertexId := buildCompositeTypeVertexId(ct.SchemaQualifiedName, diffTypeDelete) + ctName := ct.GetName() + + newFunctionsByName := make(map[string]schema.Function, len(c.newSchema.Functions)) + for _, f := range c.newSchema.Functions { + newFunctionsByName[f.GetName()] = f + } + newProceduresByName := make(map[string]schema.Procedure, len(c.newSchema.Procedures)) + for _, p := range c.newSchema.Procedures { + newProceduresByName[p.GetName()] = p + } + + var deps []dependency + for _, t := range c.oldSchema.Tables { + deps = append(deps, mustRun(deleteVertexId).after(buildTableVertexId(t.SchemaQualifiedName, diffTypeDelete))) + deps = append(deps, mustRun(deleteVertexId).after(buildTableVertexId(t.SchemaQualifiedName, diffTypeAddAlter))) + } + for _, f := range c.oldSchema.Functions { + deps = append(deps, mustRun(deleteVertexId).after(buildFunctionVertexId(f.SchemaQualifiedName, diffTypeDelete))) + newDeps := newFunctionsByName[f.GetName()].DependsOnCompositeTypes + if !consumerStillDependsOnType(newDeps, ctName) && !dependsOnAnyRecreatedType(newDeps, c.recreatedCompositeTypes) { + deps = append(deps, mustRun(deleteVertexId).after(buildFunctionVertexId(f.SchemaQualifiedName, diffTypeAddAlter))) + } + } + for _, p := range c.oldSchema.Procedures { + deps = append(deps, mustRun(deleteVertexId).after(buildProcedureVertexId(p.SchemaQualifiedName, diffTypeDelete))) + newDeps := newProceduresByName[p.GetName()].DependsOnCompositeTypes + if !consumerStillDependsOnType(newDeps, ctName) && !dependsOnAnyRecreatedType(newDeps, c.recreatedCompositeTypes) { + deps = append(deps, mustRun(deleteVertexId).after(buildProcedureVertexId(p.SchemaQualifiedName, diffTypeAddAlter))) + } + } + return deps +} + +func (c *compositeTypeSQLVertexGenerator) compositeTypeDepsForDelete(ct schema.CompositeType) []dependency { + deleteVertexId := buildCompositeTypeVertexId(ct.SchemaQualifiedName, diffTypeDelete) + + var deps []dependency + for _, dep := range ct.DependsOnCompositeTypes { + deps = append(deps, mustRun(deleteVertexId).before(buildCompositeTypeVertexId(dep, diffTypeDelete))) + } + return deps +} + +func consumerStillDependsOnType(deps []schema.SchemaQualifiedName, ctName string) bool { + for _, d := range deps { + if d.GetName() == ctName { + return true + } + } + return false +} diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index 6c31e22..8f6ed9f 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -94,6 +94,10 @@ type ( oldAndNew[schema.Enum] } + compositeTypeDiff struct { + oldAndNew[schema.CompositeType] + } + extensionDiff struct { oldAndNew[schema.Extension] } @@ -146,6 +150,7 @@ type schemaDiff struct { namedSchemaDiffs listDiff[schema.NamedSchema, namedSchemaDiff] extensionDiffs listDiff[schema.Extension, extensionDiff] enumDiffs listDiff[schema.Enum, enumDiff] + compositeTypeDiffs listDiff[schema.CompositeType, compositeTypeDiff] tableDiffs listDiff[schema.Table, tableDiff] indexDiffs listDiff[schema.Index, indexDiff] foreignKeyConstraintDiffs listDiff[schema.ForeignKeyConstraint, foreignKeyConstraintDiff] @@ -234,6 +239,24 @@ func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { return schemaDiff{}, false, fmt.Errorf("diffing enums: %w", err) } + // compositeTypesBeingRecreated tracks types whose attribute layout is changing, + // including composite types that must be recreated because one of their + // attribute types is being recreated. Functions and procedures that reference + // any of these must be force-recreated so they pick up the new layout — + // `CREATE OR REPLACE FUNCTION` cannot change a function's argument or return type. + compositeTypesBeingRecreated, err := identifyCompositeTypesToRecreate(old.CompositeTypes, new.CompositeTypes) + if err != nil { + return schemaDiff{}, false, err + } + compositeTypeDiffs, err := diffLists(old.CompositeTypes, new.CompositeTypes, func(old, new schema.CompositeType, _, _ int) (compositeTypeDiff, bool, error) { + return compositeTypeDiff{ + oldAndNew[schema.CompositeType]{old: old, new: new}, + }, compositeTypesBeingRecreated[new.GetName()], nil + }) + if err != nil { + return schemaDiff{}, false, fmt.Errorf("diffing composite types: %w", err) + } + tableDiffs, err := diffLists(old.Tables, new.Tables, buildTableDiff) if err != nil { return schemaDiff{}, false, fmt.Errorf("diffing tables: %w", err) @@ -285,6 +308,14 @@ func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { } functionDiffs, err := diffLists(old.Functions, new.Functions, func(old, new schema.Function, _, _ int) (functionDiff, bool, error) { + // If the new function references a composite type whose attributes are being + // recreated, the function must be dropped and recreated alongside the type + // (CREATE OR REPLACE cannot change a function's argument or return type). + if dependsOnAnyRecreatedType(new.DependsOnCompositeTypes, compositeTypesBeingRecreated) { + return functionDiff{ + oldAndNew[schema.Function]{old: old, new: new}, + }, true, nil + } return functionDiff{ oldAndNew[schema.Function]{ old: old, @@ -297,6 +328,11 @@ func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { } procedureDiffs, err := diffLists(old.Procedures, new.Procedures, func(old, new schema.Procedure, _, _ int) (procedureDiff, bool, error) { + if dependsOnAnyRecreatedType(new.DependsOnCompositeTypes, compositeTypesBeingRecreated) { + return procedureDiff{ + oldAndNew[schema.Procedure]{old: old, new: new}, + }, true, nil + } return procedureDiff{ oldAndNew[schema.Procedure]{ old: old, @@ -347,6 +383,7 @@ func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { namedSchemaDiffs: schemaDiffs, extensionDiffs: extensionDiffs, enumDiffs: enumDiffs, + compositeTypeDiffs: compositeTypeDiffs, tableDiffs: tableDiffs, indexDiffs: indexesDiff, foreignKeyConstraintDiffs: foreignKeyConstraintDiffs, @@ -653,6 +690,17 @@ func (s schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) { } partialGraph = concatPartialGraphs(partialGraph, sequenceOwnershipsPartialGraph) + compositeTypesBeingRecreated, err := identifyCompositeTypesToRecreate(diff.old.CompositeTypes, diff.new.CompositeTypes) + if err != nil { + return nil, fmt.Errorf("identifying composite types to recreate: %w", err) + } + compositeTypeGenerator := newCompositeTypeSQLVertexGenerator(diff.old, diff.new, compositeTypesBeingRecreated) + compositeTypesPartialGraph, err := generatePartialGraph(compositeTypeGenerator, diff.compositeTypeDiffs) + if err != nil { + return nil, fmt.Errorf("resolving composite type diff: %w", err) + } + partialGraph = concatPartialGraphs(partialGraph, compositeTypesPartialGraph) + functionGenerator := newFunctionSqlVertexGenerator(functionsInNewSchemaByName) functionsPartialGraph, err := generatePartialGraph(functionGenerator, diff.functionDiffs) if err != nil { @@ -777,6 +825,46 @@ func buildSchemaObjByNameMap[S schema.Object](s []S) map[string]S { }) } +func identifyCompositeTypesToRecreate(old, new []schema.CompositeType) (map[string]bool, error) { + oldByName := buildSchemaObjByNameMap(old) + recreated := make(map[string]bool) + + for _, newType := range new { + oldType, ok := oldByName[newType.GetName()] + if !ok { + continue + } + if !cmp.Equal(oldType.Attributes, newType.Attributes) { + if oldType.IsUsedByTable { + return nil, fmt.Errorf("changing attributes of composite type %s used by a table column: %w", newType.GetFQEscapedName(), ErrNotImplemented) + } + recreated[newType.GetName()] = true + } + } + + for changed := true; changed; { + changed = false + for _, newType := range new { + if recreated[newType.GetName()] { + continue + } + oldType, ok := oldByName[newType.GetName()] + if !ok { + continue + } + if dependsOnAnyRecreatedType(newType.DependsOnCompositeTypes, recreated) { + if oldType.IsUsedByTable { + return nil, fmt.Errorf("recreating composite type %s used by a table column because one of its composite attributes changed: %w", newType.GetFQEscapedName(), ErrNotImplemented) + } + recreated[newType.GetName()] = true + changed = true + } + } + } + + return recreated, nil +} + func buildDiffByNameMap[S schema.Object, D diff[S]](d []D) map[string]D { return buildMap(d, func(d D) string { return d.GetNew().GetName()