Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,10 @@ func (c ForceIndexConfig) IndexFor(schemaName, tableName string) string {
// used. The term `Cascading` to denote that greater specificity takes
// precedence.
type CascadingPaginationColumnConfig struct {
// PerTableComposite has highest specificity for composite keys (max 3 columns)
// SchemaName => TableName => [ColumnName1, ColumnName2, ColumnName3]
PerTableComposite map[string]map[string][]string

// PerTable has greatest specificity and takes precedence over the other options
PerTable map[string]map[string]string // SchemaName => TableName => ColumnName

Expand Down Expand Up @@ -404,6 +408,30 @@ func (c *CascadingPaginationColumnConfig) PaginationColumnFor(schemaName, tableN
return column, true
}

// CompositePaginationColumnsFor retrieves composite pagination columns for a table
func (c *CascadingPaginationColumnConfig) CompositePaginationColumnsFor(schemaName, tableName string) ([]string, bool) {
if c == nil || c.PerTableComposite == nil {
return nil, false
}

tableConfig, found := c.PerTableComposite[schemaName]
if !found {
return nil, false
}

columns, found := tableConfig[tableName]
if !found {
return nil, false
}

// Validate max 3 columns
if len(columns) > 3 || len(columns) == 0 {
return nil, false
}

return columns, true
}

// FallbackPaginationColumnName retreives the column name specified as a fallback when the Primary Key isn't suitable for pagination
func (c *CascadingPaginationColumnConfig) FallbackPaginationColumnName() (string, bool) {
if c == nil || c.FallbackColumn == "" {
Expand Down
214 changes: 196 additions & 18 deletions cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,29 @@ func (c *CursorConfig) NewCursorWithoutRowLock(table *TableSchema, startPaginati
return cursor
}

// NewCompositeCursor creates a cursor for composite key pagination
func (c *CursorConfig) NewCompositeCursor(table *TableSchema, startKeys, maxKeys CompositeKey) *Cursor {
cursor := &Cursor{
CursorConfig: *c,
Table: table,
RowLock: true,
isComposite: true,
lastSuccessfulCompositeKey: startKeys,
maxCompositeKey: maxKeys,
// Set single key values from first column for backward compatibility
MaxPaginationKey: maxKeys.Values[0].(uint64),
lastSuccessfulPaginationKey: startKeys.Values[0].(uint64),
}
return cursor
}

// NewCompositeCursorWithoutRowLock creates a cursor for composite key pagination without row locks
func (c *CursorConfig) NewCompositeCursorWithoutRowLock(table *TableSchema, startKeys, maxKeys CompositeKey) *Cursor {
cursor := c.NewCompositeCursor(table, startKeys, maxKeys)
cursor.RowLock = false
return cursor
}

func (c CursorConfig) GetBatchSize(schemaName string, tableName string) uint64 {
if c.BatchSizePerTableOverride != nil {
if batchSize, found := c.BatchSizePerTableOverride.TableOverride[schemaName][tableName]; found {
Expand All @@ -73,30 +96,82 @@ func (c CursorConfig) GetBatchSize(schemaName string, tableName string) uint64 {
return *c.BatchSize
}

// CompositeKey represents a composite pagination key
type CompositeKey struct {
Values []interface{} // Can be uint64 or string
}

// NewCompositeKey creates a CompositeKey from values
func NewCompositeKey(values ...interface{}) CompositeKey {
return CompositeKey{Values: values}
}

// IsLessThan compares two composite keys
func (c CompositeKey) IsLessThan(other CompositeKey) bool {
for i := 0; i < len(c.Values) && i < len(other.Values); i++ {
switch v1 := c.Values[i].(type) {
case uint64:
v2 := other.Values[i].(uint64)
if v1 < v2 {
return true
} else if v1 > v2 {
return false
}
case string:
v2 := other.Values[i].(string)
if v1 < v2 {
return true
} else if v1 > v2 {
return false
}
}
}
return false
}

type Cursor struct {
CursorConfig

Table *TableSchema
MaxPaginationKey uint64
RowLock bool

// Single column pagination (backward compatibility)
paginationKeyColumn *schema.TableColumn
lastSuccessfulPaginationKey uint64

// Composite key pagination
isComposite bool
maxCompositeKey CompositeKey
lastSuccessfulCompositeKey CompositeKey

logger *logrus.Entry
}

// shouldContinue checks if cursor should continue iterating
func (c *Cursor) shouldContinue() bool {
if c.isComposite {
return c.lastSuccessfulCompositeKey.IsLessThan(c.maxCompositeKey)
}
return c.lastSuccessfulPaginationKey < c.MaxPaginationKey
}

func (c *Cursor) Each(f func(*RowBatch) error) error {
c.logger = logrus.WithFields(logrus.Fields{
"table": c.Table.String(),
"tag": "cursor",
})
c.paginationKeyColumn = c.Table.GetPaginationColumn()

if !c.isComposite {
c.paginationKeyColumn = c.Table.GetPaginationColumn()
}

if len(c.ColumnsToSelect) == 0 {
c.ColumnsToSelect = []string{"*"}
}

for c.lastSuccessfulPaginationKey < c.MaxPaginationKey {
// Use appropriate loop condition based on pagination type
for c.shouldContinue() {
var tx SqlPreparerAndRollbacker
var batch *RowBatch
var paginationKeypos uint64
Expand Down Expand Up @@ -153,17 +228,50 @@ func (c *Cursor) Each(f func(*RowBatch) error) error {

tx.Rollback()

c.lastSuccessfulPaginationKey = paginationKeypos
// Update pagination position
if c.isComposite {
c.updateCompositePosition(batch)
} else {
c.lastSuccessfulPaginationKey = paginationKeypos
}
}

return nil
}

// updateCompositePosition updates the last successful composite key from the batch
func (c *Cursor) updateCompositePosition(batch *RowBatch) {
if batch.Size() == 0 {
return
}

lastRow := batch.Values()[batch.Size()-1]
newKeys := make([]interface{}, len(c.Table.CompositePaginationIndexes))

for i, idx := range c.Table.CompositePaginationIndexes {
switch v := lastRow[idx].(type) {
case int64:
newKeys[i] = uint64(v)
default:
newKeys[i] = v
}
}

c.lastSuccessfulCompositeKey = NewCompositeKey(newKeys...)
// Update single key for backward compatibility
if v, ok := newKeys[0].(uint64); ok {
c.lastSuccessfulPaginationKey = v
}
}

func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64, err error) {
var selectBuilder squirrel.SelectBuilder
batchSize := c.CursorConfig.GetBatchSize(c.Table.Schema, c.Table.Name)

if c.BuildSelect != nil {
if c.isComposite {
// Use composite pagination
selectBuilder = DefaultBuildSelectComposite(c.ColumnsToSelect, c.Table, c.lastSuccessfulCompositeKey.Values, batchSize)
} else if c.BuildSelect != nil {
selectBuilder, err = c.BuildSelect(c.ColumnsToSelect, c.Table, c.lastSuccessfulPaginationKey, batchSize)
if err != nil {
c.logger.WithError(err).Error("failed to apply filter for select")
Expand Down Expand Up @@ -229,17 +337,43 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64
}

var paginationKeyIndex int = -1
for idx, col := range columns {
if col == c.paginationKeyColumn.Name {
paginationKeyIndex = idx
break
var compositePaginationIndexes []int

if c.isComposite {
// Find all composite pagination column indexes
compositePaginationIndexes = make([]int, len(c.Table.CompositePaginationColumns))
for i, paginationCol := range c.Table.CompositePaginationColumns {
found := false
for idx, col := range columns {
if col == paginationCol.Name {
compositePaginationIndexes[i] = idx
if i == 0 {
paginationKeyIndex = idx // First column for backward compatibility
}
found = true
break
}
}
if !found {
err = fmt.Errorf("composite paginationKey column %s not found in columns: %v", paginationCol.Name, columns)
logger.WithError(err).Error("failed to get composite paginationKey index")
return
}
}
} else {
// Single pagination key
for idx, col := range columns {
if col == c.paginationKeyColumn.Name {
paginationKeyIndex = idx
break
}
}

if paginationKeyIndex < 0 {
err = fmt.Errorf("paginationKey is not found during iteration with columns: %v", columns)
logger.WithError(err).Error("failed to get paginationKey index")
return
}
}

if paginationKeyIndex < 0 {
err = fmt.Errorf("paginationKey is not found during iteration with columns: %v", columns)
logger.WithError(err).Error("failed to get paginationKey index")
return
}

var rowData RowData
Expand Down Expand Up @@ -269,10 +403,12 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64
}

batch = &RowBatch{
values: batchData,
paginationKeyIndex: paginationKeyIndex,
table: c.Table,
columns: columns,
values: batchData,
paginationKeyIndex: paginationKeyIndex,
isCompositePagination: c.isComposite,
compositePaginationIndexes: compositePaginationIndexes,
table: c.Table,
columns: columns,
}

logger.Debugf("found %d rows", batch.Size())
Expand Down Expand Up @@ -304,6 +440,29 @@ func ScanByteRow(rows *sqlorig.Rows, columnCount int) ([][]byte, error) {
return values, err
}

// BuildCompositeTupleComparison creates a WHERE clause for composite key pagination
// For columns (a,b,c) and values (x,y,z), generates:
// WHERE a > x OR (a = x AND b > y) OR (a = x AND b = y AND c > z)
func BuildCompositeTupleComparison(columns []string, values []interface{}) squirrel.Or {
conditions := make(squirrel.Or, 0, len(columns))

for i := 0; i < len(columns); i++ {
condition := squirrel.And{}

// Add equality conditions for all columns before the current one
for j := 0; j < i; j++ {
condition = append(condition, squirrel.Eq{columns[j]: values[j]})
}

// Add greater than condition for the current column
condition = append(condition, squirrel.Gt{columns[i]: values[i]})

conditions = append(conditions, condition)
}

return conditions
}

func DefaultBuildSelect(columns []string, table *TableSchema, lastPaginationKey, batchSize uint64) squirrel.SelectBuilder {
quotedPaginationKey := QuoteField(table.GetPaginationColumn().Name)

Expand All @@ -313,3 +472,22 @@ func DefaultBuildSelect(columns []string, table *TableSchema, lastPaginationKey,
Limit(batchSize).
OrderBy(quotedPaginationKey)
}

// DefaultBuildSelectComposite builds a SELECT query for composite pagination
func DefaultBuildSelectComposite(columns []string, table *TableSchema, lastPaginationKeys []interface{}, batchSize uint64) squirrel.SelectBuilder {
quotedColumns := make([]string, len(table.CompositePaginationColumns))
orderByColumns := make([]string, len(table.CompositePaginationColumns))

for i, col := range table.CompositePaginationColumns {
quotedColumns[i] = QuoteField(col.Name)
orderByColumns[i] = QuoteField(col.Name)
}

whereClause := BuildCompositeTupleComparison(quotedColumns, lastPaginationKeys)

return squirrel.Select(columns...).
From(QuotedTableName(table)).
Where(whereClause).
Limit(batchSize).
OrderBy(orderByColumns...)
}
Loading
Loading