Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ type MigrationBase struct {
nonTransactional bool // set automatically or by ForceNonTransactional / inference
forcedTx *bool // if not nil, explicitly chosen transactional mode (true=transactional, false=non-transactional)
notes map[string]any
preserveComments bool
}

func (m MigrationBase) Copy() MigrationBase {
Expand Down Expand Up @@ -328,6 +329,17 @@ func ForceTransactional() MigrationOption {
}
}

// PreserveComments prevents stripping of SQL comments before execution.
// This is primarily useful for testing scenarios where comment-only
// statements are needed to exercise specific code paths.
// PreserveComments can break DELIMITER handling so do not use in conjunction
// with SQL that includes DELIMITERs.
func PreserveComments() MigrationOption {
return func(m Migration) {
m.Base().preserveComments = true
}
}

// ApplyForceOverride overrides transactionality for any prior force call (ForceTransactional
// or ForceNonTransactional)
func (m *MigrationBase) ApplyForceOverride() {
Expand Down Expand Up @@ -416,6 +428,9 @@ func (m *MigrationBase) ForcedTransactional() bool { return m.forcedTx != nil &&
// ForcedNonTransactional reports if ForceNonTransactional() was explicitly called.
func (m *MigrationBase) ForcedNonTransactional() bool { return m.forcedTx != nil && !*m.forcedTx }

// PreserveComments reports if PreserveComments() was set on this migration.
func (m *MigrationBase) PreserveComments() bool { return m.preserveComments }

func (n MigrationName) String() string {
return n.Library + ": " + n.Name
}
38 changes: 37 additions & 1 deletion internal/mhelp/run_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/muir/libschema"
"github.com/muir/libschema/classifysql"
"github.com/muir/libschema/internal"
"github.com/muir/sqltoken"
"github.com/pkg/errors"
)

Expand All @@ -16,7 +17,42 @@ type CanExecContext interface {
}

func RunSQL(ctx context.Context, log *internal.Log, tx CanExecContext, statements classifysql.Statements, rowsAffected *int64, m libschema.Migration, d *libschema.Database) error {
for _, commandSQL := range statements.TokensList().Strings() {
for _, tokens := range statements.TokensList() {
if !m.Base().PreserveComments() {
tokens = tokens.Strip()
}
// Strip leading DelimiterStatement (e.g., "DELIMITER //\n")
if len(tokens) > 0 && tokens[0].Type == sqltoken.DelimiterStatement {
log.Debug("Stripping leading DelimiterStatement from migration", map[string]any{
"name": m.Base().Name.Name,
"library": m.Base().Name.Library,
})
tokens = tokens[1:]
}
// Strip trailing DelimiterStatement (e.g., "DELIMITER ;\n") and any whitespace before it
for len(tokens) > 0 && (tokens[len(tokens)-1].Type == sqltoken.DelimiterStatement || tokens[len(tokens)-1].Type == sqltoken.Whitespace) {
if tokens[len(tokens)-1].Type == sqltoken.DelimiterStatement {
log.Debug("Stripping trailing DelimiterStatement from migration", map[string]any{
"name": m.Base().Name.Name,
"library": m.Base().Name.Library,
})
}
tokens = tokens[:len(tokens)-1]
}
// Strip trailing Delimiter (e.g., "//") and any whitespace before it
for len(tokens) > 0 && (tokens[len(tokens)-1].Type == sqltoken.Delimiter || tokens[len(tokens)-1].Type == sqltoken.Whitespace) {
if tokens[len(tokens)-1].Type == sqltoken.Delimiter {
log.Debug("Stripping trailing Delimiter from migration", map[string]any{
"name": m.Base().Name.Name,
"library": m.Base().Name.Library,
})
}
tokens = tokens[:len(tokens)-1]
}
if len(tokens) == 0 {
continue
}
commandSQL := tokens.String()
result, err := tx.ExecContext(ctx, commandSQL)
if d.Options.DebugLogging {
log.Debug("Executed SQL", map[string]any{
Expand Down
6 changes: 3 additions & 3 deletions lsmysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ func (p *MySQL) DoOneMigration(ctx context.Context, log *internal.Log, d *libsch
}
sqlText = genSQL
}
sqlText = strings.TrimSpace(sqlText)
m.Base().SetNote("sql", sqlText)
if sqlText == "" {
trimmedSQLText := strings.TrimSpace(sqlText)
m.Base().SetNote("sql", trimmedSQLText)
if trimmedSQLText == "" {
return nil
}
statements, err := classifysql.ClassifyTokens(p.dialect, 0, sqlText)
Expand Down
50 changes: 50 additions & 0 deletions lsmysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,53 @@ func testMysqlNotAllowed(t *testing.T, dsn string, createPostfix string, driverN
}
}
}

func TestMysqlMigrationWithDelimiter(t *testing.T) {
testMysqlOneMigration(t, `
DELIMITER //
CREATE PROCEDURE charge_account(IN id BIGINT, IN amount DECIMAL(18,4))
BEGIN
DECLARE balance DECIMAL(18,4);
SELECT remaining_balance INTO balance
FROM account_balance
WHERE account_id = id;
IF balance > amount THEN
UPDATE account_balance
SET remaining_balance = balance - amount
WHERE account_id = id;
END IF;
END //
DELIMITER ;
`)
}

func testMysqlOneMigration(t *testing.T, sqlText string) {
t.Parallel()
dsn := os.Getenv("LIBSCHEMA_MYSQL_TEST_DSN")
if dsn == "" {
t.Skip("Set $LIBSCHEMA_MYSQL_TEST_DSN to test libschema/lsmysql")
}
testOneMigration(t, dsn, sqlText, mysqlNew)
}

func testOneMigration(t *testing.T, dsn string, sqlText string, driverNew driverNew) {
options, cleanup := lstesting.FakeSchema(t, "")

t.Log("Doing migrations in database/schema", options.SchemaOverride)

options.DebugLogging = true
db, err := sql.Open("mysql", dsn)
require.NoError(t, err, "open database")
defer func() {
assert.NoError(t, db.Close())
}()
defer cleanup(db)

s := libschema.New(context.Background(), options)
dbase, _, err := driverNew(t, "test", s, db)
require.NoError(t, err, "libschema NewDatabase")

dbase.Migrations("L1", lsmysql.Script("M1", sqlText))
err = s.Migrate(context.Background())
assert.NoError(t, err)
}
21 changes: 21 additions & 0 deletions lsmysql/singlestore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,27 @@ func TestSingleStoreNotAllowed(t *testing.T) {
testMysqlNotAllowed(t, dsn, "", singleStoreNew)
}

func TestSingleStoreMigrationWithDelimiter(t *testing.T) {
testSingleStoreOneMigration(t, `
DELIMITER //
CREATE OR REPLACE PROCEDURE test_proc()
AS
BEGIN
ECHO SELECT 1;
END //
DELIMITER ;
`)
}

func testSingleStoreOneMigration(t *testing.T, sqlText string) {
t.Parallel()
dsn := os.Getenv("LIBSCHEMA_SINGLESTORE_TEST_DSN")
if dsn == "" {
t.Skip("Set $LIBSCHEMA_SINGLESTORE_TEST_DSN to test SingleStore support in libschema/lsmysql")
}
testOneMigration(t, dsn, sqlText, singleStoreNew)
}

func TestSingleStoreFailedMigration(t *testing.T) {
t.Parallel()
dsn := os.Getenv("LIBSCHEMA_SINGLESTORE_TEST_DSN")
Expand Down
2 changes: 1 addition & 1 deletion lspostgres/bad_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestBadMigrationsPostgres(t *testing.T) {
define: func(dbase *libschema.Database) {
dbase.Migrations("L9",
lspostgres.Script("T4", `CREATE TABLE T1 (id text)`),
lspostgres.Script("T5", ` -- just a comment`, libschema.RepeatUntilNoOp()),
lspostgres.Script("T5", ` -- just a comment`, libschema.RepeatUntilNoOp(), libschema.PreserveComments()),
)
},
},
Expand Down
3 changes: 2 additions & 1 deletion lspostgres/non_tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ func TestRowsAffectedErrorLogged(t *testing.T) {
require.NoError(t, err)
lib := fmt.Sprintf("RA_%d", time.Now().UnixNano())
comment := lspostgres.Script("NOTHING",
" -- just a comment")
" -- just a comment",
libschema.PreserveComments())
dbase.Migrations(lib, comment)
require.NoError(t, s.Migrate(ctx))
entries := capLog.Entries()
Expand Down
6 changes: 3 additions & 3 deletions lspostgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ func (p *Postgres) DoOneMigration(ctx context.Context, log *internal.Log, d *lib
}
scriptSQL = sqlText
}
scriptSQL = strings.TrimSpace(scriptSQL)
m.Base().SetNote("sql", scriptSQL)
if scriptSQL == "" {
trimmedScriptSQL := strings.TrimSpace(scriptSQL)
m.Base().SetNote("sql", trimmedScriptSQL)
if trimmedScriptSQL == "" {
return nil
}
// Classification & downgrade via classifysql
Expand Down
Loading