diff --git a/api.go b/api.go index 253a587..fa8afa5 100644 --- a/api.go +++ b/api.go @@ -74,6 +74,7 @@ type MigrationBase struct { nonTransactional bool // set automatically or by ForceNonTransactional / inference forcedTx *bool // if not nil, explicitly chosen transactional mode (true=transactional, false=non-transactional) notes map[string]any + preserveComments bool } func (m MigrationBase) Copy() MigrationBase { @@ -328,6 +329,17 @@ func ForceTransactional() MigrationOption { } } +// PreserveComments prevents stripping of SQL comments before execution. +// This is primarily useful for testing scenarios where comment-only +// statements are needed to exercise specific code paths. +// PreserveComments can break DELIMITER handling so do not use in conjunction +// with SQL that includes DELIMITERs. +func PreserveComments() MigrationOption { + return func(m Migration) { + m.Base().preserveComments = true + } +} + // ApplyForceOverride overrides transactionality for any prior force call (ForceTransactional // or ForceNonTransactional) func (m *MigrationBase) ApplyForceOverride() { @@ -416,6 +428,9 @@ func (m *MigrationBase) ForcedTransactional() bool { return m.forcedTx != nil && // ForcedNonTransactional reports if ForceNonTransactional() was explicitly called. func (m *MigrationBase) ForcedNonTransactional() bool { return m.forcedTx != nil && !*m.forcedTx } +// PreserveComments reports if PreserveComments() was set on this migration. +func (m *MigrationBase) PreserveComments() bool { return m.preserveComments } + func (n MigrationName) String() string { return n.Library + ": " + n.Name } diff --git a/internal/mhelp/run_sql.go b/internal/mhelp/run_sql.go index 739c89e..c035af0 100644 --- a/internal/mhelp/run_sql.go +++ b/internal/mhelp/run_sql.go @@ -8,6 +8,7 @@ import ( "github.com/muir/libschema" "github.com/muir/libschema/classifysql" "github.com/muir/libschema/internal" + "github.com/muir/sqltoken" "github.com/pkg/errors" ) @@ -16,7 +17,42 @@ type CanExecContext interface { } func RunSQL(ctx context.Context, log *internal.Log, tx CanExecContext, statements classifysql.Statements, rowsAffected *int64, m libschema.Migration, d *libschema.Database) error { - for _, commandSQL := range statements.TokensList().Strings() { + for _, tokens := range statements.TokensList() { + if !m.Base().PreserveComments() { + tokens = tokens.Strip() + } + // Strip leading DelimiterStatement (e.g., "DELIMITER //\n") + if len(tokens) > 0 && tokens[0].Type == sqltoken.DelimiterStatement { + log.Debug("Stripping leading DelimiterStatement from migration", map[string]any{ + "name": m.Base().Name.Name, + "library": m.Base().Name.Library, + }) + tokens = tokens[1:] + } + // Strip trailing DelimiterStatement (e.g., "DELIMITER ;\n") and any whitespace before it + for len(tokens) > 0 && (tokens[len(tokens)-1].Type == sqltoken.DelimiterStatement || tokens[len(tokens)-1].Type == sqltoken.Whitespace) { + if tokens[len(tokens)-1].Type == sqltoken.DelimiterStatement { + log.Debug("Stripping trailing DelimiterStatement from migration", map[string]any{ + "name": m.Base().Name.Name, + "library": m.Base().Name.Library, + }) + } + tokens = tokens[:len(tokens)-1] + } + // Strip trailing Delimiter (e.g., "//") and any whitespace before it + for len(tokens) > 0 && (tokens[len(tokens)-1].Type == sqltoken.Delimiter || tokens[len(tokens)-1].Type == sqltoken.Whitespace) { + if tokens[len(tokens)-1].Type == sqltoken.Delimiter { + log.Debug("Stripping trailing Delimiter from migration", map[string]any{ + "name": m.Base().Name.Name, + "library": m.Base().Name.Library, + }) + } + tokens = tokens[:len(tokens)-1] + } + if len(tokens) == 0 { + continue + } + commandSQL := tokens.String() result, err := tx.ExecContext(ctx, commandSQL) if d.Options.DebugLogging { log.Debug("Executed SQL", map[string]any{ diff --git a/lsmysql/mysql.go b/lsmysql/mysql.go index a9255fa..81e7203 100644 --- a/lsmysql/mysql.go +++ b/lsmysql/mysql.go @@ -255,9 +255,9 @@ func (p *MySQL) DoOneMigration(ctx context.Context, log *internal.Log, d *libsch } sqlText = genSQL } - sqlText = strings.TrimSpace(sqlText) - m.Base().SetNote("sql", sqlText) - if sqlText == "" { + trimmedSQLText := strings.TrimSpace(sqlText) + m.Base().SetNote("sql", trimmedSQLText) + if trimmedSQLText == "" { return nil } statements, err := classifysql.ClassifyTokens(p.dialect, 0, sqlText) diff --git a/lsmysql/mysql_test.go b/lsmysql/mysql_test.go index cce14bf..fd612cc 100644 --- a/lsmysql/mysql_test.go +++ b/lsmysql/mysql_test.go @@ -273,3 +273,53 @@ func testMysqlNotAllowed(t *testing.T, dsn string, createPostfix string, driverN } } } + +func TestMysqlMigrationWithDelimiter(t *testing.T) { + testMysqlOneMigration(t, ` +DELIMITER // +CREATE PROCEDURE charge_account(IN id BIGINT, IN amount DECIMAL(18,4)) +BEGIN + DECLARE balance DECIMAL(18,4); + SELECT remaining_balance INTO balance + FROM account_balance + WHERE account_id = id; + IF balance > amount THEN + UPDATE account_balance + SET remaining_balance = balance - amount + WHERE account_id = id; + END IF; +END // +DELIMITER ; +`) +} + +func testMysqlOneMigration(t *testing.T, sqlText string) { + t.Parallel() + dsn := os.Getenv("LIBSCHEMA_MYSQL_TEST_DSN") + if dsn == "" { + t.Skip("Set $LIBSCHEMA_MYSQL_TEST_DSN to test libschema/lsmysql") + } + testOneMigration(t, dsn, sqlText, mysqlNew) +} + +func testOneMigration(t *testing.T, dsn string, sqlText string, driverNew driverNew) { + options, cleanup := lstesting.FakeSchema(t, "") + + t.Log("Doing migrations in database/schema", options.SchemaOverride) + + options.DebugLogging = true + db, err := sql.Open("mysql", dsn) + require.NoError(t, err, "open database") + defer func() { + assert.NoError(t, db.Close()) + }() + defer cleanup(db) + + s := libschema.New(context.Background(), options) + dbase, _, err := driverNew(t, "test", s, db) + require.NoError(t, err, "libschema NewDatabase") + + dbase.Migrations("L1", lsmysql.Script("M1", sqlText)) + err = s.Migrate(context.Background()) + assert.NoError(t, err) +} diff --git a/lsmysql/singlestore_test.go b/lsmysql/singlestore_test.go index bada3f9..5c79f94 100644 --- a/lsmysql/singlestore_test.go +++ b/lsmysql/singlestore_test.go @@ -60,6 +60,27 @@ func TestSingleStoreNotAllowed(t *testing.T) { testMysqlNotAllowed(t, dsn, "", singleStoreNew) } +func TestSingleStoreMigrationWithDelimiter(t *testing.T) { + testSingleStoreOneMigration(t, ` +DELIMITER // +CREATE OR REPLACE PROCEDURE test_proc() +AS +BEGIN + ECHO SELECT 1; +END // +DELIMITER ; +`) +} + +func testSingleStoreOneMigration(t *testing.T, sqlText string) { + t.Parallel() + dsn := os.Getenv("LIBSCHEMA_SINGLESTORE_TEST_DSN") + if dsn == "" { + t.Skip("Set $LIBSCHEMA_SINGLESTORE_TEST_DSN to test SingleStore support in libschema/lsmysql") + } + testOneMigration(t, dsn, sqlText, singleStoreNew) +} + func TestSingleStoreFailedMigration(t *testing.T) { t.Parallel() dsn := os.Getenv("LIBSCHEMA_SINGLESTORE_TEST_DSN") diff --git a/lspostgres/bad_test.go b/lspostgres/bad_test.go index 29b5aec..b175c50 100644 --- a/lspostgres/bad_test.go +++ b/lspostgres/bad_test.go @@ -121,7 +121,7 @@ func TestBadMigrationsPostgres(t *testing.T) { define: func(dbase *libschema.Database) { dbase.Migrations("L9", lspostgres.Script("T4", `CREATE TABLE T1 (id text)`), - lspostgres.Script("T5", ` -- just a comment`, libschema.RepeatUntilNoOp()), + lspostgres.Script("T5", ` -- just a comment`, libschema.RepeatUntilNoOp(), libschema.PreserveComments()), ) }, }, diff --git a/lspostgres/non_tx_test.go b/lspostgres/non_tx_test.go index 451f069..f4997c2 100644 --- a/lspostgres/non_tx_test.go +++ b/lspostgres/non_tx_test.go @@ -217,7 +217,8 @@ func TestRowsAffectedErrorLogged(t *testing.T) { require.NoError(t, err) lib := fmt.Sprintf("RA_%d", time.Now().UnixNano()) comment := lspostgres.Script("NOTHING", - " -- just a comment") + " -- just a comment", + libschema.PreserveComments()) dbase.Migrations(lib, comment) require.NoError(t, s.Migrate(ctx)) entries := capLog.Entries() diff --git a/lspostgres/postgres.go b/lspostgres/postgres.go index 359ce19..7837be7 100644 --- a/lspostgres/postgres.go +++ b/lspostgres/postgres.go @@ -219,9 +219,9 @@ func (p *Postgres) DoOneMigration(ctx context.Context, log *internal.Log, d *lib } scriptSQL = sqlText } - scriptSQL = strings.TrimSpace(scriptSQL) - m.Base().SetNote("sql", scriptSQL) - if scriptSQL == "" { + trimmedScriptSQL := strings.TrimSpace(scriptSQL) + m.Base().SetNote("sql", trimmedScriptSQL) + if trimmedScriptSQL == "" { return nil } // Classification & downgrade via classifysql