Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions package_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1118,18 +1118,11 @@ func (s *PackageSuite) TestTransactionErrors(c *C) {
// Test error when running query after rollback against the public error variable.
tx, err = db.Begin(ctx, nil)
c.Assert(err, IsNil)
// Create Query.
q = tx.Query(ctx, insertStmt, &derek)
// Rollback.

err = tx.Rollback()
c.Assert(err, IsNil)

err = tx.Query(ctx, insertStmt, &derek).Run()
// Check against sqlair package error.
if !errors.Is(err, sqlair.ErrTXDone) {
c.Errorf("expected %q, got %q", sqlair.ErrTXDone, err)
}
err = q.Run()
// Check against sql package error.
if !errors.Is(err, sql.ErrTxDone) {
c.Errorf("expected %q, got %q", sql.ErrTxDone, err)
}
Expand Down
29 changes: 2 additions & 27 deletions sqlair.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"database/sql"
"fmt"
"reflect"
"sync/atomic"

"github.com/canonical/sqlair/internal/expr"
)
Expand All @@ -33,7 +32,6 @@ type M map[string]any
type S []any

var ErrNoRows = sql.ErrNoRows
var ErrTXDone = sql.ErrTxDone

// stmtCache stores the driver prepared statements associated to the SQLair
// Statement objects.
Expand Down Expand Up @@ -413,18 +411,6 @@ func (q *Query) GetAll(sliceArgs ...any) (err error) {
type TX struct {
sqltx *sql.Tx
db *DB
done int32
}

func (tx *TX) isDone() bool {
return atomic.LoadInt32(&tx.done) == 1
}

func (tx *TX) setDone() error {
if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
return ErrTXDone
}
return nil
}

// Begin starts a transaction. A transaction must be ended
Expand All @@ -442,20 +428,12 @@ func (db *DB) Begin(ctx context.Context, opts *TXOptions) (*TX, error) {

// Commit commits the transaction.
func (tx *TX) Commit() error {
err := tx.setDone()
if err == nil {
err = tx.sqltx.Commit()
}
return err
return tx.sqltx.Commit()
}

// Rollback aborts the transaction.
func (tx *TX) Rollback() error {
err := tx.setDone()
if err == nil {
err = tx.sqltx.Rollback()
}
return err
return tx.sqltx.Rollback()
}

// TXOptions holds the transaction options to be used in [DB.Begin].
Expand Down Expand Up @@ -483,9 +461,6 @@ func (tx *TX) Query(ctx context.Context, s *Statement, inputArgs ...any) *Query
if ctx == nil {
ctx = context.Background()
}
if tx.isDone() {
return &Query{ctx: ctx, err: ErrTXDone}
}

pq, err := s.te.BindInputs(inputArgs...)
if err != nil {
Expand Down
Loading