Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions cgosqlite/cgosqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ package cgosqlite
*/
import "C"
import (
"expvar"
"sync"
"sync/atomic"
"time"
Expand All @@ -64,6 +65,12 @@ var emptyChar [1]C.char

var alwaysCopyBlob atomic.Bool

// UsesAfterClose counts the number of times an operation was attempted on a
// connection after it was closed.
//
// This is reassigned by the sqlite package in init.
var UsesAfterClose = new(expvar.Map)

// SetAlwaysCopyBlob sets whether [Stmt.ColumnBlob] should copy the blob data
// instead of returning a slice that aliases SQLite's internal memory. This is
// safe to call at runtime; the setting will apply to subsequent calls to
Expand All @@ -84,14 +91,17 @@ func init() {
type DB struct {
db *C.sqlite3

checkpointing atomic.Int32

declTypes map[string]string
}

// Stmt implements sqliteh.Stmt.
type Stmt struct {
db *DB
stmt *C.sqlite3_stmt
start C.struct_timespec
db *DB
stmt *C.sqlite3_stmt
start C.struct_timespec
finalized atomic.Bool

// used as scratch space when calling into cgo
rowid, changes C.sqlite3_int64
Expand Down Expand Up @@ -154,6 +164,11 @@ func (db *DB) BusyTimeout(d time.Duration) {
}

func (db *DB) Checkpoint(dbName string, mode sqliteh.Checkpoint) (int, int, error) {
depth := db.checkpointing.Add(1)
defer db.checkpointing.Add(-1)
if depth > 1 {
UsesAfterClose.Add("nested-checkpoint", 1)
}
var cDB *C.char
if dbName != "" {
// Docs say: "If parameter zDb is NULL or points to a zero length string",
Expand Down Expand Up @@ -192,6 +207,9 @@ func (db *DB) TxnState(schema string) sqliteh.TxnState {
func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqliteh.Stmt, remainingQuery string, err error) {
csql := C.CString(query)
defer C.free(unsafe.Pointer(csql))
if db.checkpointing.Load() > 0 {
UsesAfterClose.Add("prepare-during-checkpoint", 1)
}

var cstmt *C.sqlite3_stmt
var csqlTail *C.char
Expand Down Expand Up @@ -234,14 +252,25 @@ func (stmt *Stmt) Reset() error {
}

func (stmt *Stmt) Finalize() error {
stmt.finalized.Store(true)
return errCode(C.sqlite3_finalize(stmt.stmt))
}

func (stmt *Stmt) ClearBindings() error {
if stmt.finalized.Load() {
UsesAfterClose.Add("clear-bindings-after-finalize", 1)
}
return errCode(C.sqlite3_clear_bindings(stmt.stmt))
}

func (stmt *Stmt) ResetAndClear() (time.Duration, error) {
if stmt.finalized.Load() {
UsesAfterClose.Add("reset-and-clear-after-finalize", 1)
}
if stmt.db.checkpointing.Load() > 0 {
UsesAfterClose.Add("resetandclear-during-checkpoint", 1)
}

if stmt.start != (C.struct_timespec{}) {
stmt.duration = 0
err := errCode(C.reset_and_clear(stmt.stmt, &stmt.start, &stmt.duration))
Expand All @@ -261,14 +290,23 @@ func (stmt *Stmt) StartTimer() {
}

func (stmt *Stmt) ColumnDatabaseName(col int) string {
if stmt.finalized.Load() {
UsesAfterClose.Add("col-db-name-after-finalize", 1)
}
return C.GoString(C.sqlite3_column_database_name(stmt.stmt, C.int(col)))
}

func (stmt *Stmt) ColumnTableName(col int) string {
if stmt.finalized.Load() {
UsesAfterClose.Add("col-table-name-after-finalize", 1)
}
return C.GoString(C.sqlite3_column_table_name(stmt.stmt, C.int(col)))
}

func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) {
if stmt.finalized.Load() {
UsesAfterClose.Add("step-after-finalize", 1)
}
var ptr *C.char
if len(colType) > 0 {
ptr = (*C.char)(unsafe.Pointer(&colType[0]))
Expand All @@ -285,6 +323,12 @@ func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) {
}

func (stmt *Stmt) StepResult() (row bool, lastInsertRowID, changes int64, d time.Duration, err error) {
if stmt.finalized.Load() {
UsesAfterClose.Add("stepresult-after-finalize", 1)
}
if stmt.db.checkpointing.Load() > 0 {
UsesAfterClose.Add("stepresult-during-checkpoint", 1)
}
stmt.rowid, stmt.changes, stmt.duration = 0, 0, 0
res := C.step_result(stmt.stmt, &stmt.rowid, &stmt.changes, &stmt.duration)
lastInsertRowID = int64(stmt.rowid)
Expand Down
2 changes: 2 additions & 0 deletions sqlite_cgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (

func init() {
Open = cgosqlite.Open

cgosqlite.UsesAfterClose = &UsesAfterClose
}

// SetLogCallback sets the global SQLite log callback.
Expand Down
Loading