diff --git a/batch_writer.go b/batch_writer.go index 550f1656..28dfb4bd 100644 --- a/batch_writer.go +++ b/batch_writer.go @@ -7,7 +7,6 @@ import ( sql "github.com/Shopify/ghostferry/sqlwrapper" - "github.com/go-mysql-org/go-mysql/schema" "github.com/sirupsen/logrus" ) @@ -57,65 +56,15 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error { return nil } - var startPaginationKeypos, endPaginationKeypos PaginationKey - var err error - paginationColumn := batch.TableSchema().GetPaginationColumn() - switch paginationColumn.Type { - case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: - var startValue, endValue uint64 - startValue, err = values[0].GetUint64(batch.PaginationKeyIndex()) - if err != nil { - return err - } - endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex()) - if err != nil { - return err - } - startPaginationKeypos = NewUint64Key(startValue) - endPaginationKeypos = NewUint64Key(endValue) - - case schema.TYPE_BINARY, schema.TYPE_STRING: - startValueInterface := values[0][batch.PaginationKeyIndex()] - endValueInterface := values[len(values)-1][batch.PaginationKeyIndex()] - - getBytes := func(val interface{}) ([]byte, error) { - switch v := val.(type) { - case []byte: - return v, nil - case string: - return []byte(v), nil - default: - return nil, fmt.Errorf("expected binary/string pagination key, got %T", val) - } - } - - startValue, err := getBytes(startValueInterface) - if err != nil { - return err - } - - endValue, err := getBytes(endValueInterface) - if err != nil { - return err - } - - startPaginationKeypos = NewBinaryKey(startValue) - endPaginationKeypos = NewBinaryKey(endValue) - - default: - var startValue, endValue uint64 - startValue, err = values[0].GetUint64(batch.PaginationKeyIndex()) - if err != nil { - return err - } - endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex()) - if err != nil { - return err - } - startPaginationKeypos = NewUint64Key(startValue) - endPaginationKeypos = NewUint64Key(endValue) + startPaginationKeypos, err := NewPaginationKeyFromRow(values[0], batch.PaginationKeyIndex(), paginationColumn) + if err != nil { + return err + } + endPaginationKeypos, err := NewPaginationKeyFromRow(values[len(values)-1], batch.PaginationKeyIndex(), paginationColumn) + if err != nil { + return err } db := batch.TableSchema().Schema diff --git a/cursor.go b/cursor.go index 08f5a892..c685f270 100644 --- a/cursor.go +++ b/cursor.go @@ -262,43 +262,10 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos Pagina if len(batchData) > 0 { lastRowData := batchData[len(batchData)-1] - - switch c.paginationKeyColumn.Type { - case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: - var value uint64 - value, err = lastRowData.GetUint64(paginationKeyIndex) - if err != nil { - logger.WithError(err).Error("failed to get uint64 paginationKey value") - return - } - paginationKeypos = NewUint64Key(value) - - case schema.TYPE_BINARY, schema.TYPE_STRING: - valueInterface := lastRowData[paginationKeyIndex] - - var valueBytes []byte - switch v := valueInterface.(type) { - case []byte: - valueBytes = v - case string: - valueBytes = []byte(v) - default: - err = fmt.Errorf("expected binary pagination key to be []byte or string, got %T", valueInterface) - logger.WithError(err).Error("failed to get binary paginationKey value") - return - } - - paginationKeypos = NewBinaryKey(valueBytes) - - default: - // Fallback for other integer types - var value uint64 - value, err = lastRowData.GetUint64(paginationKeyIndex) - if err != nil { - logger.WithError(err).Error("failed to get uint64 paginationKey value") - return - } - paginationKeypos = NewUint64Key(value) + paginationKeypos, err = NewPaginationKeyFromRow(lastRowData, paginationKeyIndex, c.paginationKeyColumn) + if err != nil { + logger.WithError(err).Error("failed to get paginationKey value") + return } } diff --git a/data_iterator.go b/data_iterator.go index efe7d1b7..dee62cfd 100644 --- a/data_iterator.go +++ b/data_iterator.go @@ -6,7 +6,6 @@ import ( sql "github.com/Shopify/ghostferry/sqlwrapper" - "github.com/go-mysql-org/go-mysql/schema" "github.com/sirupsen/logrus" ) @@ -115,40 +114,13 @@ func (d *DataIterator) Run(tables []*TableSchema) { paginationColumn := table.GetPaginationColumn() for i, rowData := range batch.Values() { - var paginationKeyStr string - - switch paginationColumn.Type { - case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: - paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex()) - if err != nil { - logger.WithError(err).Error("failed to get uint64 paginationKey data") - return err - } - paginationKeyStr = NewUint64Key(paginationKeyUint).String() - - case schema.TYPE_BINARY, schema.TYPE_STRING: - paginationKeyInterface := rowData[batch.PaginationKeyIndex()] - var paginationKeyBytes []byte - switch v := paginationKeyInterface.(type) { - case []byte: - paginationKeyBytes = v - case string: - paginationKeyBytes = []byte(v) - default: - return fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface) - } - paginationKeyStr = NewBinaryKey(paginationKeyBytes).String() - - default: - paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex()) - if err != nil { - logger.WithError(err).Error("failed to get paginationKey data") - return err - } - paginationKeyStr = NewUint64Key(paginationKeyUint).String() + paginationKey, err := NewPaginationKeyFromRow(rowData, batch.PaginationKeyIndex(), paginationColumn) + if err != nil { + logger.WithError(err).Error("failed to get paginationKey data") + return err } - fingerprints[paginationKeyStr] = rowData[len(rowData)-1].([]byte) + fingerprints[paginationKey.String()] = rowData[len(rowData)-1].([]byte) rows[i] = rowData[:len(rowData)-1] } diff --git a/dml_events.go b/dml_events.go index 7d96c250..7b8a00f1 100644 --- a/dml_events.go +++ b/dml_events.go @@ -576,35 +576,9 @@ func paginationKeyFromEventData(table *TableSchema, rowData RowData) (string, er return "", err } - paginationColumn := table.GetPaginationColumn() - paginationKeyIndex := table.GetPaginationKeyIndex() - - switch paginationColumn.Type { - case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: - paginationKeyUint, err := rowData.GetUint64(paginationKeyIndex) - if err != nil { - return "", err - } - return NewUint64Key(paginationKeyUint).String(), nil - - case schema.TYPE_BINARY, schema.TYPE_STRING: - paginationKeyInterface := rowData[paginationKeyIndex] - var paginationKeyBytes []byte - switch v := paginationKeyInterface.(type) { - case []byte: - paginationKeyBytes = v - case string: - paginationKeyBytes = []byte(v) - default: - return "", fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface) - } - return NewBinaryKey(paginationKeyBytes).String(), nil - - default: - paginationKeyUint, err := rowData.GetUint64(paginationKeyIndex) - if err != nil { - return "", err - } - return NewUint64Key(paginationKeyUint).String(), nil + paginationKey, err := NewPaginationKeyFromRow(rowData, table.GetPaginationKeyIndex(), table.GetPaginationColumn()) + if err != nil { + return "", err } + return paginationKey.String(), nil } diff --git a/inline_verifier.go b/inline_verifier.go index 16c0fab1..df8e2055 100644 --- a/inline_verifier.go +++ b/inline_verifier.go @@ -335,34 +335,11 @@ func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, target paginationKeys := make([]interface{}, len(sourceBatch.Values())) for i, row := range sourceBatch.Values() { - switch paginationColumn.Type { - case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: - paginationKeyUint, err := row.GetUint64(sourceBatch.PaginationKeyIndex()) - if err != nil { - return nil, err - } - paginationKeys[i] = paginationKeyUint - - case schema.TYPE_BINARY, schema.TYPE_STRING: - paginationKeyInterface := row[sourceBatch.PaginationKeyIndex()] - var paginationKeyBytes []byte - switch v := paginationKeyInterface.(type) { - case []byte: - paginationKeyBytes = v - case string: - paginationKeyBytes = []byte(v) - default: - return nil, fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface) - } - paginationKeys[i] = paginationKeyBytes - - default: - paginationKeyUint, err := row.GetUint64(sourceBatch.PaginationKeyIndex()) - if err != nil { - return nil, err - } - paginationKeys[i] = paginationKeyUint + paginationKey, err := NewPaginationKeyFromRow(row, sourceBatch.PaginationKeyIndex(), paginationColumn) + if err != nil { + return nil, err } + paginationKeys[i] = paginationKey.SQLValue() } // Fetch target data @@ -376,36 +353,11 @@ func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, target sourceDecompressedData := make(map[string]map[string][]byte) for _, rowData := range sourceBatch.Values() { - var paginationKeyStr string - - switch paginationColumn.Type { - case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: - paginationKeyUint, err := rowData.GetUint64(sourceBatch.PaginationKeyIndex()) - if err != nil { - return nil, err - } - paginationKeyStr = NewUint64Key(paginationKeyUint).String() - - case schema.TYPE_BINARY, schema.TYPE_STRING: - paginationKeyInterface := rowData[sourceBatch.PaginationKeyIndex()] - var paginationKeyBytes []byte - switch v := paginationKeyInterface.(type) { - case []byte: - paginationKeyBytes = v - case string: - paginationKeyBytes = []byte(v) - default: - return nil, fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface) - } - paginationKeyStr = NewBinaryKey(paginationKeyBytes).String() - - default: - paginationKeyUint, err := rowData.GetUint64(sourceBatch.PaginationKeyIndex()) - if err != nil { - return nil, err - } - paginationKeyStr = NewUint64Key(paginationKeyUint).String() + paginationKey, err := NewPaginationKeyFromRow(rowData, sourceBatch.PaginationKeyIndex(), paginationColumn) + if err != nil { + return nil, err } + paginationKeyStr := paginationKey.String() sourceDecompressedData[paginationKeyStr] = make(map[string][]byte) for idx, col := range table.Columns { diff --git a/iterative_verifier.go b/iterative_verifier.go index cc896179..3476ebe6 100644 --- a/iterative_verifier.go +++ b/iterative_verifier.go @@ -320,31 +320,12 @@ func (v *IterativeVerifier) GetHashes(db *sql.DB, schemaName, tableName, paginat return nil, err } - var paginationKeyStr string - switch paginationColumn.Type { - case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: - paginationKeyUint, err := rowData.GetUint64(0) - if err != nil { - return nil, err - } - paginationKeyStr = NewUint64Key(paginationKeyUint).String() - - case schema.TYPE_BINARY, schema.TYPE_STRING: - paginationKeyBytes, ok := rowData[0].([]byte) - if !ok { - return nil, fmt.Errorf("expected []byte for binary pagination key, got %T", rowData[0]) - } - paginationKeyStr = NewBinaryKey(paginationKeyBytes).String() - - default: - paginationKeyUint, err := rowData.GetUint64(0) - if err != nil { - return nil, err - } - paginationKeyStr = NewUint64Key(paginationKeyUint).String() + paginationKey, err := NewPaginationKeyFromRow(rowData, 0, paginationColumn) + if err != nil { + return nil, err } - resultSet[paginationKeyStr] = rowData[1].([]byte) + resultSet[paginationKey.String()] = rowData[1].([]byte) } return resultSet, nil } @@ -422,34 +403,11 @@ func (v *IterativeVerifier) iterateTableFingerprints(table *TableSchema, mismatc paginationColumn := table.GetPaginationColumn() for _, rowData := range batch.Values() { - switch paginationColumn.Type { - case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: - paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex()) - if err != nil { - return err - } - paginationKeys = append(paginationKeys, paginationKeyUint) - - case schema.TYPE_BINARY, schema.TYPE_STRING: - paginationKeyInterface := rowData[batch.PaginationKeyIndex()] - var paginationKeyBytes []byte - switch v := paginationKeyInterface.(type) { - case []byte: - paginationKeyBytes = v - case string: - paginationKeyBytes = []byte(v) - default: - return fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface) - } - paginationKeys = append(paginationKeys, paginationKeyBytes) - - default: - paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex()) - if err != nil { - return err - } - paginationKeys = append(paginationKeys, paginationKeyUint) + paginationKey, err := NewPaginationKeyFromRow(rowData, batch.PaginationKeyIndex(), paginationColumn) + if err != nil { + return err } + paginationKeys = append(paginationKeys, paginationKey.SQLValue()) } mismatchedPaginationKeys, err := v.compareFingerprints(paginationKeys, batch.TableSchema()) diff --git a/pagination_key.go b/pagination_key.go index 42152648..4906632e 100644 --- a/pagination_key.go +++ b/pagination_key.go @@ -11,12 +11,21 @@ import ( "github.com/go-mysql-org/go-mysql/schema" ) +// PaginationKey represents a cursor position for paginating through table rows. +// It abstracts over different primary key types (integers, UUIDs, binary data) +// to enable consistent batched iteration through tables. type PaginationKey interface { + // SQLValue returns the value to use in SQL WHERE clauses (e.g., WHERE id > ?). SQLValue() interface{} + // Compare returns -1, 0, or 1 if this key is less than, equal to, or greater than other. Compare(other PaginationKey) int + // NumericPosition returns a float64 approximation for progress tracking and estimation. NumericPosition() float64 + // String returns a human-readable representation for logging and debugging. String() string + // MarshalJSON serializes the key for state persistence and checkpointing. MarshalJSON() ([]byte, error) + // IsMax returns true if this key represents the maximum possible value for its type. IsMax() bool } @@ -207,7 +216,7 @@ func MaxPaginationKey(column *schema.TableColumn) PaginationKey { if size > 4096 { size = 4096 } - + maxBytes := make([]byte, size) for i := range maxBytes { maxBytes[i] = 0xFF @@ -217,3 +226,37 @@ func MaxPaginationKey(column *schema.TableColumn) PaginationKey { return NewUint64Key(math.MaxUint64) } } + +// NewPaginationKeyFromRow extracts a pagination key from a row at the given index. +// It determines the appropriate key type based on the column schema. +func NewPaginationKeyFromRow(rowData RowData, index int, column *schema.TableColumn) (PaginationKey, error) { + switch column.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + value, err := rowData.GetUint64(index) + if err != nil { + return nil, fmt.Errorf("failed to get uint64 pagination key: %w", err) + } + return NewUint64Key(value), nil + + case schema.TYPE_BINARY, schema.TYPE_STRING: + valueInterface := rowData[index] + var valueBytes []byte + switch v := valueInterface.(type) { + case []byte: + valueBytes = v + case string: + valueBytes = []byte(v) + default: + return nil, fmt.Errorf("expected binary pagination key to be []byte or string, got %T", valueInterface) + } + return NewBinaryKey(valueBytes), nil + + default: + // Fallback for other integer types + value, err := rowData.GetUint64(index) + if err != nil { + return nil, fmt.Errorf("failed to get pagination key: %w", err) + } + return NewUint64Key(value), nil + } +}