Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/release-notes/release-notes-0.22.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
the close transaction is actually broadcast, and
`WaitingCloseChannel.ClosingTx` is never empty.

* [Fixed `ListPayments`](https://github.com/lightningnetwork/lnd/pull/10874)
so `count_total_payments` respects the `creation_date_start` and
`creation_date_end` filters.

# New Features

## Functional Enhancements
Expand Down
191 changes: 175 additions & 16 deletions payments/db/kv_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,26 @@ func (p *KVStore) QueryPayments(_ context.Context,

var resp Response

matchesCreationDate := func(creationTime time.Time) bool {
// Get the creation time in Unix seconds, this always rounds down
// the nanoseconds to full seconds.
createTime := creationTime.Unix()

// Skip any payments that were created before the specified time.
if createTime < query.CreationDateStart {
return false
}

// Skip any payments that were created after the specified time.
if query.CreationDateEnd != 0 &&
createTime > query.CreationDateEnd {

return false
}

return true
}

if err := kvdb.View(p.db, func(tx kvdb.RTx) error {
// Get the root payments bucket.
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
Expand Down Expand Up @@ -1096,21 +1116,7 @@ func (p *KVStore) QueryPayments(_ context.Context,
return false, err
}

// Get the creation time in Unix seconds, this always
// rounds down the nanoseconds to full seconds.
createTime := payment.Info.CreationTime.Unix()

// Skip any payments that were created before the
// specified time.
if createTime < query.CreationDateStart {
return false, nil
}

// Skip any payments that were created after the
// specified time.
if query.CreationDateEnd != 0 &&
createTime > query.CreationDateEnd {

if !matchesCreationDate(payment.Info.CreationTime) {
return false, nil
}

Expand Down Expand Up @@ -1141,7 +1147,35 @@ func (p *KVStore) QueryPayments(_ context.Context,
totalPayments uint64
err error
)
countFn := func(_, _ []byte) error {
countFn := func(sequenceKey, hash []byte) error {
if query.CreationDateStart != 0 ||
query.CreationDateEnd != 0 ||
!query.IncludeIncomplete {

r := bytes.NewReader(hash)
paymentHash, err := deserializePaymentIndex(r)
if err != nil {
return err
}

creationTime, status, err := fetchPaymentQueryMetadata(
tx, paymentHash, sequenceKey,
)
if err != nil {
return err
}

if !matchesCreationDate(creationTime) {
return nil
}

if status != StatusSucceeded &&
!query.IncludeIncomplete {

return nil
}
}

totalPayments++

return nil
Expand Down Expand Up @@ -1189,6 +1223,131 @@ func (p *KVStore) QueryPayments(_ context.Context,
return resp, nil
}

// fetchPaymentQueryMetadata gets the creation time and status for the payment
// that matches the payment hash and sequence number. This is the lightweight
// counterpart to fetchPaymentWithSequenceNumber for callers that only need
// query metadata and should avoid loading/deserializing all HTLC attempts.
func fetchPaymentQueryMetadata(tx kvdb.RTx,
paymentHash lntypes.Hash,
sequenceNumber []byte) (time.Time, PaymentStatus, error) {

bucket, err := fetchPaymentBucket(tx, paymentHash)
if err != nil {
return time.Time{}, 0, err
}

seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes == nil {
return time.Time{}, 0, ErrNoSequenceNumber
}

if bytes.Equal(seqBytes, sequenceNumber) {
creationInfo, err := fetchCreationInfo(bucket)
if err != nil {
return time.Time{}, 0, err
}

status, err := fetchPaymentQueryStatus(bucket)
if err != nil {
return time.Time{}, 0, err
}

return creationInfo.CreationTime, status, nil
}

dup := bucket.NestedReadBucket(duplicatePaymentsBucket)
if dup == nil {
return time.Time{}, 0, ErrNoDuplicateBucket
}

var (
creationTime time.Time
status PaymentStatus
found bool
)
err = dup.ForEach(func(k, _ []byte) error {
if found {
return nil
}

subBucket := dup.NestedReadBucket(k)
if subBucket == nil {
return ErrNoDuplicateNestedBucket
}

seqBytes := subBucket.Get(duplicatePaymentSequenceKey)
if seqBytes == nil {
return ErrNoDuplicateSequenceNumber
}

if !bytes.Equal(seqBytes, sequenceNumber) {
return nil
}

b := subBucket.Get(duplicatePaymentCreationInfoKey)
if b == nil {
return fmt.Errorf("creation info not found")
}

creationInfo, err := deserializeDuplicatePaymentCreationInfo(
bytes.NewReader(b),
)
if err != nil {
return err
}

status, err = fetchDuplicatePaymentStatus(subBucket)
if err != nil {
return err
}

creationTime = creationInfo.CreationTime
found = true

return nil
})
if err != nil {
return time.Time{}, 0, err
}

if !found {
return time.Time{}, 0, ErrDuplicateNotFound
}

return creationTime, status, nil
}

func fetchPaymentQueryStatus(bucket kvdb.RBucket) (PaymentStatus, error) {
if bucket.Get(paymentCreationInfoKey) == nil {
return 0, ErrPaymentNotInitiated
}

htlcsBucket := bucket.NestedReadBucket(paymentHtlcsBucket)
if htlcsBucket != nil {
var settled bool
err := htlcsBucket.ForEach(func(k, _ []byte) error {
if bytes.HasPrefix(k, htlcSettleInfoKey) {
settled = true
}

return nil
})
if err != nil {
return 0, err
}

if settled {
return StatusSucceeded, nil
}
}

if bucket.Get(paymentFailInfoKey) != nil {
return StatusFailed, nil
}

return StatusInFlight, nil
}

// fetchPaymentWithSequenceNumber get the payment which matches the payment hash
// *and* sequence number provided from the database. This is required because
// we previously had more than one payment per hash, so we have multiple indexes
Expand Down
42 changes: 41 additions & 1 deletion payments/db/payment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,10 @@ func TestQueryPayments(t *testing.T) {
firstIndex uint64
lastIndex uint64

// expectedTotal is the expected TotalCount when CountTotal is
// set. A zero value means the unfiltered total is expected.
expectedTotal uint64

// expectedSeqNrs contains the set of sequence numbers we expect
// our query to return.
expectedSeqNrs []uint64
Expand Down Expand Up @@ -2705,7 +2709,7 @@ func TestQueryPayments(t *testing.T) {
expectedSeqNrs: []uint64{5, 6},
},
{
name: "count total with filters",
name: "count total excludes incomplete payments",
query: Query{
IndexOffset: 0,
MaxPayments: math.MaxUint64,
Expand All @@ -2715,6 +2719,39 @@ func TestQueryPayments(t *testing.T) {
},
firstIndex: 6,
lastIndex: 6,
expectedTotal: 1,
expectedSeqNrs: []uint64{6},
},
{
name: "count total with date filters and incomplete",
query: Query{
IndexOffset: 0,
MaxPayments: 2,
Reversed: false,
IncludeIncomplete: true,
CountTotal: true,
CreationDateStart: 3,
CreationDateEnd: 5,
},
firstIndex: 3,
lastIndex: 4,
expectedTotal: 3,
expectedSeqNrs: []uint64{3, 4},
},
{
name: "count total with date filters excludes incomplete",
query: Query{
IndexOffset: 0,
MaxPayments: 2,
Reversed: false,
IncludeIncomplete: false,
CountTotal: true,
CreationDateStart: 3,
CreationDateEnd: 6,
},
firstIndex: 6,
lastIndex: 6,
expectedTotal: 1,
expectedSeqNrs: []uint64{6},
},
}
Expand Down Expand Up @@ -2870,6 +2907,9 @@ func TestQueryPayments(t *testing.T) {
// We should have 5 total payments
// (6 created - 1 deleted).
expectedTotal := uint64(5)
if tt.expectedTotal != 0 {
expectedTotal = tt.expectedTotal
}
require.Equal(
t, expectedTotal, querySlice.TotalCount,
"expected total count %v, got %v",
Expand Down
50 changes: 27 additions & 23 deletions payments/db/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type SQLQueries interface {
FetchNonTerminalPayments(ctx context.Context, arg sqlc.FetchNonTerminalPaymentsParams) ([]sqlc.FetchNonTerminalPaymentsRow, error)

CountPayments(ctx context.Context) (int64, error)
CountFilteredPayments(ctx context.Context, query sqlc.CountFilteredPaymentsParams) (int64, error)

FetchHtlcAttemptsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptsForPaymentsRow, error)
FetchHtlcAttemptResolutionsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptResolutionsForPaymentsRow, error)
Expand Down Expand Up @@ -690,11 +691,36 @@ func (s *SQLStore) QueryPayments(ctx context.Context, query Query) (Response,
return row.Payment.ID
}

// Default date bounds: epoch start and far future. These are always
// provided so the SQL query uses simple comparisons instead of COALESCE
// (which causes type mismatch on Postgres) or OR-based optional filters
// (which can prevent index usage).
createdAfter := time.Unix(0, 0).UTC()
if query.CreationDateStart != 0 {
createdAfter = time.Unix(query.CreationDateStart, 0).UTC()
}

createdBefore := time.Date(
9999, 12, 31, 23, 59, 59, 0, time.UTC,
)
if query.CreationDateEnd != 0 {
createdBefore = time.Unix(query.CreationDateEnd, 0).UTC()
}

err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
// We first count all payments to determine the total count
// if requested.
if query.CountTotal {
totalPayments, err := db.CountPayments(ctx)
totalPayments, err := db.CountFilteredPayments(
ctx, sqlc.CountFilteredPaymentsParams{
CreatedAfter: createdAfter,
CreatedBefore: createdBefore,
IncludeIncomplete: query.IncludeIncomplete,
IntentType: sqldb.SQLInt16(
PaymentIntentTypeBolt11,
),
},
)
if err != nil {
return fmt.Errorf("failed to count "+
"payments: %w", err)
Expand Down Expand Up @@ -768,28 +794,6 @@ func (s *SQLStore) QueryPayments(ctx context.Context, query Query) (Response,
queryFunc := func(ctx context.Context, lastID int64,
limit int32) ([]sqlc.FilterPaymentsRow, error) {

// Default date bounds: epoch start and far
// future. These are always provided so the SQL
// query uses simple comparisons instead of
// COALESCE (which causes type mismatch on
// Postgres) or OR-based optional filters (which
// can prevent index usage).
createdAfter := time.Unix(0, 0).UTC()
if query.CreationDateStart != 0 {
createdAfter = time.Unix(
query.CreationDateStart, 0,
).UTC()
}

createdBefore := time.Date(
9999, 12, 31, 23, 59, 59, 0, time.UTC,
)
if query.CreationDateEnd != 0 {
createdBefore = time.Unix(
query.CreationDateEnd, 0,
).UTC()
}

filterParams := sqlc.FilterPaymentsParams{
NumLimit: limit,
CreatedAfter: createdAfter,
Expand Down
Loading
Loading