Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ func (m *DbMap) createIndexImpl(ctx context.Context, dialect reflect.Type,
s.WriteString(" unique")
}
s.WriteString(" index")
s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName))
s.WriteString(fmt.Sprintf(
" %s on %s",
m.Dialect.QuoteField(index.IndexName),
m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName),
))
if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" {
s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType))
}
Expand Down Expand Up @@ -129,10 +133,14 @@ func (t *TableMap) DropIndex(ctx context.Context, name string) error {
for _, idx := range t.indexes {
if idx.IndexName == name {
s := bytes.Buffer{}
s.WriteString(fmt.Sprintf("DROP INDEX %s", idx.IndexName))
s.WriteString(fmt.Sprintf("DROP INDEX %s", t.dbmap.Dialect.QuoteField(idx.IndexName)))

if dname := dialect.Name(); dname == "MySQLDialect" {
s.WriteString(fmt.Sprintf(" %s %s", t.dbmap.Dialect.DropIndexSuffix(), t.TableName))
s.WriteString(fmt.Sprintf(
" %s %s",
t.dbmap.Dialect.DropIndexSuffix(),
t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName),
))
}
s.WriteString(";")
_, e := t.dbmap.ExecContext(ctx, s.String())
Expand Down
4 changes: 2 additions & 2 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,15 @@ func (d MySQLDialect) InsertAutoIncr(ctx context.Context, exec SqlExecutor, inse
}

func (d MySQLDialect) QuoteField(f string) string {
return "`" + f + "`"
return "`" + strings.ReplaceAll(f, "`", "``") + "`"
}

func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return d.QuoteField(table)
}

return schema + "." + d.QuoteField(table)
return d.QuoteField(schema) + "." + d.QuoteField(table)
}

func (d MySQLDialect) IfSchemaNotExists(command, schema string) string {
Expand Down
4 changes: 3 additions & 1 deletion dialect_mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ func TestMySQLDialect(t *testing.T) {

o.Spec("QuoteField", func(tcx testContext) {
tcx.expect(tcx.dialect.QuoteField("foo")).To(matchers.Equal("`foo`"))
tcx.expect(tcx.dialect.QuoteField("fo`o")).To(matchers.Equal("`fo``o`"))
})

o.Group("QuotedTableForQuery", func() {
Expand All @@ -149,7 +150,8 @@ func TestMySQLDialect(t *testing.T) {
})

o.Spec("with a supplied schema", func(tcx testContext) {
tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("foo.`bar`"))
tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("`foo`.`bar`"))
tcx.expect(tcx.dialect.QuotedTableForQuery("fo`o", "ba`r")).To(matchers.Equal("`fo``o`.`ba``r`"))
})
})

Expand Down
6 changes: 3 additions & 3 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,17 @@ func (d PostgresDialect) InsertAutoIncrToTarget(ctx context.Context, exec SqlExe

func (d PostgresDialect) QuoteField(f string) string {
if d.LowercaseFields {
return `"` + strings.ToLower(f) + `"`
f = strings.ToLower(f)
}
return `"` + f + `"`
return `"` + strings.ReplaceAll(f, `"`, `""`) + `"`
}

func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return d.QuoteField(table)
}

return schema + "." + d.QuoteField(table)
return d.QuoteField(schema) + "." + d.QuoteField(table)
}

func (d PostgresDialect) IfSchemaNotExists(command, schema string) string {
Expand Down
5 changes: 4 additions & 1 deletion dialect_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func TestPostgresDialect(t *testing.T) {
o.Spec("By default, case is preserved", func(tcx postgresTestContext) {
tcx.expect(tcx.dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`))
tcx.expect(tcx.dialect.QuoteField("bar")).To(matchers.Equal(`"bar"`))
tcx.expect(tcx.dialect.QuoteField(`Fo"o`)).To(matchers.Equal(`"Fo""o"`))
})

o.Group("With LowercaseFields set to true", func() {
Expand All @@ -130,6 +131,7 @@ func TestPostgresDialect(t *testing.T) {

o.Spec("fields are lowercased", func(tcx postgresTestContext) {
tcx.expect(tcx.dialect.QuoteField("Foo")).To(matchers.Equal(`"foo"`))
tcx.expect(tcx.dialect.QuoteField(`Fo"O`)).To(matchers.Equal(`"fo""o"`))
})
})
})
Expand All @@ -140,7 +142,8 @@ func TestPostgresDialect(t *testing.T) {
})

o.Spec("with a supplied schema", func(tcx postgresTestContext) {
tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`foo."bar"`))
tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`"foo"."bar"`))
tcx.expect(tcx.dialect.QuotedTableForQuery(`fo"o`, `ba"r`)).To(matchers.Equal(`"fo""o"."ba""r"`))
})
})

Expand Down
3 changes: 2 additions & 1 deletion dialect_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"fmt"
"reflect"
"strings"
)

type SqliteDialect struct {
Expand Down Expand Up @@ -92,7 +93,7 @@ func (d SqliteDialect) InsertAutoIncr(ctx context.Context, exec SqlExecutor, ins
}

func (d SqliteDialect) QuoteField(f string) string {
return `"` + f + `"`
return `"` + strings.ReplaceAll(f, `"`, `""`) + `"`
}

// sqlite does not have schemas like PostgreSQL does, so just escape it like normal
Expand Down
196 changes: 196 additions & 0 deletions identifier_quote_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package borp

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"sync"
"testing"

_ "github.com/mattn/go-sqlite3"
)

