diff --git a/internal/migration_acceptance_tests/named_schema_cases_test.go b/internal/migration_acceptance_tests/named_schema_cases_test.go index 515360c..d625421 100644 --- a/internal/migration_acceptance_tests/named_schema_cases_test.go +++ b/internal/migration_acceptance_tests/named_schema_cases_test.go @@ -1,6 +1,10 @@ package migration_acceptance_tests -import "testing" +import ( + "testing" + + "github.com/stripe/pg-schema-diff/pkg/diff" +) var namedSchemaAcceptanceTestCases = []acceptanceTestCase{ { @@ -29,12 +33,92 @@ var namedSchemaAcceptanceTestCases = []acceptanceTestCase{ name: "Drop schema", oldSchemaDDL: []string{` CREATE SCHEMA "schema 1"; - CREATE SCHEMA "schema 2"; + CREATE SCHEMA "schema 2"; `}, newSchemaDDL: []string{` CREATE SCHEMA "schema 1"; `}, }, + { + name: "Grant usage on existing schema", + roles: []string{"app_user"}, + oldSchemaDDL: []string{` + CREATE SCHEMA app_schema; + `}, + newSchemaDDL: []string{` + CREATE SCHEMA app_schema; + GRANT USAGE ON SCHEMA app_schema TO app_user; + `}, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + expectedPlanDDL: []string{`GRANT USAGE ON SCHEMA "app_schema" TO "app_user"`}, + }, + { + name: "Revoke usage on existing schema", + roles: []string{"app_user"}, + oldSchemaDDL: []string{` + CREATE SCHEMA app_schema; + GRANT USAGE ON SCHEMA app_schema TO app_user; + `}, + newSchemaDDL: []string{` + CREATE SCHEMA app_schema; + `}, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + expectedPlanDDL: []string{`REVOKE USAGE ON SCHEMA "app_schema" FROM "app_user"`}, + }, + { + name: "Change schema grant option", + roles: []string{"app_user"}, + oldSchemaDDL: []string{` + CREATE SCHEMA app_schema; + GRANT USAGE ON SCHEMA app_schema TO app_user; + `}, + newSchemaDDL: []string{` + CREATE SCHEMA app_schema; + GRANT USAGE ON SCHEMA app_schema TO app_user WITH GRANT OPTION; + `}, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Revoke create when target only grants usage", + roles: []string{"app_user"}, + oldSchemaDDL: []string{` + CREATE SCHEMA app_schema; + GRANT USAGE, CREATE ON SCHEMA app_schema TO app_user; + `}, + newSchemaDDL: []string{` + CREATE SCHEMA app_schema; + GRANT USAGE ON SCHEMA app_schema TO app_user; + `}, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + expectedPlanDDL: []string{ + `REVOKE CREATE ON SCHEMA "app_schema" FROM "app_user"`, + }, + }, + { + name: "Do not grant usage to existing schema owner", + roles: []string{"service"}, + oldSchemaDDL: []string{` + CREATE SCHEMA casino_wager_stats AUTHORIZATION service; + GRANT USAGE, CREATE ON SCHEMA casino_wager_stats TO service; + `}, + newSchemaDDL: []string{` + CREATE SCHEMA casino_wager_stats; + GRANT USAGE ON SCHEMA casino_wager_stats TO service; + `}, + expectedDBSchemaDDL: []string{` + CREATE SCHEMA casino_wager_stats AUTHORIZATION service; + GRANT USAGE, CREATE ON SCHEMA casino_wager_stats TO service; + `}, + expectEmptyPlan: true, + }, } func TestNamedSchemaTestCases(t *testing.T) { diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql index d8397c8..fa70e22 100644 --- a/internal/queries/queries.sql +++ b/internal/queries/queries.sql @@ -1,6 +1,10 @@ -- name: GetSchemas :many -SELECT nspname::TEXT AS schema_name +SELECT + pg_namespace.nspname::TEXT AS schema_name, + owner_role.rolname::TEXT AS owner FROM pg_catalog.pg_namespace +INNER JOIN pg_catalog.pg_roles AS owner_role + ON pg_namespace.nspowner = owner_role.oid WHERE nspname NOT IN ('pg_catalog', 'information_schema') AND nspname !~ '^pg_toast' @@ -15,6 +19,42 @@ WHERE AND depend.deptype = 'e' ); +-- name: GetSchemaPrivileges :many +WITH parsed_acl AS ( + SELECT + n.nspname AS schema_name, + n.nspowner AS owner_oid, + (ACLEXPLODE(n.nspacl)).grantee AS grantee_oid, + (ACLEXPLODE(n.nspacl)).privilege_type AS privilege_type, + (ACLEXPLODE(n.nspacl)).is_grantable AS is_grantable + FROM pg_catalog.pg_namespace AS n + WHERE + n.nspname NOT IN ('pg_catalog', 'information_schema') + AND n.nspname !~ '^pg_toast' + AND n.nspname !~ '^pg_temp' + -- Exclude schemas owned by extensions + AND NOT EXISTS ( + SELECT depend.objid + FROM pg_catalog.pg_depend AS depend + WHERE + depend.classid = 'pg_namespace'::REGCLASS + AND depend.objid = n.oid + AND depend.deptype = 'e' + ) +) + +SELECT + pa.schema_name::TEXT AS schema_name, + COALESCE(grantee_role.rolname, '')::TEXT AS grantee, + pa.privilege_type::TEXT AS privilege, + pa.is_grantable +FROM parsed_acl AS pa +LEFT JOIN pg_catalog.pg_roles AS grantee_role + ON pa.grantee_oid = grantee_role.oid +-- Exclude privileges granted to the schema owner (these are implicit) +WHERE pa.grantee_oid != pa.owner_oid OR pa.grantee_oid = 0 +ORDER BY pa.schema_name, grantee, pa.privilege_type; + -- name: GetTables :many SELECT c.oid, diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go index 4315102..78c7f96 100644 --- a/internal/queries/queries.sql.go +++ b/internal/queries/queries.sql.go @@ -863,9 +863,85 @@ func (q *Queries) GetProcs(ctx context.Context, prokind interface{}) ([]GetProcs return items, nil } +const getSchemaPrivileges = `-- name: GetSchemaPrivileges :many +WITH parsed_acl AS ( + SELECT + n.nspname AS schema_name, + n.nspowner AS owner_oid, + (ACLEXPLODE(n.nspacl)).grantee AS grantee_oid, + (ACLEXPLODE(n.nspacl)).privilege_type AS privilege_type, + (ACLEXPLODE(n.nspacl)).is_grantable AS is_grantable + FROM pg_catalog.pg_namespace AS n + WHERE + n.nspname NOT IN ('pg_catalog', 'information_schema') + AND n.nspname !~ '^pg_toast' + AND n.nspname !~ '^pg_temp' + -- Exclude schemas owned by extensions + AND NOT EXISTS ( + SELECT depend.objid + FROM pg_catalog.pg_depend AS depend + WHERE + depend.classid = 'pg_namespace'::REGCLASS + AND depend.objid = n.oid + AND depend.deptype = 'e' + ) +) + +SELECT + pa.schema_name::TEXT AS schema_name, + COALESCE(grantee_role.rolname, '')::TEXT AS grantee, + pa.privilege_type::TEXT AS privilege, + pa.is_grantable +FROM parsed_acl AS pa +LEFT JOIN pg_catalog.pg_roles AS grantee_role + ON pa.grantee_oid = grantee_role.oid +WHERE pa.grantee_oid != pa.owner_oid OR pa.grantee_oid = 0 +ORDER BY pa.schema_name, grantee, pa.privilege_type +` + +type GetSchemaPrivilegesRow struct { + SchemaName string + Grantee string + Privilege string + IsGrantable interface{} +} + +// Exclude privileges granted to the schema owner (these are implicit) +func (q *Queries) GetSchemaPrivileges(ctx context.Context) ([]GetSchemaPrivilegesRow, error) { + rows, err := q.db.QueryContext(ctx, getSchemaPrivileges) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetSchemaPrivilegesRow + for rows.Next() { + var i GetSchemaPrivilegesRow + if err := rows.Scan( + &i.SchemaName, + &i.Grantee, + &i.Privilege, + &i.IsGrantable, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getSchemas = `-- name: GetSchemas :many -SELECT nspname::TEXT AS schema_name +SELECT + pg_namespace.nspname::TEXT AS schema_name, + owner_role.rolname::TEXT AS owner FROM pg_catalog.pg_namespace +INNER JOIN pg_catalog.pg_roles AS owner_role + ON pg_namespace.nspowner = owner_role.oid WHERE nspname NOT IN ('pg_catalog', 'information_schema') AND nspname !~ '^pg_toast' @@ -881,19 +957,24 @@ WHERE ) ` -func (q *Queries) GetSchemas(ctx context.Context) ([]string, error) { +type GetSchemasRow struct { + SchemaName string + Owner string +} + +func (q *Queries) GetSchemas(ctx context.Context) ([]GetSchemasRow, error) { rows, err := q.db.QueryContext(ctx, getSchemas) if err != nil { return nil, err } defer rows.Close() - var items []string + var items []GetSchemasRow for rows.Next() { - var schema_name string - if err := rows.Scan(&schema_name); err != nil { + var i GetSchemasRow + if err := rows.Scan(&i.SchemaName, &i.Owner); err != nil { return nil, err } - items = append(items, schema_name) + items = append(items, i) } if err := rows.Close(); err != nil { return nil, err diff --git a/internal/schema/schema.go b/internal/schema/schema.go index b3945d5..9a53e4d 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -69,7 +69,12 @@ type Schema struct { // Normalize normalizes the schema (alphabetically sorts tables and columns in tables). // Useful for hashing and testing. func (s Schema) Normalize() Schema { - s.NamedSchemas = sortSchemaObjectsByName(s.NamedSchemas) + var normNamedSchemas []NamedSchema + for _, namedSchema := range sortSchemaObjectsByName(s.NamedSchemas) { + namedSchema.Privileges = sortSchemaObjectsByName(namedSchema.Privileges) + normNamedSchemas = append(normNamedSchemas, namedSchema) + } + s.NamedSchemas = normNamedSchemas s.Extensions = sortSchemaObjectsByName(s.Extensions) s.Enums = sortSchemaObjectsByName(s.Enums) @@ -196,12 +201,33 @@ const ( // schema type NamedSchema struct { Name string + // Owner is used to classify implicit owner privileges; schema ownership changes are not generated. + Owner string + Privileges []SchemaPrivilege } func (n NamedSchema) GetName() string { return n.Name } +// SchemaPrivilege represents a privilege granted on a schema. +type SchemaPrivilege struct { + // Grantee is the role that has the privilege. Empty string means PUBLIC. + Grantee string + // Privilege is the type of privilege (USAGE, CREATE) + Privilege string + // IsGrantable indicates if the grantee can grant this privilege to others (WITH GRANT OPTION) + IsGrantable bool +} + +func (p SchemaPrivilege) GetName() string { + grantee := p.Grantee + if grantee == "" { + grantee = "PUBLIC" + } + return fmt.Sprintf("%s:%s", grantee, p.Privilege) +} + type Extension struct { SchemaQualifiedName Version string @@ -842,15 +868,26 @@ func (s *schemaFetcher) getSchema(ctx context.Context) (Schema, error) { } func (s *schemaFetcher) fetchNamedSchemas(ctx context.Context) ([]NamedSchema, error) { - schemaNames, err := s.q.GetSchemas(ctx) + schemaRows, err := s.q.GetSchemas(ctx) if err != nil { return nil, fmt.Errorf("GetSchemas(): %w", err) } + schemaPrivileges, err := s.fetchSchemaPrivileges(ctx) + if err != nil { + return nil, fmt.Errorf("fetchSchemaPrivileges(): %w", err) + } + privilegesBySchema := make(map[string][]SchemaPrivilege) + for _, p := range schemaPrivileges { + privilegesBySchema[p.schemaName] = append(privilegesBySchema[p.schemaName], p.privilege) + } + var schemas []NamedSchema - for _, schemaName := range schemaNames { + for _, schemaRow := range schemaRows { schemas = append(schemas, NamedSchema{ - Name: schemaName, + Name: schemaRow.SchemaName, + Owner: schemaRow.Owner, + Privileges: privilegesBySchema[schemaRow.SchemaName], }) } @@ -1381,6 +1418,51 @@ type privilegeAndTable struct { table SchemaQualifiedName } +type privilegeAndSchema struct { + privilege SchemaPrivilege + schemaName string +} + +func (s *schemaFetcher) fetchSchemaPrivileges(ctx context.Context) ([]privilegeAndSchema, error) { + rawPrivileges, err := s.q.GetSchemaPrivileges(ctx) + if err != nil { + return nil, fmt.Errorf("GetSchemaPrivileges: %w", err) + } + + var privileges []privilegeAndSchema + for _, rp := range rawPrivileges { + // Handle the is_grantable field which may be returned as interface{} + isGrantable := false + if rp.IsGrantable != nil { + if b, ok := rp.IsGrantable.(bool); ok { + isGrantable = b + } + } + + privileges = append(privileges, privilegeAndSchema{ + privilege: SchemaPrivilege{ + Grantee: rp.Grantee, + Privilege: rp.Privilege, + IsGrantable: isGrantable, + }, + schemaName: rp.SchemaName, + }) + } + + privileges = filterSliceByName( + privileges, + func(p privilegeAndSchema) SchemaQualifiedName { + return SchemaQualifiedName{ + SchemaName: p.schemaName, + EscapedName: EscapeIdentifier(p.schemaName), + } + }, + s.nameFilter, + ) + + return privileges, nil +} + func (s *schemaFetcher) fetchPolicies(ctx context.Context) ([]policyAndTable, error) { rawPolicies, err := s.q.GetPolicies(ctx) if err != nil { diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index c73bb17..feb8858 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -34,6 +34,16 @@ var ( EscapedName: `"C"`, SchemaName: "pg_catalog", } + publicSchema = NamedSchema{ + Name: "public", + Owner: "pg_database_owner", + Privileges: []SchemaPrivilege{ + {Grantee: "", Privilege: "USAGE", IsGrantable: false}, + }, + } + postgresOwnedSchema = func(name string) NamedSchema { + return NamedSchema{Name: name, Owner: "postgres"} + } testCases = []*testCase{ // Exclude materialized views from the test for now because Postgres 14-15 fully qualify column names while Postgres @@ -239,12 +249,12 @@ var ( GRANT SELECT ON schema_2.foo TO some_role_1; GRANT INSERT ON schema_2.foo TO some_role_2 WITH GRANT OPTION; `}, - expectedHash: "4c2174e2cac3956b", + expectedHash: "b8bd15ebe3d99c4b", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, - {Name: "schema_1"}, - {Name: "schema_2"}, + publicSchema, + postgresOwnedSchema("schema_1"), + postgresOwnedSchema("schema_2"), }, Extensions: []Extension{ { @@ -591,10 +601,10 @@ var ( ALTER TABLE foo_fk_1 ADD CONSTRAINT foo_fk_1_fk FOREIGN KEY (author, content) REFERENCES foo_1 (author, content) NOT VALID; `}, - expectedHash: "32c5a9c52dcfb15e", + expectedHash: "bb1cdc1ffff18ddd", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, + publicSchema, }, Tables: []Table{ { @@ -913,7 +923,7 @@ var ( `}, expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, + publicSchema, }, Tables: []Table{ { @@ -970,7 +980,7 @@ var ( `}, expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, + publicSchema, }, Tables: []Table{ { @@ -1043,7 +1053,7 @@ var ( `}, expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, + publicSchema, }, Tables: []Table{ { @@ -1077,7 +1087,7 @@ var ( `}, expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, + publicSchema, }, Tables: []Table{ { @@ -1105,7 +1115,7 @@ var ( `}, expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, + publicSchema, }, Tables: []Table{ { @@ -1130,8 +1140,8 @@ var ( `}, expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, - {Name: "schema_1"}, + publicSchema, + postgresOwnedSchema("schema_1"), }, Tables: []Table{ { @@ -1173,10 +1183,10 @@ var ( CREATE TYPE pg_temp.color AS ENUM ('red', 'green', 'blue'); `}, // Assert empty schema hash, since we want to validate specifically that this hash is deterministic - expectedHash: "9c413c6ad2f4a042", + expectedHash: "bfda373852505980", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, + publicSchema, }, }, }, @@ -1189,7 +1199,7 @@ var ( `}, expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, + publicSchema, }, Tables: []Table{ { @@ -1217,8 +1227,8 @@ var ( `}, expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "public"}, - {Name: "schema_2"}, + publicSchema, + postgresOwnedSchema("schema_2"), }, Tables: []Table{ { @@ -1248,7 +1258,7 @@ var ( `}, expectedSchema: Schema{ NamedSchemas: []NamedSchema{ - {Name: "schema_1"}, + postgresOwnedSchema("schema_1"), }, Tables: []Table{ { diff --git a/pkg/diff/named_schema_sql_generator.go b/pkg/diff/named_schema_sql_generator.go index a8d758f..2dee91a 100644 --- a/pkg/diff/named_schema_sql_generator.go +++ b/pkg/diff/named_schema_sql_generator.go @@ -11,11 +11,23 @@ import ( type namedSchemaSQLGenerator struct{} func (n *namedSchemaSQLGenerator) Add(s schema.NamedSchema) ([]Statement, error) { - return []Statement{{ + stmts := []Statement{{ DDL: fmt.Sprintf("CREATE SCHEMA %s", schema.EscapeIdentifier(s.Name)), Timeout: statementTimeoutDefault, LockTimeout: lockTimeoutDefault, - }}, nil + }} + + privilegeGenerator := &schemaPrivilegeSQLGenerator{schemaName: s.Name} + for _, privilege := range s.Privileges { + addPrivilegeStmts, err := privilegeGenerator.Add(privilege) + if err != nil { + return nil, fmt.Errorf("generating add schema privilege statements for privilege %s: %w", privilege.GetName(), err) + } + // Remove hazards from statements since the schema is brand new. + stmts = append(stmts, stripMigrationHazards(addPrivilegeStmts...)...) + } + + return stmts, nil } func (n *namedSchemaSQLGenerator) Delete(s schema.NamedSchema) ([]Statement, error) { @@ -26,6 +38,16 @@ func (n *namedSchemaSQLGenerator) Delete(s schema.NamedSchema) ([]Statement, err }}, nil } -func (n *namedSchemaSQLGenerator) Alter(_ namedSchemaDiff) ([]Statement, error) { - return nil, nil +func (n *namedSchemaSQLGenerator) Alter(diff namedSchemaDiff) ([]Statement, error) { + privilegeGenerator := &schemaPrivilegeSQLGenerator{schemaName: diff.new.Name} + privilegeStatements, err := diff.privilegesDiff.resolveToSQLGroupedByEffect(privilegeGenerator) + if err != nil { + return nil, fmt.Errorf("resolving schema privilege sql: %w", err) + } + + var stmts []Statement + stmts = append(stmts, privilegeStatements.Deletes...) + stmts = append(stmts, privilegeStatements.Alters...) + stmts = append(stmts, privilegeStatements.Adds...) + return stmts, nil } diff --git a/pkg/diff/plan_generator.go b/pkg/diff/plan_generator.go index 2c2c92b..0017e1e 100644 --- a/pkg/diff/plan_generator.go +++ b/pkg/diff/plan_generator.go @@ -279,9 +279,17 @@ func schemaFromTempDb(ctx context.Context, db *tempdb.Database, plan *planOption return schema.GetSchema(ctx, db.ConnPool, append(plan.getSchemaOpts, db.ExcludeMetadataOptions...)...) } -// clearTablePrivileges returns a copy of the schema with all table privileges cleared. +// clearSkippedPrivileges returns a copy of the schema with all privileges cleared that are emitted as +// SkipValidation statements. // This is used during plan validation because privilege statements are skipped (roles don't exist in temp DB). -func clearTablePrivileges(s schema.Schema) schema.Schema { +func clearSkippedPrivileges(s schema.Schema) schema.Schema { + namedSchemas := make([]schema.NamedSchema, len(s.NamedSchemas)) + for i, namedSchema := range s.NamedSchemas { + namedSchema.Privileges = nil + namedSchemas[i] = namedSchema + } + s.NamedSchemas = namedSchemas + tables := make([]schema.Table, len(s.Tables)) for i, t := range s.Tables { t.Privileges = nil @@ -294,8 +302,8 @@ func clearTablePrivileges(s schema.Schema) schema.Schema { func assertMigratedSchemaMatchesTarget(migratedSchema, targetSchema schema.Schema, planOptions *planOptions) error { // Clear privileges from both schemas since privilege statements are skipped during validation // (roles don't exist in temp DB). We make copies to avoid modifying the original schemas. - migratedSchema = clearTablePrivileges(migratedSchema) - targetSchema = clearTablePrivileges(targetSchema) + migratedSchema = clearSkippedPrivileges(migratedSchema) + targetSchema = clearSkippedPrivileges(targetSchema) toTargetSchemaStmts, err := generateMigrationStatements(migratedSchema, targetSchema, planOptions) if err != nil { diff --git a/pkg/diff/schema_privilege_sql_generator.go b/pkg/diff/schema_privilege_sql_generator.go new file mode 100644 index 0000000..aa49df0 --- /dev/null +++ b/pkg/diff/schema_privilege_sql_generator.go @@ -0,0 +1,58 @@ +package diff + +import ( + "fmt" + + "github.com/stripe/pg-schema-diff/internal/schema" +) + +type schemaPrivilegeSQLGenerator struct { + schemaName string +} + +func (spg *schemaPrivilegeSQLGenerator) Add(p schema.SchemaPrivilege) ([]Statement, error) { + grantee := p.Grantee + if grantee == "" { + grantee = "PUBLIC" + } else { + grantee = schema.EscapeIdentifier(grantee) + } + + ddl := fmt.Sprintf("GRANT %s ON SCHEMA %s TO %s", p.Privilege, schema.EscapeIdentifier(spg.schemaName), grantee) + if p.IsGrantable { + ddl += " WITH GRANT OPTION" + } + + return []Statement{{ + DDL: ddl, + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardPrivilegeGranted}, + SkipValidation: true, + }}, nil +} + +func (spg *schemaPrivilegeSQLGenerator) Delete(p schema.SchemaPrivilege) ([]Statement, error) { + grantee := p.Grantee + if grantee == "" { + grantee = "PUBLIC" + } else { + grantee = schema.EscapeIdentifier(grantee) + } + + ddl := fmt.Sprintf("REVOKE %s ON SCHEMA %s FROM %s", p.Privilege, schema.EscapeIdentifier(spg.schemaName), grantee) + + return []Statement{{ + DDL: ddl, + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardPrivilegeRevoked}, + SkipValidation: true, + }}, nil +} + +func (spg *schemaPrivilegeSQLGenerator) Alter(_ schemaPrivilegeDiff) ([]Statement, error) { + // Privileges don't support ALTER - if IsGrantable changes, we need to recreate + // (handled via requiresRecreation in buildSchemaDiff). + return nil, nil +} diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index 6c31e22..333d1c9 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -88,6 +88,11 @@ func (o oldAndNew[S]) GetOld() S { type ( namedSchemaDiff struct { oldAndNew[schema.NamedSchema] + privilegesDiff listDiff[schema.SchemaPrivilege, schemaPrivilegeDiff] + } + + schemaPrivilegeDiff struct { + oldAndNew[schema.SchemaPrivilege] } enumDiff struct { @@ -196,11 +201,27 @@ func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { old.NamedSchemas, new.NamedSchemas, func(old, new schema.NamedSchema, _, _ int) (namedSchemaDiff, bool, error) { + oldPrivileges := filterSchemaOwnerPrivileges(old.Privileges, old.Owner) + newPrivileges := filterSchemaOwnerPrivileges(new.Privileges, old.Owner) + privilegesDiff, err := diffLists( + oldPrivileges, + newPrivileges, + func(old, new schema.SchemaPrivilege, _, _ int) (schemaPrivilegeDiff, bool, error) { + // Recreate the privilege if IsGrantable changes + recreate := old.IsGrantable != new.IsGrantable + return schemaPrivilegeDiff{oldAndNew[schema.SchemaPrivilege]{old: old, new: new}}, recreate, nil + }, + ) + if err != nil { + return namedSchemaDiff{}, false, fmt.Errorf("diffing schema privileges: %w", err) + } + return namedSchemaDiff{ oldAndNew[schema.NamedSchema]{ old: old, new: new, }, + privilegesDiff, }, false, nil }) if err != nil { @@ -359,6 +380,21 @@ func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { }, false, nil } +func filterSchemaOwnerPrivileges(privileges []schema.SchemaPrivilege, owner string) []schema.SchemaPrivilege { + if owner == "" { + return privileges + } + + var filtered []schema.SchemaPrivilege + for _, privilege := range privileges { + if privilege.Grantee == owner && (privilege.Privilege == "USAGE" || privilege.Privilege == "CREATE") { + continue + } + filtered = append(filtered, privilege) + } + return filtered +} + func buildTableDiff(oldTable, newTable schema.Table, _, _ int) (diff tableDiff, requiresRecreation bool, err error) { if oldTable.IsPartitioned() != newTable.IsPartitioned() { return tableDiff{}, true, nil diff --git a/pkg/tempdb/factory_test.go b/pkg/tempdb/factory_test.go index d271fb6..9716a90 100644 --- a/pkg/tempdb/factory_test.go +++ b/pkg/tempdb/factory_test.go @@ -174,7 +174,13 @@ func (suite *onInstanceTempDbFactorySuite) TestCreate_CreateAndDropFlow() { suite.Require().NoError(err) suite.Equal(internalschema.Schema{ NamedSchemas: []internalschema.NamedSchema{{ - Name: "public", + Name: "public", + Owner: "pg_database_owner", + Privileges: []internalschema.SchemaPrivilege{{ + Grantee: "", + Privilege: "USAGE", + IsGrantable: false, + }}, }}, }, schema)