Skip to content

Commit b2ce7a7

Browse files
committed
Rename types for clarity
1 parent 320b0f3 commit b2ce7a7

1 file changed

Lines changed: 46 additions & 40 deletions

File tree

table.go

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,38 @@ import (
1212
"github.com/jackc/pgx/v5"
1313
)
1414

15-
// Table provides basic CRUD operations for database records.
16-
// Records must implement GetID() and Validate() methods.
17-
type Table[T any, PT interface {
15+
// ID is a comparable type used for record IDs.
16+
type ID comparable
17+
18+
// Records must be a pointer with the methods defined on the pointer.
19+
type Record[T any, I ID] interface {
1820
*T // Enforce T is a pointer.
19-
GetID() IDT
21+
GetID() I
2022
Validate() error
21-
}, IDT comparable] struct {
23+
}
24+
25+
// Table provides basic CRUD operations for database records.
26+
type Table[T any, P Record[T, I], I ID] struct {
2227
*DB
2328
Name string
2429
IDColumn string
2530
}
2631

27-
type hasSetCreatedAt interface {
28-
SetCreatedAt(time.Time)
29-
}
30-
31-
type hasSetUpdatedAt interface {
32-
SetUpdatedAt(time.Time)
33-
}
34-
35-
type hasSetDeletedAt interface {
36-
SetDeletedAt(time.Time)
37-
}
32+
// helpers for setting timestamp fields
33+
type (
34+
hasSetCreatedAt interface {
35+
SetCreatedAt(time.Time)
36+
}
37+
hasSetUpdatedAt interface {
38+
SetUpdatedAt(time.Time)
39+
}
40+
hasSetDeletedAt interface {
41+
SetDeletedAt(time.Time)
42+
}
43+
)
3844

3945
// Save inserts or updates given records. Auto-detects insert vs update by ID based on zerovalue of ID from GetID() method on record.
40-
func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error {
46+
func (t *Table[T, P, I]) Save(ctx context.Context, records ...P) error {
4147
switch len(records) {
4248
case 0:
4349
return nil
@@ -48,7 +54,7 @@ func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error {
4854
}
4955
}
5056

51-
func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error {
57+
func (t *Table[T, P, I]) saveOne(ctx context.Context, record P) error {
5258
if record == nil {
5359
return fmt.Errorf("record is nil")
5460
}
@@ -62,7 +68,7 @@ func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error {
6268
}
6369

6470
// Insert
65-
var zero IDT
71+
var zero I
6672
if record.GetID() == zero {
6773
q := t.SQL.
6874
InsertRecord(record).
@@ -87,10 +93,10 @@ func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error {
8793

8894
const chunkSize = 1000
8995

90-
func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error {
96+
func (t *Table[T, P, I]) saveAll(ctx context.Context, records []P) error {
9197
now := time.Now().UTC()
9298

93-
insertRecords := make([]PT, 0)
99+
insertRecords := make([]P, 0)
94100
insertIndices := make([]int, 0) // keep track of original indices, so we can update the records with IDs in passed slice
95101

96102
updateQueries := make(Queries, 0)
@@ -108,7 +114,7 @@ func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error {
108114
row.SetUpdatedAt(now)
109115
}
110116

111-
var zero IDT
117+
var zero I
112118
if r.GetID() == zero {
113119
if row, ok := any(r).(hasSetCreatedAt); ok {
114120
row.SetCreatedAt(now)
@@ -159,7 +165,7 @@ func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error {
159165
}
160166

161167
// getListQuery builds a base select query for listing records.
162-
func (t *Table[T, PT, IDT]) getListQuery(where sq.Sqlizer, orderBy []string) sq.SelectBuilder {
168+
func (t *Table[T, P, I]) getListQuery(where sq.Sqlizer, orderBy []string) sq.SelectBuilder {
163169
if len(orderBy) == 0 {
164170
orderBy = []string{t.IDColumn}
165171
}
@@ -173,7 +179,7 @@ func (t *Table[T, PT, IDT]) getListQuery(where sq.Sqlizer, orderBy []string) sq.
173179
}
174180

175181
// Get returns the first record matching the condition.
176-
func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy []string) (PT, error) {
182+
func (t *Table[T, P, I]) Get(ctx context.Context, where sq.Sqlizer, orderBy []string) (P, error) {
177183
record := new(T)
178184

179185
q := t.getListQuery(where, orderBy).Limit(1)
@@ -186,9 +192,9 @@ func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy [
186192
}
187193

188194
// List returns all records matching the condition.
189-
func (t *Table[T, PT, IDT]) List(ctx context.Context, where sq.Sqlizer, orderBy []string) ([]PT, error) {
195+
func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []string) ([]P, error) {
190196
q := t.getListQuery(where, orderBy)
191-
var records []PT
197+
var records []P
192198
if err := t.Query.GetAll(ctx, q, &records); err != nil {
193199
return nil, err
194200
}
@@ -197,14 +203,14 @@ func (t *Table[T, PT, IDT]) List(ctx context.Context, where sq.Sqlizer, orderBy
197203
}
198204

199205
// Iter returns an iterator for records matching the condition.
200-
func (t *Table[T, PT, IDT]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []string) (iter.Seq2[PT, error], error) {
206+
func (t *Table[T, P, I]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []string) (iter.Seq2[P, error], error) {
201207
q := t.getListQuery(where, orderBy)
202208
rows, err := t.Query.QueryRows(ctx, q)
203209
if err != nil {
204210
return nil, fmt.Errorf("query rows: %w", err)
205211
}
206212

207-
return func(yield func(PT, error) bool) {
213+
return func(yield func(P, error) bool) {
208214
defer rows.Close()
209215
for rows.Next() {
210216
var record T
@@ -222,17 +228,17 @@ func (t *Table[T, PT, IDT]) Iter(ctx context.Context, where sq.Sqlizer, orderBy
222228
}
223229

224230
// GetByID returns a record by its ID.
225-
func (t *Table[T, PT, IDT]) GetByID(ctx context.Context, id IDT) (PT, error) {
231+
func (t *Table[T, P, I]) GetByID(ctx context.Context, id I) (P, error) {
226232
return t.Get(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn})
227233
}
228234

229235
// ListByIDs returns records by their IDs.
230-
func (t *Table[T, PT, IDT]) ListByIDs(ctx context.Context, ids []IDT) ([]PT, error) {
236+
func (t *Table[T, P, I]) ListByIDs(ctx context.Context, ids []I) ([]P, error) {
231237
return t.List(ctx, sq.Eq{t.IDColumn: ids}, nil)
232238
}
233239

234240
// Count returns the number of matching records.
235-
func (t *Table[T, PT, IDT]) Count(ctx context.Context, where sq.Sqlizer) (uint64, error) {
241+
func (t *Table[T, P, I]) Count(ctx context.Context, where sq.Sqlizer) (uint64, error) {
236242
var count uint64
237243
q := t.SQL.
238244
Select("COUNT(1)").
@@ -247,7 +253,7 @@ func (t *Table[T, PT, IDT]) Count(ctx context.Context, where sq.Sqlizer) (uint64
247253
}
248254

249255
// DeleteByID deletes a record by ID. Uses soft delete if .SetDeletedAt() method exists.
250-
func (t *Table[T, PT, IDT]) DeleteByID(ctx context.Context, id IDT) error {
256+
func (t *Table[T, P, I]) DeleteByID(ctx context.Context, id I) error {
251257
record, err := t.GetByID(ctx, id)
252258
if err != nil {
253259
return fmt.Errorf("delete: %w", err)
@@ -267,7 +273,7 @@ func (t *Table[T, PT, IDT]) DeleteByID(ctx context.Context, id IDT) error {
267273
}
268274

269275
// HardDeleteByID permanently deletes a record by ID.
270-
func (t *Table[T, PT, IDT]) HardDeleteByID(ctx context.Context, id IDT) error {
276+
func (t *Table[T, P, I]) HardDeleteByID(ctx context.Context, id I) error {
271277
q := t.SQL.Delete(t.Name).Where(sq.Eq{t.IDColumn: id})
272278
if _, err := t.Query.Exec(ctx, q); err != nil {
273279
return fmt.Errorf("hard delete: %w", err)
@@ -276,8 +282,8 @@ func (t *Table[T, PT, IDT]) HardDeleteByID(ctx context.Context, id IDT) error {
276282
}
277283

278284
// WithTx returns a table instance bound to the given transaction.
279-
func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] {
280-
return &Table[T, PT, IDT]{
285+
func (t *Table[T, P, I]) WithTx(tx pgx.Tx) *Table[T, P, I] {
286+
return &Table[T, P, I]{
281287
DB: &DB{
282288
Conn: t.DB.Conn,
283289
SQL: t.DB.SQL,
@@ -296,10 +302,10 @@ func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] {
296302
// to update status to "completed" or "failed".
297303
//
298304
// Returns ErrNoRows if no matching records are available for locking.
299-
func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record PT)) error {
305+
func (t *Table[T, P, I]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record P)) error {
300306
var noRows bool
301307

302-
err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []PT) {
308+
err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []P) {
303309
if len(records) > 0 {
304310
updateFn(records[0])
305311
} else {
@@ -324,7 +330,7 @@ func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer,
324330
// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status
325331
// to "processing" and return early, then process asynchronously. Use defer LockForUpdate()
326332
// to update status to "completed" or "failed".
327-
func (t *Table[T, PT, IDT]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error {
333+
func (t *Table[T, P, I]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error {
328334
// Check if we're already in a transaction
329335
if t.DB.Query.Tx != nil {
330336
if err := t.lockForUpdatesWithTx(ctx, t.DB.Query.Tx, where, orderBy, limit, updateFn); err != nil {
@@ -340,7 +346,7 @@ func (t *Table[T, PT, IDT]) LockForUpdates(ctx context.Context, where sq.Sqlizer
340346
})
341347
}
342348

343-
func (t *Table[T, PT, IDT]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error {
349+
func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error {
344350
if len(orderBy) == 0 {
345351
orderBy = []string{t.IDColumn}
346352
}
@@ -355,7 +361,7 @@ func (t *Table[T, PT, IDT]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.T
355361

356362
txQuery := t.DB.TxQuery(pgTx)
357363

358-
var records []PT
364+
var records []P
359365
if err := txQuery.GetAll(ctx, q, &records); err != nil {
360366
return fmt.Errorf("select for update skip locked: %w", err)
361367
}

0 commit comments

Comments
 (0)