type identifierCapturedExec struct {
query string
args []driver.NamedValue
}

var identifierCaptureState = struct {
sync.Mutex
registered sync.Once
execs []identifierCapturedExec
}{}

type identifierCaptureDriver struct{}

func (identifierCaptureDriver) Open(string) (driver.Conn, error) {
return identifierCaptureConn{}, nil
}

type identifierCaptureConn struct{}

func (identifierCaptureConn) Prepare(string) (driver.Stmt, error) {
return nil, errors.New("identifier capture driver does not prepare statements")
}

func (identifierCaptureConn) Close() error {
return nil
}

func (identifierCaptureConn) Begin() (driver.Tx, error) {
return nil, errors.New("identifier capture driver does not begin transactions")
}

func (identifierCaptureConn) ExecContext(
_ context.Context,
query string,
args []driver.NamedValue,
) (driver.Result, error) {
argsCopy := append([]driver.NamedValue(nil), args...)

identifierCaptureState.Lock()
defer identifierCaptureState.Unlock()
identifierCaptureState.execs = append(identifierCaptureState.execs, identifierCapturedExec{
query: query,
args: argsCopy,
})
return driver.RowsAffected(0), nil
}

func newIdentifierCaptureDbMap(t *testing.T, dialect Dialect) *DbMap {
t.Helper()

identifierCaptureState.registered.Do(func() {
sql.Register("borp_identifier_capture", identifierCaptureDriver{})
})

identifierCaptureState.Lock()
identifierCaptureState.execs = nil
identifierCaptureState.Unlock()

db, err := sql.Open("borp_identifier_capture", "")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
closeErr := db.Close()
if closeErr != nil {
t.Fatal(closeErr)
}
})

return &DbMap{Db: db, Dialect: dialect}
}

func identifierCapturedExecs() []identifierCapturedExec {
identifierCaptureState.Lock()
defer identifierCaptureState.Unlock()
return append([]identifierCapturedExec(nil), identifierCaptureState.execs...)
}

func TestSqliteDialectEscapesIdentifierQuotes(t *testing.T) {
dialect := SqliteDialect{}
got := dialect.QuoteField(`fo"o`)
want := `"fo""o"`
if got != want {
t.Fatalf("QuoteField() = %q, want %q", got, want)
}
got = dialect.QuotedTableForQuery("", `ta"ble`)
want = `"ta""ble"`
if got != want {
t.Fatalf("QuotedTableForQuery() = %q, want %q", got, want)
}
}

func TestQuotedTableNameCannotRewriteUpdateTarget(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()

_, err = db.Exec("CREATE TABLE victim (id integer primary key, value text, admin integer)")
if err != nil {
t.Fatal(err)
}
_, err = db.Exec("INSERT INTO victim (id, value, admin) VALUES (1, 'unchanged', 0)")
if err != nil {
t.Fatal(err)
}

type row struct {
ID int64 `db:"ID"`
Value string `db:"Value"`
}

dbmap := &DbMap{Db: db, Dialect: SqliteDialect{}}
injectedTable := `victim" SET admin = 1 WHERE ? <> ? -- `
dbmap.AddTableWithName(row{}, injectedTable).SetKeys(false, "ID")

_, err = dbmap.Update(context.Background(), &row{ID: 1, Value: "unused"})
if err == nil {
t.Fatal("Update succeeded for escaped malicious table name")
}

var admin int
err = db.QueryRow("SELECT admin FROM victim WHERE id = 1").Scan(&admin)
if err != nil {
t.Fatal(err)
}
if admin != 0 {
t.Fatalf("victim.admin = %d, want 0", admin)
}
}

func TestCreateIndexQuotesIdentifierMetadata(t *testing.T) {
type indexedRow struct {
ID int64 `db:"ID"`
}

dbmap := newIdentifierCaptureDbMap(t, PostgresDialect{})
table := dbmap.AddTableWithNameAndSchema(indexedRow{}, `sche"ma`, `security"rows`)
table.SetKeys(false, "ID")
table.AddIndex(`idx"name`, "btree", []string{"ID"})

err := dbmap.CreateIndex(context.Background())
if err != nil {
t.Fatal(err)
}

execs := identifierCapturedExecs()
if len(execs) != 1 {
t.Fatalf("expected one captured exec, got %d: %+v", len(execs), execs)
}

want := `create index "idx""name" on "sche""ma"."security""rows" using btree ("ID");`
if execs[0].query != want {
t.Fatalf("generated %q, want %q", execs[0].query, want)
}
}

func TestDropIndexQuotesIdentifierMetadata(t *testing.T) {
type indexedRow struct {
ID int64 `db:"ID"`
}

dbmap := newIdentifierCaptureDbMap(t, MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"})
table := dbmap.AddTableWithNameAndSchema(indexedRow{}, "sche`ma", "security`rows")
table.SetKeys(false, "ID")
table.AddIndex("idx`name", "Btree", []string{"ID"})

err := table.DropIndex(context.Background(), "idx`name")
if err != nil {
t.Fatal(err)
}

execs := identifierCapturedExecs()
if len(execs) != 1 {
t.Fatalf("expected one captured exec, got %d: %+v", len(execs), execs)
}

want := "DROP INDEX `idx``name` on `sche``ma`.`security``rows`;"
if execs[0].query != want {
t.Fatalf("generated %q, want %q", execs[0].query, want)
}
}