From f4612e128632adeec6dc429a549fc783b8cdf121 Mon Sep 17 00:00:00 2001 From: Julio Cesar Date: Wed, 3 Jun 2026 22:21:31 +0200 Subject: [PATCH] payments: count total with date filters --- docs/release-notes/release-notes-0.22.0.md | 4 + payments/db/kv_store.go | 211 +++++++++++++++++++-- payments/db/payment_test.go | 133 ++++++++++++- payments/db/sql_store.go | 50 ++--- sqldb/sqlc/payments.sql.go | 53 ++++++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/payments.sql | 33 ++++ 7 files changed, 445 insertions(+), 40 deletions(-) diff --git a/docs/release-notes/release-notes-0.22.0.md b/docs/release-notes/release-notes-0.22.0.md index 45e76051922..ff08d579c0f 100644 --- a/docs/release-notes/release-notes-0.22.0.md +++ b/docs/release-notes/release-notes-0.22.0.md @@ -34,6 +34,10 @@ the close transaction is actually broadcast, and `WaitingCloseChannel.ClosingTx` is never empty. +* [Fixed `ListPayments`](https://github.com/lightningnetwork/lnd/pull/10874) + so `count_total_payments` respects the `creation_date_start` and + `creation_date_end` filters. + # New Features ## Functional Enhancements diff --git a/payments/db/kv_store.go b/payments/db/kv_store.go index b3fffe1e182..a708696e703 100644 --- a/payments/db/kv_store.go +++ b/payments/db/kv_store.go @@ -1053,6 +1053,26 @@ func (p *KVStore) QueryPayments(_ context.Context, var resp Response + matchesCreationDate := func(creationTime time.Time) bool { + // Get the creation time in Unix seconds, this always rounds down + // the nanoseconds to full seconds. + createTime := creationTime.Unix() + + // Skip any payments that were created before the specified time. + if createTime < query.CreationDateStart { + return false + } + + // Skip any payments that were created after the specified time. + if query.CreationDateEnd != 0 && + createTime > query.CreationDateEnd { + + return false + } + + return true + } + if err := kvdb.View(p.db, func(tx kvdb.RTx) error { // Get the root payments bucket. paymentsBucket := tx.ReadBucket(paymentsRootBucket) @@ -1096,21 +1116,7 @@ func (p *KVStore) QueryPayments(_ context.Context, return false, err } - // Get the creation time in Unix seconds, this always - // rounds down the nanoseconds to full seconds. - createTime := payment.Info.CreationTime.Unix() - - // Skip any payments that were created before the - // specified time. - if createTime < query.CreationDateStart { - return false, nil - } - - // Skip any payments that were created after the - // specified time. - if query.CreationDateEnd != 0 && - createTime > query.CreationDateEnd { - + if !matchesCreationDate(payment.Info.CreationTime) { return false, nil } @@ -1141,7 +1147,35 @@ func (p *KVStore) QueryPayments(_ context.Context, totalPayments uint64 err error ) - countFn := func(_, _ []byte) error { + countFn := func(sequenceKey, hash []byte) error { + if query.CreationDateStart != 0 || + query.CreationDateEnd != 0 || + !query.IncludeIncomplete { + + r := bytes.NewReader(hash) + paymentHash, err := deserializePaymentIndex(r) + if err != nil { + return err + } + + creationTime, status, err := fetchPaymentQueryMetadata( + tx, paymentHash, sequenceKey, + ) + if err != nil { + return err + } + + if !matchesCreationDate(creationTime) { + return nil + } + + if status != StatusSucceeded && + !query.IncludeIncomplete { + + return nil + } + } + totalPayments++ return nil @@ -1189,6 +1223,151 @@ func (p *KVStore) QueryPayments(_ context.Context, return resp, nil } +// fetchPaymentQueryMetadata gets the creation time and status for the payment +// that matches the payment hash and sequence number. This is the lightweight +// counterpart to fetchPaymentWithSequenceNumber for callers that only need +// query metadata and should avoid loading/deserializing all HTLC attempts. +func fetchPaymentQueryMetadata(tx kvdb.RTx, + paymentHash lntypes.Hash, + sequenceNumber []byte) (time.Time, PaymentStatus, error) { + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err != nil { + return time.Time{}, 0, err + } + + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return time.Time{}, 0, ErrNoSequenceNumber + } + + if bytes.Equal(seqBytes, sequenceNumber) { + creationInfo, err := fetchCreationInfo(bucket) + if err != nil { + return time.Time{}, 0, err + } + + status, err := fetchPaymentQueryStatus(bucket) + if err != nil { + return time.Time{}, 0, err + } + + return creationInfo.CreationTime, status, nil + } + + dup := bucket.NestedReadBucket(duplicatePaymentsBucket) + if dup == nil { + return time.Time{}, 0, ErrNoDuplicateBucket + } + + var ( + creationTime time.Time + status PaymentStatus + found bool + ) + err = dup.ForEach(func(k, _ []byte) error { + if found { + return nil + } + + subBucket := dup.NestedReadBucket(k) + if subBucket == nil { + return ErrNoDuplicateNestedBucket + } + + seqBytes := subBucket.Get(duplicatePaymentSequenceKey) + if seqBytes == nil { + return ErrNoDuplicateSequenceNumber + } + + if !bytes.Equal(seqBytes, sequenceNumber) { + return nil + } + + b := subBucket.Get(duplicatePaymentCreationInfoKey) + if b == nil { + return fmt.Errorf("creation info not found") + } + + creationInfo, err := deserializeDuplicatePaymentCreationInfo( + bytes.NewReader(b), + ) + if err != nil { + return err + } + + status, err = fetchDuplicatePaymentStatus(subBucket) + if err != nil { + return err + } + + creationTime = creationInfo.CreationTime + found = true + + return nil + }) + if err != nil { + return time.Time{}, 0, err + } + + if !found { + return time.Time{}, 0, ErrDuplicateNotFound + } + + return creationTime, status, nil +} + +func fetchPaymentQueryStatus(bucket kvdb.RBucket) (PaymentStatus, error) { + if bucket.Get(paymentCreationInfoKey) == nil { + return 0, ErrPaymentNotInitiated + } + + var failureReason *FailureReason + if failInfo := bucket.Get(paymentFailInfoKey); failInfo != nil { + reason := FailureReason(failInfo[0]) + failureReason = &reason + } + + htlcsBucket := bucket.NestedReadBucket(paymentHtlcsBucket) + if htlcsBucket == nil { + return decidePaymentStatus(nil, failureReason) + } + + htlcs := make(map[uint64]*HTLCAttempt) + err := htlcsBucket.ForEach(func(k, _ []byte) error { + aid := byteOrder.Uint64(k[len(k)-8:]) + if _, ok := htlcs[aid]; !ok { + htlcs[aid] = &HTLCAttempt{} + } + + switch { + case bytes.HasPrefix(k, htlcAttemptInfoKey): + return nil + + case bytes.HasPrefix(k, htlcSettleInfoKey): + htlcs[aid].Settle = &HTLCSettleInfo{} + + case bytes.HasPrefix(k, htlcFailInfoKey): + htlcs[aid].Failure = &HTLCFailInfo{} + + default: + return fmt.Errorf("unknown htlc attempt key") + } + + return nil + }) + if err != nil { + return 0, err + } + + htlcList := make([]HTLCAttempt, 0, len(htlcs)) + for _, htlc := range htlcs { + htlcList = append(htlcList, *htlc) + } + + return decidePaymentStatus(htlcList, failureReason) +} + // fetchPaymentWithSequenceNumber get the payment which matches the payment hash // *and* sequence number provided from the database. This is required because // we previously had more than one payment per hash, so we have multiple indexes diff --git a/payments/db/payment_test.go b/payments/db/payment_test.go index e304b136dd2..ed526e6f6eb 100644 --- a/payments/db/payment_test.go +++ b/payments/db/payment_test.go @@ -2424,6 +2424,10 @@ func TestQueryPayments(t *testing.T) { firstIndex uint64 lastIndex uint64 + // expectedTotal is the expected TotalCount when CountTotal is + // set. A zero value means the unfiltered total is expected. + expectedTotal uint64 + // expectedSeqNrs contains the set of sequence numbers we expect // our query to return. expectedSeqNrs []uint64 @@ -2705,7 +2709,7 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs: []uint64{5, 6}, }, { - name: "count total with filters", + name: "count total excludes incomplete payments", query: Query{ IndexOffset: 0, MaxPayments: math.MaxUint64, @@ -2715,6 +2719,39 @@ func TestQueryPayments(t *testing.T) { }, firstIndex: 6, lastIndex: 6, + expectedTotal: 1, + expectedSeqNrs: []uint64{6}, + }, + { + name: "count total with date filters and incomplete", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CountTotal: true, + CreationDateStart: 3, + CreationDateEnd: 5, + }, + firstIndex: 3, + lastIndex: 4, + expectedTotal: 3, + expectedSeqNrs: []uint64{3, 4}, + }, + { + name: "count total with date filters excludes incomplete", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: false, + CountTotal: true, + CreationDateStart: 3, + CreationDateEnd: 6, + }, + firstIndex: 6, + lastIndex: 6, + expectedTotal: 1, expectedSeqNrs: []uint64{6}, }, } @@ -2870,6 +2907,9 @@ func TestQueryPayments(t *testing.T) { // We should have 5 total payments // (6 created - 1 deleted). expectedTotal := uint64(5) + if tt.expectedTotal != 0 { + expectedTotal = tt.expectedTotal + } require.Equal( t, expectedTotal, querySlice.TotalCount, "expected total count %v, got %v", @@ -2884,6 +2924,97 @@ func TestQueryPayments(t *testing.T) { } } +// TestQueryPaymentsCountTotalMixedInflight asserts that CountTotal uses the +// same succeeded-only semantics as the returned payment list when +// IncludeIncomplete is false. A payment with one settled HTLC and one +// unresolved HTLC is still in-flight and must not be counted as succeeded. +func TestQueryPaymentsCountTotalMixedInflight(t *testing.T) { + t.Parallel() + + ctx := t.Context() + paymentDB, _ := NewTestDB(t) + + succeededInfo, _ := genInfo(t) + succeededInfo.CreationTime = time.Unix(10, 0) + err := paymentDB.InitPayment( + ctx, succeededInfo.PaymentIdentifier, succeededInfo, + ) + require.NoError(t, err) + + settledAttempt, err := NewHtlcAttempt( + 1, genSessionKey(t), testRoute, time.Unix(11, 0), + &succeededInfo.PaymentIdentifier, + ) + require.NoError(t, err) + + _, err = paymentDB.RegisterAttempt( + ctx, succeededInfo.PaymentIdentifier, + &settledAttempt.HTLCAttemptInfo, + ) + require.NoError(t, err) + + preimg := genPreimage(t) + _, err = paymentDB.SettleAttempt( + ctx, succeededInfo.PaymentIdentifier, settledAttempt.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + require.NoError(t, err) + + mixedInfo, _ := genInfo(t) + mixedInfo.CreationTime = time.Unix(20, 0) + mixedInfo.Value = testRoute.ReceiverAmt() * 2 + err = paymentDB.InitPayment(ctx, mixedInfo.PaymentIdentifier, mixedInfo) + require.NoError(t, err) + + mixedSettledAttempt, err := NewHtlcAttempt( + 2, genSessionKey(t), testRoute, time.Unix(21, 0), + &mixedInfo.PaymentIdentifier, + ) + require.NoError(t, err) + + _, err = paymentDB.RegisterAttempt( + ctx, mixedInfo.PaymentIdentifier, + &mixedSettledAttempt.HTLCAttemptInfo, + ) + require.NoError(t, err) + + mixedInflightAttempt, err := NewHtlcAttempt( + 3, genSessionKey(t), testRoute, time.Unix(22, 0), + &mixedInfo.PaymentIdentifier, + ) + require.NoError(t, err) + + _, err = paymentDB.RegisterAttempt( + ctx, mixedInfo.PaymentIdentifier, + &mixedInflightAttempt.HTLCAttemptInfo, + ) + require.NoError(t, err) + + preimg = genPreimage(t) + _, err = paymentDB.SettleAttempt( + ctx, mixedInfo.PaymentIdentifier, mixedSettledAttempt.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + require.NoError(t, err) + + resp, err := paymentDB.QueryPayments(ctx, Query{ + IndexOffset: 0, + MaxPayments: math.MaxUint64, + IncludeIncomplete: false, + CountTotal: true, + CreationDateStart: 1, + CreationDateEnd: 30, + }) + require.NoError(t, err) + require.Len(t, resp.Payments, 1) + require.Equal(t, StatusSucceeded, resp.Payments[0].Status) + require.Equal(t, uint64(1), resp.TotalCount) +} + // TestFetchInFlightPayments tests that FetchInFlightPayments correctly returns // only payments that are in-flight. func TestFetchInFlightPayments(t *testing.T) { diff --git a/payments/db/sql_store.go b/payments/db/sql_store.go index 3d92385fd0f..a855d66fcb8 100644 --- a/payments/db/sql_store.go +++ b/payments/db/sql_store.go @@ -52,6 +52,7 @@ type SQLQueries interface { FetchNonTerminalPayments(ctx context.Context, arg sqlc.FetchNonTerminalPaymentsParams) ([]sqlc.FetchNonTerminalPaymentsRow, error) CountPayments(ctx context.Context) (int64, error) + CountFilteredPayments(ctx context.Context, query sqlc.CountFilteredPaymentsParams) (int64, error) FetchHtlcAttemptsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptsForPaymentsRow, error) FetchHtlcAttemptResolutionsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptResolutionsForPaymentsRow, error) @@ -690,11 +691,36 @@ func (s *SQLStore) QueryPayments(ctx context.Context, query Query) (Response, return row.Payment.ID } + // Default date bounds: epoch start and far future. These are always + // provided so the SQL query uses simple comparisons instead of COALESCE + // (which causes type mismatch on Postgres) or OR-based optional filters + // (which can prevent index usage). + createdAfter := time.Unix(0, 0).UTC() + if query.CreationDateStart != 0 { + createdAfter = time.Unix(query.CreationDateStart, 0).UTC() + } + + createdBefore := time.Date( + 9999, 12, 31, 23, 59, 59, 0, time.UTC, + ) + if query.CreationDateEnd != 0 { + createdBefore = time.Unix(query.CreationDateEnd, 0).UTC() + } + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { // We first count all payments to determine the total count // if requested. if query.CountTotal { - totalPayments, err := db.CountPayments(ctx) + totalPayments, err := db.CountFilteredPayments( + ctx, sqlc.CountFilteredPaymentsParams{ + CreatedAfter: createdAfter, + CreatedBefore: createdBefore, + IncludeIncomplete: query.IncludeIncomplete, + IntentType: sqldb.SQLInt16( + PaymentIntentTypeBolt11, + ), + }, + ) if err != nil { return fmt.Errorf("failed to count "+ "payments: %w", err) @@ -768,28 +794,6 @@ func (s *SQLStore) QueryPayments(ctx context.Context, query Query) (Response, queryFunc := func(ctx context.Context, lastID int64, limit int32) ([]sqlc.FilterPaymentsRow, error) { - // Default date bounds: epoch start and far - // future. These are always provided so the SQL - // query uses simple comparisons instead of - // COALESCE (which causes type mismatch on - // Postgres) or OR-based optional filters (which - // can prevent index usage). - createdAfter := time.Unix(0, 0).UTC() - if query.CreationDateStart != 0 { - createdAfter = time.Unix( - query.CreationDateStart, 0, - ).UTC() - } - - createdBefore := time.Date( - 9999, 12, 31, 23, 59, 59, 0, time.UTC, - ) - if query.CreationDateEnd != 0 { - createdBefore = time.Unix( - query.CreationDateEnd, 0, - ).UTC() - } - filterParams := sqlc.FilterPaymentsParams{ NumLimit: limit, CreatedAfter: createdAfter, diff --git a/sqldb/sqlc/payments.sql.go b/sqldb/sqlc/payments.sql.go index 42a8fb826f0..8b454645f52 100644 --- a/sqldb/sqlc/payments.sql.go +++ b/sqldb/sqlc/payments.sql.go @@ -12,6 +12,59 @@ import ( "time" ) +const countFilteredPayments = `-- name: CountFilteredPayments :one +SELECT COUNT(*) +FROM payments p +LEFT JOIN payment_intents i ON i.payment_id = p.id +WHERE p.created_at >= $1 + AND p.created_at <= $2 + AND ( + i.intent_type = $3 OR + $3 IS NULL OR i.intent_type IS NULL + ) + AND ( + CAST($4 AS BOOLEAN) OR ( + EXISTS ( + SELECT 1 + FROM payment_htlc_attempts ha + JOIN payment_htlc_attempt_resolutions hr + ON hr.attempt_index = ha.attempt_index + WHERE ha.payment_id = p.id + AND hr.resolution_type = 1 + ) + AND NOT EXISTS ( + SELECT 1 + FROM payment_htlc_attempts ha + WHERE ha.payment_id = p.id + AND NOT EXISTS ( + SELECT 1 + FROM payment_htlc_attempt_resolutions hr + WHERE hr.attempt_index = ha.attempt_index + ) + ) + ) + ) +` + +type CountFilteredPaymentsParams struct { + CreatedAfter time.Time + CreatedBefore time.Time + IntentType sql.NullInt16 + IncludeIncomplete bool +} + +func (q *Queries) CountFilteredPayments(ctx context.Context, arg CountFilteredPaymentsParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countFilteredPayments, + arg.CreatedAfter, + arg.CreatedBefore, + arg.IntentType, + arg.IncludeIncomplete, + ) + var count int64 + err := row.Scan(&count) + return count, err +} + const countPayments = `-- name: CountPayments :one SELECT COUNT(*) FROM payments ` diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 9b95a669917..7497cf7f728 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -15,6 +15,7 @@ type Querier interface { AddV1ChannelProof(ctx context.Context, arg AddV1ChannelProofParams) (sql.Result, error) AddV2ChannelProof(ctx context.Context, arg AddV2ChannelProofParams) (sql.Result, error) ClearKVInvoiceHashIndex(ctx context.Context) error + CountFilteredPayments(ctx context.Context, arg CountFilteredPaymentsParams) (int64, error) CountPayments(ctx context.Context) (int64, error) CountZombieChannels(ctx context.Context, version int16) (int64, error) CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error) diff --git a/sqldb/sqlc/queries/payments.sql b/sqldb/sqlc/queries/payments.sql index 16682c83bc3..828268add5e 100644 --- a/sqldb/sqlc/queries/payments.sql +++ b/sqldb/sqlc/queries/payments.sql @@ -47,6 +47,39 @@ WHERE p.id > COALESCE(sqlc.narg('index_offset_get'), -1) ORDER BY p.id DESC LIMIT @num_limit; +-- name: CountFilteredPayments :one +SELECT COUNT(*) +FROM payments p +LEFT JOIN payment_intents i ON i.payment_id = p.id +WHERE p.created_at >= @created_after + AND p.created_at <= @created_before + AND ( + i.intent_type = sqlc.narg('intent_type') OR + sqlc.narg('intent_type') IS NULL OR i.intent_type IS NULL + ) + AND ( + CAST(sqlc.arg('include_incomplete') AS BOOLEAN) OR ( + EXISTS ( + SELECT 1 + FROM payment_htlc_attempts ha + JOIN payment_htlc_attempt_resolutions hr + ON hr.attempt_index = ha.attempt_index + WHERE ha.payment_id = p.id + AND hr.resolution_type = 1 + ) + AND NOT EXISTS ( + SELECT 1 + FROM payment_htlc_attempts ha + WHERE ha.payment_id = p.id + AND NOT EXISTS ( + SELECT 1 + FROM payment_htlc_attempt_resolutions hr + WHERE hr.attempt_index = ha.attempt_index + ) + ) + ) + ); + -- name: FetchPayment :one SELECT sqlc.embed(p),