@@ -12,32 +12,38 @@ import (
1212 "github.com/jackc/pgx/v5"
1313)
1414
15- // Table provides basic CRUD operations for database records.
16- // Records must implement GetID() and Validate() methods.
17- type Table [T any , PT interface {
15+ // ID is a comparable type used for record IDs.
16+ type ID comparable
17+
18+ // Records must be a pointer with the methods defined on the pointer.
19+ type Record [T any , I ID ] interface {
1820 * T // Enforce T is a pointer.
19- GetID () IDT
21+ GetID () I
2022 Validate () error
21- }, IDT comparable ] struct {
23+ }
24+
25+ // Table provides basic CRUD operations for database records.
26+ type Table [T any , P Record [T , I ], I ID ] struct {
2227 * DB
2328 Name string
2429 IDColumn string
2530}
2631
27- type hasSetCreatedAt interface {
28- SetCreatedAt (time.Time )
29- }
30-
31- type hasSetUpdatedAt interface {
32- SetUpdatedAt (time.Time )
33- }
34-
35- type hasSetDeletedAt interface {
36- SetDeletedAt (time.Time )
37- }
32+ // helpers for setting timestamp fields
33+ type (
34+ hasSetCreatedAt interface {
35+ SetCreatedAt (time.Time )
36+ }
37+ hasSetUpdatedAt interface {
38+ SetUpdatedAt (time.Time )
39+ }
40+ hasSetDeletedAt interface {
41+ SetDeletedAt (time.Time )
42+ }
43+ )
3844
3945// Save inserts or updates given records. Auto-detects insert vs update by ID based on zerovalue of ID from GetID() method on record.
40- func (t * Table [T , PT , IDT ]) Save (ctx context.Context , records ... PT ) error {
46+ func (t * Table [T , P , I ]) Save (ctx context.Context , records ... P ) error {
4147 switch len (records ) {
4248 case 0 :
4349 return nil
@@ -48,7 +54,7 @@ func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error {
4854 }
4955}
5056
51- func (t * Table [T , PT , IDT ]) saveOne (ctx context.Context , record PT ) error {
57+ func (t * Table [T , P , I ]) saveOne (ctx context.Context , record P ) error {
5258 if record == nil {
5359 return fmt .Errorf ("record is nil" )
5460 }
@@ -62,7 +68,7 @@ func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error {
6268 }
6369
6470 // Insert
65- var zero IDT
71+ var zero I
6672 if record .GetID () == zero {
6773 q := t .SQL .
6874 InsertRecord (record ).
@@ -87,10 +93,10 @@ func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error {
8793
8894const chunkSize = 1000
8995
90- func (t * Table [T , PT , IDT ]) saveAll (ctx context.Context , records []PT ) error {
96+ func (t * Table [T , P , I ]) saveAll (ctx context.Context , records []P ) error {
9197 now := time .Now ().UTC ()
9298
93- insertRecords := make ([]PT , 0 )
99+ insertRecords := make ([]P , 0 )
94100 insertIndices := make ([]int , 0 ) // keep track of original indices, so we can update the records with IDs in passed slice
95101
96102 updateQueries := make (Queries , 0 )
@@ -108,7 +114,7 @@ func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error {
108114 row .SetUpdatedAt (now )
109115 }
110116
111- var zero IDT
117+ var zero I
112118 if r .GetID () == zero {
113119 if row , ok := any (r ).(hasSetCreatedAt ); ok {
114120 row .SetCreatedAt (now )
@@ -159,7 +165,7 @@ func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error {
159165}
160166
161167// getListQuery builds a base select query for listing records.
162- func (t * Table [T , PT , IDT ]) getListQuery (where sq.Sqlizer , orderBy []string ) sq.SelectBuilder {
168+ func (t * Table [T , P , I ]) getListQuery (where sq.Sqlizer , orderBy []string ) sq.SelectBuilder {
163169 if len (orderBy ) == 0 {
164170 orderBy = []string {t .IDColumn }
165171 }
@@ -173,7 +179,7 @@ func (t *Table[T, PT, IDT]) getListQuery(where sq.Sqlizer, orderBy []string) sq.
173179}
174180
175181// Get returns the first record matching the condition.
176- func (t * Table [T , PT , IDT ]) Get (ctx context.Context , where sq.Sqlizer , orderBy []string ) (PT , error ) {
182+ func (t * Table [T , P , I ]) Get (ctx context.Context , where sq.Sqlizer , orderBy []string ) (P , error ) {
177183 record := new (T )
178184
179185 q := t .getListQuery (where , orderBy ).Limit (1 )
@@ -186,9 +192,9 @@ func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy [
186192}
187193
188194// List returns all records matching the condition.
189- func (t * Table [T , PT , IDT ]) List (ctx context.Context , where sq.Sqlizer , orderBy []string ) ([]PT , error ) {
195+ func (t * Table [T , P , I ]) List (ctx context.Context , where sq.Sqlizer , orderBy []string ) ([]P , error ) {
190196 q := t .getListQuery (where , orderBy )
191- var records []PT
197+ var records []P
192198 if err := t .Query .GetAll (ctx , q , & records ); err != nil {
193199 return nil , err
194200 }
@@ -197,14 +203,14 @@ func (t *Table[T, PT, IDT]) List(ctx context.Context, where sq.Sqlizer, orderBy
197203}
198204
199205// Iter returns an iterator for records matching the condition.
200- func (t * Table [T , PT , IDT ]) Iter (ctx context.Context , where sq.Sqlizer , orderBy []string ) (iter.Seq2 [PT , error ], error ) {
206+ func (t * Table [T , P , I ]) Iter (ctx context.Context , where sq.Sqlizer , orderBy []string ) (iter.Seq2 [P , error ], error ) {
201207 q := t .getListQuery (where , orderBy )
202208 rows , err := t .Query .QueryRows (ctx , q )
203209 if err != nil {
204210 return nil , fmt .Errorf ("query rows: %w" , err )
205211 }
206212
207- return func (yield func (PT , error ) bool ) {
213+ return func (yield func (P , error ) bool ) {
208214 defer rows .Close ()
209215 for rows .Next () {
210216 var record T
@@ -222,17 +228,17 @@ func (t *Table[T, PT, IDT]) Iter(ctx context.Context, where sq.Sqlizer, orderBy
222228}
223229
224230// GetByID returns a record by its ID.
225- func (t * Table [T , PT , IDT ]) GetByID (ctx context.Context , id IDT ) (PT , error ) {
231+ func (t * Table [T , P , I ]) GetByID (ctx context.Context , id I ) (P , error ) {
226232 return t .Get (ctx , sq.Eq {t .IDColumn : id }, []string {t .IDColumn })
227233}
228234
229235// ListByIDs returns records by their IDs.
230- func (t * Table [T , PT , IDT ]) ListByIDs (ctx context.Context , ids []IDT ) ([]PT , error ) {
236+ func (t * Table [T , P , I ]) ListByIDs (ctx context.Context , ids []I ) ([]P , error ) {
231237 return t .List (ctx , sq.Eq {t .IDColumn : ids }, nil )
232238}
233239
234240// Count returns the number of matching records.
235- func (t * Table [T , PT , IDT ]) Count (ctx context.Context , where sq.Sqlizer ) (uint64 , error ) {
241+ func (t * Table [T , P , I ]) Count (ctx context.Context , where sq.Sqlizer ) (uint64 , error ) {
236242 var count uint64
237243 q := t .SQL .
238244 Select ("COUNT(1)" ).
@@ -247,7 +253,7 @@ func (t *Table[T, PT, IDT]) Count(ctx context.Context, where sq.Sqlizer) (uint64
247253}
248254
249255// DeleteByID deletes a record by ID. Uses soft delete if .SetDeletedAt() method exists.
250- func (t * Table [T , PT , IDT ]) DeleteByID (ctx context.Context , id IDT ) error {
256+ func (t * Table [T , P , I ]) DeleteByID (ctx context.Context , id I ) error {
251257 record , err := t .GetByID (ctx , id )
252258 if err != nil {
253259 return fmt .Errorf ("delete: %w" , err )
@@ -267,7 +273,7 @@ func (t *Table[T, PT, IDT]) DeleteByID(ctx context.Context, id IDT) error {
267273}
268274
269275// HardDeleteByID permanently deletes a record by ID.
270- func (t * Table [T , PT , IDT ]) HardDeleteByID (ctx context.Context , id IDT ) error {
276+ func (t * Table [T , P , I ]) HardDeleteByID (ctx context.Context , id I ) error {
271277 q := t .SQL .Delete (t .Name ).Where (sq.Eq {t .IDColumn : id })
272278 if _ , err := t .Query .Exec (ctx , q ); err != nil {
273279 return fmt .Errorf ("hard delete: %w" , err )
@@ -276,8 +282,8 @@ func (t *Table[T, PT, IDT]) HardDeleteByID(ctx context.Context, id IDT) error {
276282}
277283
278284// WithTx returns a table instance bound to the given transaction.
279- func (t * Table [T , PT , IDT ]) WithTx (tx pgx.Tx ) * Table [T , PT , IDT ] {
280- return & Table [T , PT , IDT ]{
285+ func (t * Table [T , P , I ]) WithTx (tx pgx.Tx ) * Table [T , P , I ] {
286+ return & Table [T , P , I ]{
281287 DB : & DB {
282288 Conn : t .DB .Conn ,
283289 SQL : t .DB .SQL ,
@@ -296,10 +302,10 @@ func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] {
296302// to update status to "completed" or "failed".
297303//
298304// Returns ErrNoRows if no matching records are available for locking.
299- func (t * Table [T , PT , IDT ]) LockForUpdate (ctx context.Context , where sq.Sqlizer , orderBy []string , updateFn func (record PT )) error {
305+ func (t * Table [T , P , I ]) LockForUpdate (ctx context.Context , where sq.Sqlizer , orderBy []string , updateFn func (record P )) error {
300306 var noRows bool
301307
302- err := t .LockForUpdates (ctx , where , orderBy , 1 , func (records []PT ) {
308+ err := t .LockForUpdates (ctx , where , orderBy , 1 , func (records []P ) {
303309 if len (records ) > 0 {
304310 updateFn (records [0 ])
305311 } else {
@@ -324,7 +330,7 @@ func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer,
324330// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status
325331// to "processing" and return early, then process asynchronously. Use defer LockForUpdate()
326332// to update status to "completed" or "failed".
327- func (t * Table [T , PT , IDT ]) LockForUpdates (ctx context.Context , where sq.Sqlizer , orderBy []string , limit uint64 , updateFn func (records []PT )) error {
333+ func (t * Table [T , P , I ]) LockForUpdates (ctx context.Context , where sq.Sqlizer , orderBy []string , limit uint64 , updateFn func (records []P )) error {
328334 // Check if we're already in a transaction
329335 if t .DB .Query .Tx != nil {
330336 if err := t .lockForUpdatesWithTx (ctx , t .DB .Query .Tx , where , orderBy , limit , updateFn ); err != nil {
@@ -340,7 +346,7 @@ func (t *Table[T, PT, IDT]) LockForUpdates(ctx context.Context, where sq.Sqlizer
340346 })
341347}
342348
343- func (t * Table [T , PT , IDT ]) lockForUpdatesWithTx (ctx context.Context , pgTx pgx.Tx , where sq.Sqlizer , orderBy []string , limit uint64 , updateFn func (records []PT )) error {
349+ func (t * Table [T , P , I ]) lockForUpdatesWithTx (ctx context.Context , pgTx pgx.Tx , where sq.Sqlizer , orderBy []string , limit uint64 , updateFn func (records []P )) error {
344350 if len (orderBy ) == 0 {
345351 orderBy = []string {t .IDColumn }
346352 }
@@ -355,7 +361,7 @@ func (t *Table[T, PT, IDT]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.T
355361
356362 txQuery := t .DB .TxQuery (pgTx )
357363
358- var records []PT
364+ var records []P
359365 if err := txQuery .GetAll (ctx , q , & records ); err != nil {
360366 return fmt .Errorf ("select for update skip locked: %w" , err )
361367 }
0 commit comments