From 36967228f2077b63f64a8bdaadac4cb81367fc4b Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 21 Jan 2026 15:52:08 -0800 Subject: [PATCH] cgosqlite: add some paranoia counters Just in case database/sql or layers above this package call it in unexpected ways, add some extra paranoia checks that callers are holding the API correctly. Updates tailscale/corp#35671 Signed-off-by: Brad Fitzpatrick --- cgosqlite/cgosqlite.go | 50 +++++++++++++++++++++++++++++++++++++++--- sqlite_cgo.go | 2 ++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index bfcc095..c998173 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -50,6 +50,7 @@ package cgosqlite */ import "C" import ( + "expvar" "sync" "sync/atomic" "time" @@ -64,6 +65,12 @@ var emptyChar [1]C.char var alwaysCopyBlob atomic.Bool +// UsesAfterClose counts the number of times an operation was attempted on a +// connection after it was closed. +// +// This is reassigned by the sqlite package in init. +var UsesAfterClose = new(expvar.Map) + // SetAlwaysCopyBlob sets whether [Stmt.ColumnBlob] should copy the blob data // instead of returning a slice that aliases SQLite's internal memory. This is // safe to call at runtime; the setting will apply to subsequent calls to @@ -84,14 +91,17 @@ func init() { type DB struct { db *C.sqlite3 + checkpointing atomic.Int32 + declTypes map[string]string } // Stmt implements sqliteh.Stmt. type Stmt struct { - db *DB - stmt *C.sqlite3_stmt - start C.struct_timespec + db *DB + stmt *C.sqlite3_stmt + start C.struct_timespec + finalized atomic.Bool // used as scratch space when calling into cgo rowid, changes C.sqlite3_int64 @@ -154,6 +164,11 @@ func (db *DB) BusyTimeout(d time.Duration) { } func (db *DB) Checkpoint(dbName string, mode sqliteh.Checkpoint) (int, int, error) { + depth := db.checkpointing.Add(1) + defer db.checkpointing.Add(-1) + if depth > 1 { + UsesAfterClose.Add("nested-checkpoint", 1) + } var cDB *C.char if dbName != "" { // Docs say: "If parameter zDb is NULL or points to a zero length string", @@ -192,6 +207,9 @@ func (db *DB) TxnState(schema string) sqliteh.TxnState { func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqliteh.Stmt, remainingQuery string, err error) { csql := C.CString(query) defer C.free(unsafe.Pointer(csql)) + if db.checkpointing.Load() > 0 { + UsesAfterClose.Add("prepare-during-checkpoint", 1) + } var cstmt *C.sqlite3_stmt var csqlTail *C.char @@ -234,14 +252,25 @@ func (stmt *Stmt) Reset() error { } func (stmt *Stmt) Finalize() error { + stmt.finalized.Store(true) return errCode(C.sqlite3_finalize(stmt.stmt)) } func (stmt *Stmt) ClearBindings() error { + if stmt.finalized.Load() { + UsesAfterClose.Add("clear-bindings-after-finalize", 1) + } return errCode(C.sqlite3_clear_bindings(stmt.stmt)) } func (stmt *Stmt) ResetAndClear() (time.Duration, error) { + if stmt.finalized.Load() { + UsesAfterClose.Add("reset-and-clear-after-finalize", 1) + } + if stmt.db.checkpointing.Load() > 0 { + UsesAfterClose.Add("resetandclear-during-checkpoint", 1) + } + if stmt.start != (C.struct_timespec{}) { stmt.duration = 0 err := errCode(C.reset_and_clear(stmt.stmt, &stmt.start, &stmt.duration)) @@ -261,14 +290,23 @@ func (stmt *Stmt) StartTimer() { } func (stmt *Stmt) ColumnDatabaseName(col int) string { + if stmt.finalized.Load() { + UsesAfterClose.Add("col-db-name-after-finalize", 1) + } return C.GoString(C.sqlite3_column_database_name(stmt.stmt, C.int(col))) } func (stmt *Stmt) ColumnTableName(col int) string { + if stmt.finalized.Load() { + UsesAfterClose.Add("col-table-name-after-finalize", 1) + } return C.GoString(C.sqlite3_column_table_name(stmt.stmt, C.int(col))) } func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) { + if stmt.finalized.Load() { + UsesAfterClose.Add("step-after-finalize", 1) + } var ptr *C.char if len(colType) > 0 { ptr = (*C.char)(unsafe.Pointer(&colType[0])) @@ -285,6 +323,12 @@ func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) { } func (stmt *Stmt) StepResult() (row bool, lastInsertRowID, changes int64, d time.Duration, err error) { + if stmt.finalized.Load() { + UsesAfterClose.Add("stepresult-after-finalize", 1) + } + if stmt.db.checkpointing.Load() > 0 { + UsesAfterClose.Add("stepresult-during-checkpoint", 1) + } stmt.rowid, stmt.changes, stmt.duration = 0, 0, 0 res := C.step_result(stmt.stmt, &stmt.rowid, &stmt.changes, &stmt.duration) lastInsertRowID = int64(stmt.rowid) diff --git a/sqlite_cgo.go b/sqlite_cgo.go index b4e903d..a8a9664 100644 --- a/sqlite_cgo.go +++ b/sqlite_cgo.go @@ -8,6 +8,8 @@ import ( func init() { Open = cgosqlite.Open + + cgosqlite.UsesAfterClose = &UsesAfterClose } // SetLogCallback sets the global SQLite log callback.