diff --git a/table_bindings.go b/table_bindings.go index 6c2f146..1b2757c 100644 --- a/table_bindings.go +++ b/table_bindings.go @@ -163,62 +163,70 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { func (t *TableMap) bindUpdate(elem reflect.Value, colFilter ColumnFilter) (bindInstance, error) { if colFilter == nil { - colFilter = acceptAllFilter + plan := &t.updatePlan + plan.once.Do(func() { + t.buildUpdatePlan(plan, acceptAllFilter) + }) + return plan.createBindInstance(elem, t.dbmap.TypeConverter) } - plan := &t.updatePlan - plan.once.Do(func() { - s := bytes.Buffer{} - s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) - x := 0 - - for y := range t.Columns { - col := t.Columns[y] - if !col.isAutoIncr && !col.Transient && colFilter(col) { - if x > 0 { - s.WriteString(", ") - } - s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) - s.WriteString("=") - s.WriteString(t.dbmap.Dialect.BindVar(x)) + // ColumnFilter functions are caller-provided and do not have a stable cache key. + // Build filtered plans per call so UpdateColumns always reflects this invocation. + plan := &bindPlan{} + t.buildUpdatePlan(plan, colFilter) + return plan.createBindInstance(elem, t.dbmap.TypeConverter) +} - if col == t.version { - plan.versField = col.fieldName - plan.argFields = append(plan.argFields, versFieldConst) - } else { - plan.argFields = append(plan.argFields, col.fieldName) - } - x++ - } - } +func (t *TableMap) buildUpdatePlan(plan *bindPlan, colFilter ColumnFilter) { + s := bytes.Buffer{} + tableName := t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName) + s.WriteString(fmt.Sprintf("update %s set ", tableName)) + x := 0 - s.WriteString(" where ") - for y := range t.keys { - col := t.keys[y] - if y > 0 { - s.WriteString(" and ") + for y := range t.Columns { + col := t.Columns[y] + if !col.isAutoIncr && !col.Transient && colFilter(col) { + if x > 0 { + s.WriteString(", ") } s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) s.WriteString("=") s.WriteString(t.dbmap.Dialect.BindVar(x)) - plan.argFields = append(plan.argFields, col.fieldName) - plan.keyFields = append(plan.keyFields, col.fieldName) + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst) + } else { + plan.argFields = append(plan.argFields, col.fieldName) + } x++ } - if plan.versField != "" { + } + + s.WriteString(" where ") + for y := range t.keys { + col := t.keys[y] + if y > 0 { s.WriteString(" and ") - s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName)) - s.WriteString("=") - s.WriteString(t.dbmap.Dialect.BindVar(x)) - plan.argFields = append(plan.argFields, plan.versField) } - s.WriteString(t.dbmap.Dialect.QuerySuffix()) + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) - plan.query = s.String() - }) + plan.argFields = append(plan.argFields, col.fieldName) + plan.keyFields = append(plan.keyFields, col.fieldName) + x++ + } + if plan.versField != "" { + s.WriteString(" and ") + s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + plan.argFields = append(plan.argFields, plan.versField) + } + s.WriteString(t.dbmap.Dialect.QuerySuffix()) - return plan.createBindInstance(elem, t.dbmap.TypeConverter) + plan.query = s.String() } func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) { diff --git a/update_columns_test.go b/update_columns_test.go new file mode 100644 index 0000000..3948f19 --- /dev/null +++ b/update_columns_test.go @@ -0,0 +1,210 @@ +package borp + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "reflect" + "sync" + "testing" +) + +type updateColumnsSecurityRow struct { + ID int64 + PublicNote string + Admin bool + Balance int64 +} + +type capturedExec struct { + query string + args []driver.NamedValue +} + +var captureDriverState = struct { + sync.Mutex + registered sync.Once + execs []capturedExec +}{} + +type captureDriver struct{} + +func (captureDriver) Open(string) (driver.Conn, error) { + return captureConn{}, nil +} + +type captureConn struct{} + +func (captureConn) Prepare(string) (driver.Stmt, error) { + return nil, errors.New("capture driver does not prepare statements") +} + +func (captureConn) Close() error { + return nil +} + +func (captureConn) Begin() (driver.Tx, error) { + return nil, errors.New("capture driver does not begin transactions") +} + +func (captureConn) ExecContext( + _ context.Context, + query string, + args []driver.NamedValue, +) (driver.Result, error) { + copied := append([]driver.NamedValue(nil), args...) + captureDriverState.Lock() + captureDriverState.execs = append(captureDriverState.execs, capturedExec{ + query: query, + args: copied, + }) + captureDriverState.Unlock() + return driver.RowsAffected(1), nil +} + +func newCaptureDbMap(t *testing.T) *DbMap { + t.Helper() + captureDriverState.registered.Do(func() { + sql.Register("borp_capture", captureDriver{}) + }) + captureDriverState.Lock() + captureDriverState.execs = nil + captureDriverState.Unlock() + + db, err := sql.Open("borp_capture", "") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + err := db.Close() + if err != nil { + t.Error(err) + } + }) + + dbmap := &DbMap{Db: db, Dialect: SqliteDialect{}} + dbmap.AddTableWithName(updateColumnsSecurityRow{}, "security_rows").SetKeys(true, "ID") + return dbmap +} + +func captureExecs() []capturedExec { + captureDriverState.Lock() + defer captureDriverState.Unlock() + return append([]capturedExec(nil), captureDriverState.execs...) +} + +func capturedExecValues(exec capturedExec) []interface{} { + values := make([]interface{}, len(exec.args)) + for i, arg := range exec.args { + values[i] = arg.Value + } + return values +} + +func requireCapturedExec( + t *testing.T, + got capturedExec, + wantQuery string, + wantValues []interface{}, +) { + t.Helper() + + if got.query != wantQuery { + t.Fatalf("generated %q, want %q; args: %+v", got.query, wantQuery, got.args) + } + + values := capturedExecValues(got) + if !reflect.DeepEqual(values, wantValues) { + t.Fatalf("bound values = %#v, want %#v; query: %s", values, wantValues, got.query) + } +} + +func TestUpdateColumnsDoesNotReusePriorFullUpdatePlan(t *testing.T) { + dbmap := newCaptureDbMap(t) + ctx := context.Background() + + _, err := dbmap.Update(ctx, &updateColumnsSecurityRow{ + ID: 7, + PublicNote: "initial", + Admin: false, + Balance: 10, + }) + if err != nil { + t.Fatal(err) + } + + onlyPublicNote := func(col *ColumnMap) bool { + return col.ColumnName == "PublicNote" + } + _, err = dbmap.UpdateColumns(ctx, onlyPublicNote, &updateColumnsSecurityRow{ + ID: 7, + PublicNote: "attacker controlled note", + Admin: true, + Balance: 999, + }) + if err != nil { + t.Fatal(err) + } + + execs := captureExecs() + if len(execs) != 2 { + t.Fatalf("expected two captured execs, got %d: %+v", len(execs), execs) + } + + got := execs[1] + wantQuery := `update "security_rows" set "PublicNote"=? where "ID"=?;` + requireCapturedExec(t, got, wantQuery, []interface{}{ + "attacker controlled note", + int64(7), + }) +} + +func TestUpdateColumnsDoesNotReusePriorFilteredUpdatePlan(t *testing.T) { + dbmap := newCaptureDbMap(t) + ctx := context.Background() + + adminAndBalance := func(col *ColumnMap) bool { + return col.ColumnName == "Admin" || col.ColumnName == "Balance" + } + _, err := dbmap.UpdateColumns(ctx, adminAndBalance, &updateColumnsSecurityRow{ + ID: 7, + PublicNote: "not updated", + Admin: true, + Balance: 999, + }) + if err != nil { + t.Fatal(err) + } + + onlyPublicNote := func(col *ColumnMap) bool { + return col.ColumnName == "PublicNote" + } + _, err = dbmap.UpdateColumns(ctx, onlyPublicNote, &updateColumnsSecurityRow{ + ID: 7, + PublicNote: "second filtered update", + Admin: false, + Balance: 10, + }) + if err != nil { + t.Fatal(err) + } + + execs := captureExecs() + if len(execs) != 2 { + t.Fatalf("expected two captured execs, got %d: %+v", len(execs), execs) + } + + requireCapturedExec( + t, + execs[0], + `update "security_rows" set "Admin"=?, "Balance"=? where "ID"=?;`, + []interface{}{true, int64(999), int64(7)}, + ) + requireCapturedExec( + t, + execs[1], + `update "security_rows" set "PublicNote"=? where "ID"=?;`, + []interface{}{"second filtered update", int64(7)}, + ) +}