diff --git a/package_test.go b/package_test.go index 0d427786..436b634e 100644 --- a/package_test.go +++ b/package_test.go @@ -1118,18 +1118,11 @@ func (s *PackageSuite) TestTransactionErrors(c *C) { // Test error when running query after rollback against the public error variable. tx, err = db.Begin(ctx, nil) c.Assert(err, IsNil) - // Create Query. - q = tx.Query(ctx, insertStmt, &derek) - // Rollback. + err = tx.Rollback() c.Assert(err, IsNil) + err = tx.Query(ctx, insertStmt, &derek).Run() - // Check against sqlair package error. - if !errors.Is(err, sqlair.ErrTXDone) { - c.Errorf("expected %q, got %q", sqlair.ErrTXDone, err) - } - err = q.Run() - // Check against sql package error. if !errors.Is(err, sql.ErrTxDone) { c.Errorf("expected %q, got %q", sql.ErrTxDone, err) } diff --git a/sqlair.go b/sqlair.go index 0acb6628..ad23d0b1 100644 --- a/sqlair.go +++ b/sqlair.go @@ -8,7 +8,6 @@ import ( "database/sql" "fmt" "reflect" - "sync/atomic" "github.com/canonical/sqlair/internal/expr" ) @@ -33,7 +32,6 @@ type M map[string]any type S []any var ErrNoRows = sql.ErrNoRows -var ErrTXDone = sql.ErrTxDone // stmtCache stores the driver prepared statements associated to the SQLair // Statement objects. @@ -413,18 +411,6 @@ func (q *Query) GetAll(sliceArgs ...any) (err error) { type TX struct { sqltx *sql.Tx db *DB - done int32 -} - -func (tx *TX) isDone() bool { - return atomic.LoadInt32(&tx.done) == 1 -} - -func (tx *TX) setDone() error { - if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { - return ErrTXDone - } - return nil } // Begin starts a transaction. A transaction must be ended @@ -442,20 +428,12 @@ func (db *DB) Begin(ctx context.Context, opts *TXOptions) (*TX, error) { // Commit commits the transaction. func (tx *TX) Commit() error { - err := tx.setDone() - if err == nil { - err = tx.sqltx.Commit() - } - return err + return tx.sqltx.Commit() } // Rollback aborts the transaction. func (tx *TX) Rollback() error { - err := tx.setDone() - if err == nil { - err = tx.sqltx.Rollback() - } - return err + return tx.sqltx.Rollback() } // TXOptions holds the transaction options to be used in [DB.Begin]. @@ -483,9 +461,6 @@ func (tx *TX) Query(ctx context.Context, s *Statement, inputArgs ...any) *Query if ctx == nil { ctx = context.Background() } - if tx.isDone() { - return &Query{ctx: ctx, err: ErrTXDone} - } pq, err := s.te.BindInputs(inputArgs...) if err != nil {