diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b023c7bd..4079b17b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ All notable changes to this project will be documented in this file. ### Added - Changelog. +- Pagination keys beyond UINT64 @milanatshopify #417 +- Pagination keys other than UINT64 have to have binary collation @grodowski #422 ## [1.1.0] diff --git a/Makefile b/Makefile index 019e6b646..98c5939bd 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Variables to be built into the binary -VERSION := 1.1.0 +VERSION := 1.2.0 # This variable can be overwritten by the caller DATETIME ?= $(shell date -u +%Y%m%d%H%M%S) diff --git a/batch_writer.go b/batch_writer.go index 5ed22cd8a..28dfb4bd6 100644 --- a/batch_writer.go +++ b/batch_writer.go @@ -56,12 +56,13 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error { return nil } - startPaginationKeypos, err := values[0].GetUint64(batch.PaginationKeyIndex()) + paginationColumn := batch.TableSchema().GetPaginationColumn() + + startPaginationKeypos, err := NewPaginationKeyFromRow(values[0], batch.PaginationKeyIndex(), paginationColumn) if err != nil { return err } - - endPaginationKeypos, err := values[len(values)-1].GetUint64(batch.PaginationKeyIndex()) + endPaginationKeypos, err := NewPaginationKeyFromRow(values[len(values)-1], batch.PaginationKeyIndex(), paginationColumn) if err != nil { return err } @@ -78,12 +79,12 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error { query, args, err := batch.AsSQLQuery(db, table) if err != nil { - return fmt.Errorf("during generating sql query at paginationKey %v -> %v: %v", startPaginationKeypos, endPaginationKeypos, err) + return fmt.Errorf("during generating sql query at paginationKey %s -> %s: %v", startPaginationKeypos.String(), endPaginationKeypos.String(), err) } stmt, err := w.stmtCache.StmtFor(w.DB, query) if err != nil { - return fmt.Errorf("during prepare query near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err) + return fmt.Errorf("during prepare query near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err) } tx, err := w.DB.Begin() @@ -94,14 +95,14 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error { _, err = tx.Stmt(stmt).Exec(args...) if err != nil { tx.Rollback() - return fmt.Errorf("during exec query near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err) + return fmt.Errorf("during exec query near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err) } if w.InlineVerifier != nil { mismatches, err := w.InlineVerifier.CheckFingerprintInline(tx, db, table, batch, w.EnforceInlineVerification) if err != nil { tx.Rollback() - return fmt.Errorf("during fingerprint checking for paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err) + return fmt.Errorf("during fingerprint checking for paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err) } if w.EnforceInlineVerification { @@ -119,7 +120,7 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error { err = tx.Commit() if err != nil { tx.Rollback() - return fmt.Errorf("during commit near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err) + return fmt.Errorf("during commit near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err) } // Note that the state tracker expects us the track based on the original diff --git a/compression_verifier.go b/compression_verifier.go index 0efb0fd33..b70eb994a 100644 --- a/compression_verifier.go +++ b/compression_verifier.go @@ -49,6 +49,7 @@ func (e UnsupportedCompressionError) Error() string { type CompressionVerifier struct { logger *logrus.Entry + TableSchemaCache TableSchemaCache supportedAlgorithms map[string]struct{} tableColumnCompressions TableColumnCompressionConfig } @@ -59,32 +60,52 @@ type CompressionVerifier struct { // The GetCompressedHashes method checks if the existing table contains compressed data // and will apply the decompression algorithm to the applicable columns if necessary. // After the columns are decompressed, the hashes of the data are used to verify equality -func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (map[uint64][]byte, error) { +func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schemaName, tableName, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (map[string][]byte, error) { c.logger.WithFields(logrus.Fields{ "tag": "compression_verifier", - "table": table, + "table": tableName, }).Info("decompressing table data before verification") - tableCompression := c.tableColumnCompressions[table] + tableCompression := c.tableColumnCompressions[tableName] // Extract the raw rows using SQL to be decompressed - rows, err := getRows(db, schema, table, paginationKeyColumn, columns, paginationKeys) + rows, err := getRows(db, schemaName, tableName, paginationKeyColumn, columns, paginationKeys) if err != nil { return nil, err } defer rows.Close() - // Decompress applicable columns and hash the resulting column values for comparison - resultSet := make(map[uint64][]byte) + table := c.TableSchemaCache.Get(schemaName, tableName) + if table == nil { + return nil, fmt.Errorf("table %s.%s not found in schema cache", schemaName, tableName) + } + paginationColumn := table.GetPaginationColumn() + resultSet := make(map[string][]byte) + for rows.Next() { rowData, err := ScanByteRow(rows, len(columns)+1) if err != nil { return nil, err } - paginationKey, err := strconv.ParseUint(string(rowData[0]), 10, 64) - if err != nil { - return nil, err + var paginationKeyStr string + switch paginationColumn.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + paginationKeyUint, err := strconv.ParseUint(string(rowData[0]), 10, 64) + if err != nil { + return nil, err + } + paginationKeyStr = NewUint64Key(paginationKeyUint).String() + + case schema.TYPE_BINARY, schema.TYPE_STRING: + paginationKeyStr = NewBinaryKey(rowData[0]).String() + + default: + paginationKeyUint, err := strconv.ParseUint(string(rowData[0]), 10, 64) + if err != nil { + return nil, err + } + paginationKeyStr = NewUint64Key(paginationKeyUint).String() } // Decompress the applicable columns and then hash them together @@ -95,7 +116,7 @@ func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, pag for idx, column := range columns { if algorithm, ok := tableCompression[column.Name]; ok { // rowData contains the result of "SELECT paginationKeyColumn, * FROM ...", so idx+1 to get each column - decompressedColData, err := c.Decompress(table, column.Name, algorithm, rowData[idx+1]) + decompressedColData, err := c.Decompress(tableName, column.Name, algorithm, rowData[idx+1]) if err != nil { return nil, err } @@ -111,20 +132,20 @@ func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, pag return nil, err } - resultSet[paginationKey] = decompressedRowHash + resultSet[paginationKeyStr] = decompressedRowHash } metrics.Gauge( "compression_verifier_decompress_rows", float64(len(resultSet)), - []MetricTag{{"table", table}}, + []MetricTag{{"table", tableName}}, 1.0, ) logrus.WithFields(logrus.Fields{ "tag": "compression_verifier", "rows": len(resultSet), - "table": table, + "table": tableName, }).Debug("decompressed rows will be compared") return resultSet, nil @@ -192,12 +213,13 @@ func (c *CompressionVerifier) verifyConfiguredCompression(tableColumnCompression // NewCompressionVerifier first checks the map for supported compression algorithms before // initializing and returning the initialized instance. -func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig) (*CompressionVerifier, error) { +func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig, tableSchemaCache TableSchemaCache) (*CompressionVerifier, error) { supportedAlgorithms := make(map[string]struct{}) supportedAlgorithms[CompressionSnappy] = struct{}{} compressionVerifier := &CompressionVerifier{ logger: logrus.WithField("tag", "compression_verifier"), + TableSchemaCache: tableSchemaCache, supportedAlgorithms: supportedAlgorithms, tableColumnCompressions: tableColumnCompressions, } @@ -209,7 +231,7 @@ func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig return compressionVerifier, nil } -func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (*sqlorig.Rows, error) { +func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (*sqlorig.Rows, error) { quotedPaginationKey := QuoteField(paginationKeyColumn) sql, args, err := rowSelector(columns, paginationKeyColumn). From(QuotedTableNameFromString(schema, table)). diff --git a/config.go b/config.go index d9351f01a..e99ca086e 100644 --- a/config.go +++ b/config.go @@ -376,12 +376,17 @@ func (c ForceIndexConfig) IndexFor(schemaName, tableName string) string { // CascadingPaginationColumnConfig to configure pagination columns to be // used. The term `Cascading` to denote that greater specificity takes // precedence. +// +// IMPORTANT: All configured pagination columns must contain unique values. +// When specifying a FallbackColumn for tables with composite primary keys, +// ensure the column has a unique constraint to prevent data loss during migration. type CascadingPaginationColumnConfig struct { // PerTable has greatest specificity and takes precedence over the other options PerTable map[string]map[string]string // SchemaName => TableName => ColumnName // FallbackColumn is a global default to fallback to and is less specific than the - // default, which is the Primary Key + // default, which is the Primary Key. + // This column MUST have unique values (ideally a unique constraint) for data integrity. FallbackColumn string } @@ -727,10 +732,15 @@ type Config struct { // ForceIndexForVerification ForceIndexConfig - // Ghostferry requires a single numeric column to paginate over tables. Inferring that column is done in the following exact order: + // Ghostferry requires a single numeric or binary column to paginate over tables. Inferring that column is done in the following exact order: // 1. Use the PerTable pagination column, if configured for a table. Fail if we cannot find this column in the table. - // 2. Use the table's primary key column as the pagination column. Fail if the primary key is not numeric or is a composite key without a FallbackColumn specified. + // 2. Use the table's primary key column as the pagination column. Fail if the primary key is not numeric/binary or is a composite key without a FallbackColumn specified. // 3. Use the FallbackColumn pagination column, if configured. Fail if we cannot find this column in the table. + // + // IMPORTANT: The pagination column MUST contain unique values for data integrity. + // When using a FallbackColumn (typically "id") for tables with composite primary keys, this column must have a unique constraint. + // The pagination algorithm uses WHERE pagination_key > last_key ORDER BY pagination_key LIMIT batch_size. + // If duplicate values exist, rows may be skipped during iteration, resulting in data loss during the migration. CascadingPaginationColumnConfig *CascadingPaginationColumnConfig // SkipTargetVerification is used to enable or disable target verification during moves. diff --git a/cursor.go b/cursor.go index 9a7a72ed1..c685f270c 100644 --- a/cursor.go +++ b/cursor.go @@ -38,7 +38,7 @@ type CursorConfig struct { Throttler Throttler ColumnsToSelect []string - BuildSelect func([]string, *TableSchema, uint64, uint64) (squirrel.SelectBuilder, error) + BuildSelect func([]string, *TableSchema, PaginationKey, uint64) (squirrel.SelectBuilder, error) // BatchSize is a pointer to the BatchSize in Config.UpdatableConfig which can be independently updated from this code. // Having it as a pointer allows the updated value to be read without needing additional code to copy the batch size value into the cursor config for each cursor we create. BatchSize *uint64 @@ -47,7 +47,7 @@ type CursorConfig struct { } // returns a new Cursor with an embedded copy of itself -func (c *CursorConfig) NewCursor(table *TableSchema, startPaginationKey, maxPaginationKey uint64) *Cursor { +func (c *CursorConfig) NewCursor(table *TableSchema, startPaginationKey, maxPaginationKey PaginationKey) *Cursor { return &Cursor{ CursorConfig: *c, Table: table, @@ -58,7 +58,7 @@ func (c *CursorConfig) NewCursor(table *TableSchema, startPaginationKey, maxPagi } // returns a new Cursor with an embedded copy of itself -func (c *CursorConfig) NewCursorWithoutRowLock(table *TableSchema, startPaginationKey, maxPaginationKey uint64) *Cursor { +func (c *CursorConfig) NewCursorWithoutRowLock(table *TableSchema, startPaginationKey, maxPaginationKey PaginationKey) *Cursor { cursor := c.NewCursor(table, startPaginationKey, maxPaginationKey) cursor.RowLock = false return cursor @@ -77,11 +77,11 @@ type Cursor struct { CursorConfig Table *TableSchema - MaxPaginationKey uint64 + MaxPaginationKey PaginationKey RowLock bool paginationKeyColumn *schema.TableColumn - lastSuccessfulPaginationKey uint64 + lastSuccessfulPaginationKey PaginationKey logger *logrus.Entry } @@ -96,10 +96,10 @@ func (c *Cursor) Each(f func(*RowBatch) error) error { c.ColumnsToSelect = []string{"*"} } - for c.lastSuccessfulPaginationKey < c.MaxPaginationKey { + for c.lastSuccessfulPaginationKey.Compare(c.MaxPaginationKey) < 0 { var tx SqlPreparerAndRollbacker var batch *RowBatch - var paginationKeypos uint64 + var paginationKeypos PaginationKey err := WithRetries(c.ReadRetries, 1*time.Second, c.logger, "fetch rows", func() (err error) { if c.Throttler != nil { @@ -137,9 +137,9 @@ func (c *Cursor) Each(f func(*RowBatch) error) error { break } - if paginationKeypos <= c.lastSuccessfulPaginationKey { + if paginationKeypos.Compare(c.lastSuccessfulPaginationKey) <= 0 { tx.Rollback() - err = fmt.Errorf("new paginationKeypos %d <= lastSuccessfulPaginationKey %d", paginationKeypos, c.lastSuccessfulPaginationKey) + err = fmt.Errorf("new paginationKeypos %s <= lastSuccessfulPaginationKey %s", paginationKeypos.String(), c.lastSuccessfulPaginationKey.String()) c.logger.WithError(err).Errorf("last successful paginationKey position did not advance") return err } @@ -159,7 +159,7 @@ func (c *Cursor) Each(f func(*RowBatch) error) error { return nil } -func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64, err error) { +func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos PaginationKey, err error) { var selectBuilder squirrel.SelectBuilder batchSize := c.CursorConfig.GetBatchSize(c.Table.Schema, c.Table.Name) @@ -176,7 +176,7 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64 if c.RowLock { mySqlVersion, err := c.DB.QueryMySQLVersion() if err != nil { - return nil, 0, err + return nil, NewUint64Key(0), err } if strings.HasPrefix(mySqlVersion, "8.") { selectBuilder = selectBuilder.Suffix("FOR SHARE NOWAIT") @@ -261,9 +261,10 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64 } if len(batchData) > 0 { - paginationKeypos, err = batchData[len(batchData)-1].GetUint64(paginationKeyIndex) + lastRowData := batchData[len(batchData)-1] + paginationKeypos, err = NewPaginationKeyFromRow(lastRowData, paginationKeyIndex, c.paginationKeyColumn) if err != nil { - logger.WithError(err).Error("failed to get uint64 paginationKey value") + logger.WithError(err).Error("failed to get paginationKey value") return } } @@ -304,12 +305,12 @@ func ScanByteRow(rows *sqlorig.Rows, columnCount int) ([][]byte, error) { return values, err } -func DefaultBuildSelect(columns []string, table *TableSchema, lastPaginationKey, batchSize uint64) squirrel.SelectBuilder { +func DefaultBuildSelect(columns []string, table *TableSchema, lastPaginationKey PaginationKey, batchSize uint64) squirrel.SelectBuilder { quotedPaginationKey := QuoteField(table.GetPaginationColumn().Name) return squirrel.Select(columns...). From(QuotedTableName(table)). - Where(squirrel.Gt{quotedPaginationKey: lastPaginationKey}). + Where(squirrel.Gt{quotedPaginationKey: lastPaginationKey.SQLValue()}). Limit(batchSize). OrderBy(quotedPaginationKey) } diff --git a/data_iterator.go b/data_iterator.go index 5621a24e0..dee62cfd3 100644 --- a/data_iterator.go +++ b/data_iterator.go @@ -2,7 +2,6 @@ package ghostferry import ( "fmt" - "math" "sync" sql "github.com/Shopify/ghostferry/sqlwrapper" @@ -28,7 +27,7 @@ type DataIterator struct { type TableMaxPaginationKey struct { Table *TableSchema - MaxPaginationKey uint64 + MaxPaginationKey PaginationKey } func (d *DataIterator) Run(tables []*TableSchema) { @@ -86,15 +85,15 @@ func (d *DataIterator) Run(tables []*TableSchema) { return } - startPaginationKey := d.StateTracker.LastSuccessfulPaginationKey(table.String()) - if startPaginationKey == math.MaxUint64 { + startPaginationKey := d.StateTracker.LastSuccessfulPaginationKey(table.String(), table) + if startPaginationKey.IsMax() { err := fmt.Errorf("%v has been marked as completed but a table iterator has been spawned, this is likely a programmer error which resulted in the inconsistent starting state", table.String()) logger.WithError(err).Error("this is definitely a bug") d.ErrorHandler.Fatal("data_iterator", err) return } - cursor := d.CursorConfig.NewCursor(table, startPaginationKey, targetPaginationKeyInterface.(uint64)) + cursor := d.CursorConfig.NewCursor(table, startPaginationKey, targetPaginationKeyInterface.(PaginationKey)) if d.SelectFingerprint { if len(cursor.ColumnsToSelect) == 0 { cursor.ColumnsToSelect = []string{"*"} @@ -110,17 +109,18 @@ func (d *DataIterator) Run(tables []*TableSchema) { }, 1.0) if d.SelectFingerprint { - fingerprints := make(map[uint64][]byte) + fingerprints := make(map[string][]byte) rows := make([]RowData, batch.Size()) + paginationColumn := table.GetPaginationColumn() for i, rowData := range batch.Values() { - paginationKey, err := rowData.GetUint64(batch.PaginationKeyIndex()) + paginationKey, err := NewPaginationKeyFromRow(rowData, batch.PaginationKeyIndex(), paginationColumn) if err != nil { logger.WithError(err).Error("failed to get paginationKey data") return err } - fingerprints[paginationKey] = rowData[len(rowData)-1].([]byte) + fingerprints[paginationKey.String()] = rowData[len(rowData)-1].([]byte) rows[i] = rowData[:len(rowData)-1] } diff --git a/data_iterator_sorter.go b/data_iterator_sorter.go index 4dc912869..1a2931069 100644 --- a/data_iterator_sorter.go +++ b/data_iterator_sorter.go @@ -8,13 +8,13 @@ import ( // DataIteratorSorter is an interface for the DataIterator to choose which order it will process table type DataIteratorSorter interface { - Sort(unorderedTables map[*TableSchema]uint64) ([]TableMaxPaginationKey, error) + Sort(unorderedTables map[*TableSchema]PaginationKey) ([]TableMaxPaginationKey, error) } // MaxPaginationKeySorter arranges table based on the MaxPaginationKey in DESC order type MaxPaginationKeySorter struct{} -func (s *MaxPaginationKeySorter) Sort(unorderedTables map[*TableSchema]uint64) ([]TableMaxPaginationKey, error) { +func (s *MaxPaginationKeySorter) Sort(unorderedTables map[*TableSchema]PaginationKey) ([]TableMaxPaginationKey, error) { orderedTables := make([]TableMaxPaginationKey, len(unorderedTables)) i := 0 @@ -24,7 +24,7 @@ func (s *MaxPaginationKeySorter) Sort(unorderedTables map[*TableSchema]uint64) ( } sort.Slice(orderedTables, func(i, j int) bool { - return orderedTables[i].MaxPaginationKey > orderedTables[j].MaxPaginationKey + return orderedTables[i].MaxPaginationKey.Compare(orderedTables[j].MaxPaginationKey) > 0 }) return orderedTables, nil @@ -35,7 +35,7 @@ type MaxTableSizeSorter struct { DataIterator *DataIterator } -func (s *MaxTableSizeSorter) Sort(unorderedTables map[*TableSchema]uint64) ([]TableMaxPaginationKey, error) { +func (s *MaxTableSizeSorter) Sort(unorderedTables map[*TableSchema]PaginationKey) ([]TableMaxPaginationKey, error) { orderedTables := []TableMaxPaginationKey{} tableNames := []string{} databaseSchemasSet := map[string]struct{}{} diff --git a/dml_events.go b/dml_events.go index 46695800f..7b8a00f1d 100644 --- a/dml_events.go +++ b/dml_events.go @@ -76,7 +76,7 @@ type DMLEvent interface { AsSQLString(string, string) (string, error) OldValues() RowData NewValues() RowData - PaginationKey() (uint64, error) + PaginationKey() (string, error) BinlogPosition() mysql.Position ResumableBinlogPosition() mysql.Position Annotation() (string, error) @@ -180,7 +180,7 @@ func (e *BinlogInsertEvent) AsSQLString(schemaName, tableName string) (string, e return query, nil } -func (e *BinlogInsertEvent) PaginationKey() (uint64, error) { +func (e *BinlogInsertEvent) PaginationKey() (string, error) { return paginationKeyFromEventData(e.table, e.newValues) } @@ -233,7 +233,7 @@ func (e *BinlogUpdateEvent) AsSQLString(schemaName, tableName string) (string, e return query, nil } -func (e *BinlogUpdateEvent) PaginationKey() (uint64, error) { +func (e *BinlogUpdateEvent) PaginationKey() (string, error) { return paginationKeyFromEventData(e.table, e.newValues) } @@ -274,7 +274,7 @@ func (e *BinlogDeleteEvent) AsSQLString(schemaName, tableName string) (string, e return query, nil } -func (e *BinlogDeleteEvent) PaginationKey() (uint64, error) { +func (e *BinlogDeleteEvent) PaginationKey() (string, error) { return paginationKeyFromEventData(e.table, e.oldValues) } @@ -571,10 +571,14 @@ func appendEscapedBuffer(buffer, value []byte, isJSON bool) []byte { return buffer } -func paginationKeyFromEventData(table *TableSchema, rowData RowData) (uint64, error) { +func paginationKeyFromEventData(table *TableSchema, rowData RowData) (string, error) { if err := verifyValuesHasTheSameLengthAsColumns(table, rowData); err != nil { - return 0, err + return "", err } - return rowData.GetUint64(table.GetPaginationKeyIndex()) + paginationKey, err := NewPaginationKeyFromRow(rowData, table.GetPaginationKeyIndex(), table.GetPaginationColumn()) + if err != nil { + return "", err + } + return paginationKey.String(), nil } diff --git a/ferry.go b/ferry.go index 3718931b9..28ab46efb 100644 --- a/ferry.go +++ b/ferry.go @@ -288,7 +288,7 @@ func (f *Ferry) NewIterativeVerifier() (*IterativeVerifier, error) { var compressionVerifier *CompressionVerifier if config.TableColumnCompression != nil { - compressionVerifier, err = NewCompressionVerifier(config.TableColumnCompression) + compressionVerifier, err = NewCompressionVerifier(config.TableColumnCompression, f.Tables) if err != nil { return nil, err } @@ -995,7 +995,7 @@ func (f *Ferry) Progress() *Progress { s.Tables = make(map[string]TableProgress) targetPaginationKeys := make(map[string]uint64) f.DataIterator.TargetPaginationKeys.Range(func(k, v interface{}) bool { - targetPaginationKeys[k.(string)] = v.(uint64) + targetPaginationKeys[k.(string)] = uint64(v.(PaginationKey).NumericPosition()) return true }) @@ -1009,7 +1009,7 @@ func (f *Ferry) Progress() *Progress { for _, table := range tables { var currentAction string tableName := table.String() - lastSuccessfulPaginationKey, foundInProgress := serializedState.LastSuccessfulPaginationKeys[tableName] + lastSuccessfulPaginationKeyInterface, foundInProgress := serializedState.LastSuccessfulPaginationKeys[tableName] if serializedState.CompletedTables[tableName] { currentAction = TableActionCompleted @@ -1022,6 +1022,11 @@ func (f *Ferry) Progress() *Progress { rowWrittenStats, _ := rowStatsWrittenPerTable[tableName] + var lastSuccessfulPaginationKey uint64 + if lastSuccessfulPaginationKeyInterface != nil { + lastSuccessfulPaginationKey = uint64(lastSuccessfulPaginationKeyInterface.NumericPosition()) + } + s.Tables[tableName] = TableProgress{ LastSuccessfulPaginationKey: lastSuccessfulPaginationKey, TargetPaginationKey: targetPaginationKeys[tableName], @@ -1041,7 +1046,7 @@ func (f *Ferry) Progress() *Progress { } for _, completedPaginationKey := range serializedState.LastSuccessfulPaginationKeys { - completedPaginationKeys += completedPaginationKey + completedPaginationKeys += uint64(completedPaginationKey.NumericPosition()) } var remainingPaginationKeys float64 = 0 diff --git a/filter.go b/filter.go index de6621429..9c367156e 100644 --- a/filter.go +++ b/filter.go @@ -12,10 +12,10 @@ type CopyFilter interface { // allowing for restricting copying to a subset of data. Returning an error // here will cause the query to be retried, until the retry limit is // reached, at which point the ferry will be aborted. BuildSelect is passed - // the columns to be selected, table being copied, the last primary key value + // the columns to be selected, table being copied, the last pagination key value // from the previous batch, and the batch size. Call DefaultBuildSelect to // generate the default query, which may be used as a starting point. - BuildSelect([]string, *TableSchema, uint64, uint64) (sq.SelectBuilder, error) + BuildSelect([]string, *TableSchema, PaginationKey, uint64) (sq.SelectBuilder, error) // ApplicableEvent is used to filter events for rows that have been // filtered in ConstrainSelect. ApplicableEvent should return true if the diff --git a/inline_verifier.go b/inline_verifier.go index 552c88e6e..df8e2055a 100644 --- a/inline_verifier.go +++ b/inline_verifier.go @@ -15,6 +15,7 @@ import ( sql "github.com/Shopify/ghostferry/sqlwrapper" + "github.com/go-mysql-org/go-mysql/schema" "github.com/golang/snappy" "github.com/sirupsen/logrus" ) @@ -56,7 +57,7 @@ type BinlogVerifyStore struct { currentRowCount uint64 // The number of rows in store currently. } -type BinlogVerifySerializedStore map[string]map[string]map[uint64]int +type BinlogVerifySerializedStore map[string]map[string]map[string]int func (s BinlogVerifySerializedStore) RowCount() uint64 { var v uint64 = 0 @@ -85,9 +86,9 @@ func (s BinlogVerifySerializedStore) Copy() BinlogVerifySerializedStore { copyS := make(BinlogVerifySerializedStore) for db, _ := range s { - copyS[db] = make(map[string]map[uint64]int) + copyS[db] = make(map[string]map[string]int) for table, _ := range s[db] { - copyS[db][table] = make(map[uint64]int) + copyS[db][table] = make(map[string]int) for paginationKey, count := range s[db][table] { copyS[db][table][paginationKey] = count } @@ -100,14 +101,14 @@ func (s BinlogVerifySerializedStore) Copy() BinlogVerifySerializedStore { type BinlogVerifyBatch struct { SchemaName string TableName string - PaginationKeys []uint64 + PaginationKeys []interface{} } func NewBinlogVerifyStore() *BinlogVerifyStore { return &BinlogVerifyStore{ - EmitLogPerRowsAdded: uint64(10000), // TODO: make this configurable + EmitLogPerRowsAdded: uint64(10000), mutex: &sync.Mutex{}, - store: make(map[string]map[string]map[uint64]int), + store: make(map[string]map[string]map[string]int), totalRowCount: uint64(0), currentRowCount: uint64(0), } @@ -123,18 +124,18 @@ func NewBinlogVerifyStoreFromSerialized(serialized BinlogVerifySerializedStore) return s } -func (s *BinlogVerifyStore) Add(table *TableSchema, paginationKey uint64) { +func (s *BinlogVerifyStore) Add(table *TableSchema, paginationKey string) { s.mutex.Lock() defer s.mutex.Unlock() _, exists := s.store[table.Schema] if !exists { - s.store[table.Schema] = make(map[string]map[uint64]int) + s.store[table.Schema] = make(map[string]map[string]int) } _, exists = s.store[table.Schema][table.Name] if !exists { - s.store[table.Schema][table.Name] = make(map[uint64]int) + s.store[table.Schema][table.Name] = make(map[string]int) } _, exists = s.store[table.Schema][table.Name][paginationKey] @@ -172,13 +173,15 @@ func (s *BinlogVerifyStore) RemoveVerifiedBatch(batch BinlogVerifyBatch) { } for _, paginationKey := range batch.PaginationKeys { - if _, exists = tableStore[paginationKey]; exists { - if tableStore[paginationKey] <= 1 { - // Even though this doesn't save as RAM, it will save space on the - // serialized output. - delete(tableStore, paginationKey) + paginationKeyStr, ok := paginationKey.(string) + if !ok { + continue + } + if _, exists = tableStore[paginationKeyStr]; exists { + if tableStore[paginationKeyStr] <= 1 { + delete(tableStore, paginationKeyStr) } else { - tableStore[paginationKey]-- + tableStore[paginationKeyStr]-- } s.currentRowCount-- } @@ -192,17 +195,17 @@ func (s *BinlogVerifyStore) Batches(batchsize int) []BinlogVerifyBatch { batches := make([]BinlogVerifyBatch, 0) for schemaName, _ := range s.store { for tableName, paginationKeySet := range s.store[schemaName] { - paginationKeyBatch := make([]uint64, 0, batchsize) + paginationKeyBatch := make([]interface{}, 0, batchsize) - for paginationKey, _ := range paginationKeySet { - paginationKeyBatch = append(paginationKeyBatch, paginationKey) + for paginationKeyStr, _ := range paginationKeySet { + paginationKeyBatch = append(paginationKeyBatch, paginationKeyStr) if len(paginationKeyBatch) >= batchsize { batches = append(batches, BinlogVerifyBatch{ SchemaName: schemaName, TableName: tableName, PaginationKeys: paginationKeyBatch, }) - paginationKeyBatch = make([]uint64, 0, batchsize) + paginationKeyBatch = make([]interface{}, 0, batchsize) } } @@ -247,7 +250,7 @@ const ( ) type InlineVerifierMismatches struct { - Pk uint64 + Pk string SourceChecksum string TargetChecksum string MismatchColumn string @@ -328,15 +331,15 @@ func (v *InlineVerifier) Result() (VerificationResultAndStatus, error) { func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, targetTable string, sourceBatch *RowBatch, enforceInlineVerification bool) ([]InlineVerifierMismatches, error) { table := sourceBatch.TableSchema() + paginationColumn := table.GetPaginationColumn() - paginationKeys := make([]uint64, len(sourceBatch.Values())) + paginationKeys := make([]interface{}, len(sourceBatch.Values())) for i, row := range sourceBatch.Values() { - paginationKey, err := row.GetUint64(sourceBatch.PaginationKeyIndex()) + paginationKey, err := NewPaginationKeyFromRow(row, sourceBatch.PaginationKeyIndex(), paginationColumn) if err != nil { return nil, err } - - paginationKeys[i] = paginationKey + paginationKeys[i] = paginationKey.SQLValue() } // Fetch target data @@ -347,15 +350,16 @@ func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, target // Fetch source data sourceFingerprints := sourceBatch.Fingerprints() - sourceDecompressedData := make(map[uint64]map[string][]byte) + sourceDecompressedData := make(map[string]map[string][]byte) for _, rowData := range sourceBatch.Values() { - paginationKey, err := rowData.GetUint64(sourceBatch.PaginationKeyIndex()) + paginationKey, err := NewPaginationKeyFromRow(rowData, sourceBatch.PaginationKeyIndex(), paginationColumn) if err != nil { return nil, err } + paginationKeyStr := paginationKey.String() - sourceDecompressedData[paginationKey] = make(map[string][]byte) + sourceDecompressedData[paginationKeyStr] = make(map[string][]byte) for idx, col := range table.Columns { var compressedData []byte var ok bool @@ -368,7 +372,7 @@ func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, target return nil, fmt.Errorf("cannot convert column %v to []byte", col.Name) } - sourceDecompressedData[paginationKey][col.Name], err = v.decompressData(table, col.Name, compressedData) + sourceDecompressedData[paginationKeyStr][col.Name], err = v.decompressData(table, col.Name, compressedData) } } @@ -468,7 +472,7 @@ func formatMismatches(mismatches map[string]map[string][]InlineVerifierMismatche messageBuf.WriteString(tableNameWithSchema) messageBuf.WriteString(" [PKs: ") for _, mismatch := range mismatches[schemaName][tableName] { - messageBuf.WriteString(strconv.FormatUint(mismatch.Pk, 10)) + messageBuf.WriteString(mismatch.Pk) messageBuf.WriteString(" (type: ") messageBuf.WriteString(string(mismatch.MismatchType)) if mismatch.SourceChecksum != "" { @@ -521,15 +525,15 @@ func (v *InlineVerifier) VerifyDuringCutover() (VerificationResult, error) { }, nil } -func (v *InlineVerifier) getFingerprintDataFromSourceDb(schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []uint64) (map[uint64][]byte, map[uint64]map[string][]byte, error) { +func (v *InlineVerifier) getFingerprintDataFromSourceDb(schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []interface{}) (map[string][]byte, map[string]map[string][]byte, error) { return v.getFingerprintDataFromDb(v.SourceDB, v.sourceStmtCache, schemaName, tableName, tx, table, paginationKeys) } -func (v *InlineVerifier) getFingerprintDataFromTargetDb(schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []uint64) (map[uint64][]byte, map[uint64]map[string][]byte, error) { +func (v *InlineVerifier) getFingerprintDataFromTargetDb(schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []interface{}) (map[string][]byte, map[string]map[string][]byte, error) { return v.getFingerprintDataFromDb(v.TargetDB, v.targetStmtCache, schemaName, tableName, tx, table, paginationKeys) } -func (v *InlineVerifier) getFingerprintDataFromDb(db *sql.DB, stmtCache *StmtCache, schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []uint64) (map[uint64][]byte, map[uint64]map[string][]byte, error) { +func (v *InlineVerifier) getFingerprintDataFromDb(db *sql.DB, stmtCache *StmtCache, schemaName, tableName string, tx *sql.Tx, table *TableSchema, paginationKeys []interface{}) (map[string][]byte, map[string]map[string][]byte, error) { fingerprintQuery := table.FingerprintQuery(schemaName, tableName, len(paginationKeys)) fingerprintStmt, err := stmtCache.StmtFor(db, fingerprintQuery) if err != nil { @@ -555,8 +559,9 @@ func (v *InlineVerifier) getFingerprintDataFromDb(db *sql.DB, stmtCache *StmtCac return nil, nil, err } - fingerprints := make(map[uint64][]byte) // paginationKey -> fingerprint - decompressedData := make(map[uint64]map[string][]byte) // paginationKey -> columnName -> decompressedData + fingerprints := make(map[string][]byte) + decompressedData := make(map[string]map[string][]byte) + paginationColumn := table.GetPaginationColumn() for rows.Next() { rowData, err := ScanByteRow(rows, len(columns)) @@ -564,20 +569,31 @@ func (v *InlineVerifier) getFingerprintDataFromDb(db *sql.DB, stmtCache *StmtCac return nil, nil, err } - paginationKey, err := strconv.ParseUint(string(rowData[0]), 10, 64) - if err != nil { - return nil, nil, err + var paginationKeyStr string + switch paginationColumn.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + paginationKeyUint, err := strconv.ParseUint(string(rowData[0]), 10, 64) + if err != nil { + return nil, nil, err + } + paginationKeyStr = NewUint64Key(paginationKeyUint).String() + + case schema.TYPE_BINARY, schema.TYPE_STRING: + paginationKeyStr = NewBinaryKey(rowData[0]).String() + + default: + paginationKeyUint, err := strconv.ParseUint(string(rowData[0]), 10, 64) + if err != nil { + return nil, nil, err + } + paginationKeyStr = NewUint64Key(paginationKeyUint).String() } - fingerprints[paginationKey] = rowData[1] - decompressedData[paginationKey] = make(map[string][]byte) + fingerprints[paginationKeyStr] = rowData[1] + decompressedData[paginationKeyStr] = make(map[string][]byte) - // Note that the FingerprintQuery returns the columns: paginationKey, fingerprint, - // compressedData1, compressedData2, ... - // If there are no compressed data, only 2 columns are returned and this - // loop will be skipped. for i := 2; i < len(columns); i++ { - decompressedData[paginationKey][columns[i]], err = v.decompressData(table, columns[i], rowData[i]) + decompressedData[paginationKeyStr][columns[i]], err = v.decompressData(table, columns[i], rowData[i]) if err != nil { return nil, nil, err } @@ -606,8 +622,8 @@ func (v *InlineVerifier) decompressData(table *TableSchema, column string, compr } } -func (v *InlineVerifier) compareHashes(source, target map[uint64][]byte) map[uint64]InlineVerifierMismatches { - mismatchSet := map[uint64]InlineVerifierMismatches{} +func (v *InlineVerifier) compareHashes(source, target map[string][]byte) map[string]InlineVerifierMismatches { + mismatchSet := map[string]InlineVerifierMismatches{} for paginationKey, targetHash := range target { sourceHash, exists := source[paginationKey] @@ -639,8 +655,8 @@ func (v *InlineVerifier) compareHashes(source, target map[uint64][]byte) map[uin return mismatchSet } -func compareDecompressedData(source, target map[uint64]map[string][]byte) map[uint64]InlineVerifierMismatches { - mismatchSet := map[uint64]InlineVerifierMismatches{} +func compareDecompressedData(source, target map[string]map[string][]byte) map[string]InlineVerifierMismatches { + mismatchSet := map[string]InlineVerifierMismatches{} for paginationKey, targetDecompressedColumns := range target { sourceDecompressedColumns, exists := source[paginationKey] @@ -704,7 +720,7 @@ func compareDecompressedData(source, target map[uint64]map[string][]byte) map[ui return mismatchSet } -func (v *InlineVerifier) compareHashesAndData(sourceHashes, targetHashes map[uint64][]byte, sourceData, targetData map[uint64]map[string][]byte) []InlineVerifierMismatches { +func (v *InlineVerifier) compareHashesAndData(sourceHashes, targetHashes map[string][]byte, sourceData, targetData map[string]map[string][]byte) []InlineVerifierMismatches { mismatches := v.compareHashes(sourceHashes, targetHashes) compressedMismatch := compareDecompressedData(sourceData, targetData) for paginationKey, mismatch := range compressedMismatch { @@ -812,8 +828,8 @@ func (v *InlineVerifier) verifyBinlogBatch(batch BinlogVerifyBatch) ([]InlineVer wg := &sync.WaitGroup{} wg.Add(2) - var sourceFingerprints map[uint64][]byte - var sourceDecompressedData map[uint64]map[string][]byte + var sourceFingerprints map[string][]byte + var sourceDecompressedData map[string]map[string][]byte var sourceErr error go func() { defer wg.Done() @@ -828,8 +844,8 @@ func (v *InlineVerifier) verifyBinlogBatch(batch BinlogVerifyBatch) ([]InlineVer }) }() - var targetFingerprints map[uint64][]byte - var targetDecompressedData map[uint64]map[string][]byte + var targetFingerprints map[string][]byte + var targetDecompressedData map[string]map[string][]byte var targetErr error go func() { defer wg.Done() diff --git a/inline_verifier_test.go b/inline_verifier_test.go index 1db37362b..58b5073d4 100644 --- a/inline_verifier_test.go +++ b/inline_verifier_test.go @@ -7,31 +7,31 @@ import ( ) func TestCompareDecompressedDataNoDifference(t *testing.T) { - source := map[uint64]map[string][]byte{ - 31: {"name": []byte("Leszek")}, + source := map[string]map[string][]byte{ + "31": {"name": []byte("Leszek")}, } - target := map[uint64]map[string][]byte{ - 31: {"name": []byte("Leszek")}, + target := map[string]map[string][]byte{ + "31": {"name": []byte("Leszek")}, } result := compareDecompressedData(source, target) - assert.Equal(t, map[uint64]InlineVerifierMismatches{}, result) + assert.Equal(t, map[string]InlineVerifierMismatches{}, result) } func TestCompareDecompressedDataContentDifference(t *testing.T) { - source := map[uint64]map[string][]byte{ - 1: {"name": []byte("Leszek")}, + source := map[string]map[string][]byte{ + "1": {"name": []byte("Leszek")}, } - target := map[uint64]map[string][]byte{ - 1: {"name": []byte("Steve")}, + target := map[string]map[string][]byte{ + "1": {"name": []byte("Steve")}, } result := compareDecompressedData(source, target) - assert.Equal(t, map[uint64]InlineVerifierMismatches{ - 1: { - Pk: 1, + assert.Equal(t, map[string]InlineVerifierMismatches{ + "1": { + Pk: "1", MismatchType: MismatchColumnValueDifference, MismatchColumn: "name", SourceChecksum: "e356a972989f87a1531252cfa2152797", @@ -41,25 +41,25 @@ func TestCompareDecompressedDataContentDifference(t *testing.T) { } func TestCompareDecompressedDataMissingTarget(t *testing.T) { - source := map[uint64]map[string][]byte{ - 1: {"name": []byte("Leszek")}, + source := map[string]map[string][]byte{ + "1": {"name": []byte("Leszek")}, } - target := map[uint64]map[string][]byte{} + target := map[string]map[string][]byte{} result := compareDecompressedData(source, target) - assert.Equal(t, map[uint64]InlineVerifierMismatches{1: {Pk: 1, MismatchType: MismatchRowMissingOnTarget}}, result) + assert.Equal(t, map[string]InlineVerifierMismatches{"1": {Pk: "1", MismatchType: MismatchRowMissingOnTarget}}, result) } func TestCompareDecompressedDataMissingSource(t *testing.T) { - source := map[uint64]map[string][]byte{} - target := map[uint64]map[string][]byte{ - 3: {"name": []byte("Leszek")}, + source := map[string]map[string][]byte{} + target := map[string]map[string][]byte{ + "3": {"name": []byte("Leszek")}, } result := compareDecompressedData(source, target) - assert.Equal(t, map[uint64]InlineVerifierMismatches{3: {Pk: 3, MismatchType: MismatchRowMissingOnSource}}, result) + assert.Equal(t, map[string]InlineVerifierMismatches{"3": {Pk: "3", MismatchType: MismatchRowMissingOnSource}}, result) } func TestFormatMismatch(t *testing.T) { @@ -67,7 +67,7 @@ func TestFormatMismatch(t *testing.T) { "default": { "users": { InlineVerifierMismatches{ - Pk: 1, + Pk: "1", MismatchType: MismatchRowMissingOnSource, }, }, @@ -84,17 +84,17 @@ func TestFormatMismatches(t *testing.T) { "default": { "users": { InlineVerifierMismatches{ - Pk: 1, + Pk: "1", MismatchType: MismatchRowMissingOnSource, }, InlineVerifierMismatches{ - Pk: 5, + Pk: "5", MismatchType: MismatchRowMissingOnTarget, }, }, "posts": { InlineVerifierMismatches{ - Pk: 9, + Pk: "9", MismatchType: MismatchColumnValueDifference, MismatchColumn: string("title"), SourceChecksum: "boo", @@ -103,7 +103,7 @@ func TestFormatMismatches(t *testing.T) { }, "attachments": { InlineVerifierMismatches{ - Pk: 7, + Pk: "7", MismatchType: MismatchColumnValueDifference, MismatchColumn: string("name"), SourceChecksum: "boo", diff --git a/iterative_verifier.go b/iterative_verifier.go index 114bf35e5..3476ebe60 100644 --- a/iterative_verifier.go +++ b/iterative_verifier.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "fmt" - "math" "strconv" "strings" "sync" @@ -18,17 +17,17 @@ import ( ) type ReverifyBatch struct { - PaginationKeys []uint64 + PaginationKeys []interface{} Table TableIdentifier } type ReverifyEntry struct { - PaginationKey uint64 + PaginationKey string Table *TableSchema } type ReverifyStore struct { - MapStore map[TableIdentifier]map[uint64]struct{} + MapStore map[TableIdentifier]map[string]struct{} mapStoreMutex *sync.Mutex BatchStore []ReverifyBatch RowCount uint64 @@ -50,13 +49,14 @@ func (r *ReverifyStore) Add(entry ReverifyEntry) { r.mapStoreMutex.Lock() defer r.mapStoreMutex.Unlock() + paginationKeyStr := entry.PaginationKey tableId := NewTableIdentifierFromSchemaTable(entry.Table) if _, exists := r.MapStore[tableId]; !exists { - r.MapStore[tableId] = make(map[uint64]struct{}) + r.MapStore[tableId] = make(map[string]struct{}) } - if _, exists := r.MapStore[tableId][entry.PaginationKey]; !exists { - r.MapStore[tableId][entry.PaginationKey] = struct{}{} + if _, exists := r.MapStore[tableId][paginationKeyStr]; !exists { + r.MapStore[tableId][paginationKeyStr] = struct{}{} r.RowCount++ if r.RowCount%r.EmitLogPerRowCount == 0 { metrics.Gauge("iterative_verifier_store_rows", float64(r.RowCount), []MetricTag{}, 1.0) @@ -74,16 +74,16 @@ func (r *ReverifyStore) FlushAndBatchByTable(batchsize int) []ReverifyBatch { r.BatchStore = make([]ReverifyBatch, 0) for tableId, paginationKeySet := range r.MapStore { - paginationKeyBatch := make([]uint64, 0, batchsize) - for paginationKey, _ := range paginationKeySet { - paginationKeyBatch = append(paginationKeyBatch, paginationKey) - delete(paginationKeySet, paginationKey) + paginationKeyBatch := make([]interface{}, 0, batchsize) + for paginationKeyStr, _ := range paginationKeySet { + paginationKeyBatch = append(paginationKeyBatch, paginationKeyStr) + delete(paginationKeySet, paginationKeyStr) if len(paginationKeyBatch) >= batchsize { r.BatchStore = append(r.BatchStore, ReverifyBatch{ PaginationKeys: paginationKeyBatch, Table: tableId, }) - paginationKeyBatch = make([]uint64, 0, batchsize) + paginationKeyBatch = make([]interface{}, 0, batchsize) } } @@ -102,7 +102,7 @@ func (r *ReverifyStore) FlushAndBatchByTable(batchsize int) []ReverifyBatch { } func (r *ReverifyStore) flushStore() { - r.MapStore = make(map[TableIdentifier]map[uint64]struct{}) + r.MapStore = make(map[TableIdentifier]map[string]struct{}) r.RowCount = 0 } @@ -184,10 +184,10 @@ func (v *IterativeVerifier) Initialize() error { func (v *IterativeVerifier) VerifyOnce() (VerificationResult, error) { v.logger.Info("starting one-off verification of all tables") - err := v.iterateAllTables(func(paginationKey uint64, tableSchema *TableSchema) error { + err := v.iterateAllTables(func(paginationKey string, tableSchema *TableSchema) error { return VerificationResult{ DataCorrect: false, - Message: fmt.Sprintf("verification failed on table: %s for paginationKey: %d", tableSchema.String(), paginationKey), + Message: fmt.Sprintf("verification failed on table: %s for paginationKey: %s", tableSchema.String(), paginationKey), IncorrectTables: []string{tableSchema.String()}, } }) @@ -213,7 +213,7 @@ func (v *IterativeVerifier) VerifyBeforeCutover() error { v.BinlogStreamer.AddEventListener(v.binlogEventListener) v.logger.Debug("verifying all tables") - err := v.iterateAllTables(func(paginationKey uint64, tableSchema *TableSchema) error { + err := v.iterateAllTables(func(paginationKey string, tableSchema *TableSchema) error { v.reverifyStore.Add(ReverifyEntry{PaginationKey: paginationKey, Table: tableSchema}) return nil }) @@ -290,15 +290,12 @@ func (v *IterativeVerifier) Result() (VerificationResultAndStatus, error) { return v.verificationResultAndStatus, v.verificationErr } -func (v *IterativeVerifier) GetHashes(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (map[uint64][]byte, error) { - sql, args, err := GetMd5HashesSql(schema, table, paginationKeyColumn, columns, paginationKeys) +func (v *IterativeVerifier) GetHashes(db *sql.DB, schemaName, tableName, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (map[string][]byte, error) { + sql, args, err := GetMd5HashesSql(schemaName, tableName, paginationKeyColumn, columns, paginationKeys) if err != nil { return nil, err } - // This query must be a prepared query. If it is not, querying will use - // MySQL's plain text interface, which will scan all values into []uint8 - // if we give it []interface{}. stmt, err := db.Prepare(sql) if err != nil { return nil, err @@ -313,19 +310,22 @@ func (v *IterativeVerifier) GetHashes(db *sql.DB, schema, table, paginationKeyCo defer rows.Close() - resultSet := make(map[uint64][]byte) + table := v.TableSchemaCache.Get(schemaName, tableName) + paginationColumn := table.GetPaginationColumn() + resultSet := make(map[string][]byte) + for rows.Next() { rowData, err := ScanGenericRow(rows, 2) if err != nil { return nil, err } - paginationKey, err := rowData.GetUint64(0) + paginationKey, err := NewPaginationKeyFromRow(rowData, 0, paginationColumn) if err != nil { return nil, err } - resultSet[paginationKey] = rowData[1].([]byte) + resultSet[paginationKey.String()] = rowData[1].([]byte) } return resultSet, nil } @@ -363,7 +363,7 @@ func (v *IterativeVerifier) reverifyUntilStoreIsSmallEnough(maxIterations int) e return nil } -func (v *IterativeVerifier) iterateAllTables(mismatchedPaginationKeyFunc func(uint64, *TableSchema) error) error { +func (v *IterativeVerifier) iterateAllTables(mismatchedPaginationKeyFunc func(string, *TableSchema) error) error { pool := &WorkerPool{ Concurrency: v.Concurrency, Process: func(tableIndex int) (interface{}, error) { @@ -386,10 +386,10 @@ func (v *IterativeVerifier) iterateAllTables(mismatchedPaginationKeyFunc func(ui return err } -func (v *IterativeVerifier) iterateTableFingerprints(table *TableSchema, mismatchedPaginationKeyFunc func(uint64, *TableSchema) error) error { +func (v *IterativeVerifier) iterateTableFingerprints(table *TableSchema, mismatchedPaginationKeyFunc func(string, *TableSchema) error) error { // The cursor will stop iterating when it cannot find anymore rows, - // so it will not iterate until MaxUint64. - cursor := v.CursorConfig.NewCursorWithoutRowLock(table, 0, math.MaxUint64) + // so it will not iterate until MaxPaginationKey. + cursor := v.CursorConfig.NewCursorWithoutRowLock(table, MinPaginationKey(table.GetPaginationColumn()), MaxPaginationKey(table.GetPaginationColumn())) // It only needs the PaginationKeys, not the entire row. cursor.ColumnsToSelect = []string{fmt.Sprintf("`%s`", table.GetPaginationColumn().Name)} @@ -399,15 +399,15 @@ func (v *IterativeVerifier) iterateTableFingerprints(table *TableSchema, mismatc MetricTag{"source", "iterative_verifier_before_cutover"}, }, 1.0) - paginationKeys := make([]uint64, 0, batch.Size()) + paginationKeys := make([]interface{}, 0, batch.Size()) + paginationColumn := table.GetPaginationColumn() for _, rowData := range batch.Values() { - paginationKey, err := rowData.GetUint64(batch.PaginationKeyIndex()) + paginationKey, err := NewPaginationKeyFromRow(rowData, batch.PaginationKeyIndex(), paginationColumn) if err != nil { return err } - - paginationKeys = append(paginationKeys, paginationKey) + paginationKeys = append(paginationKeys, paginationKey.SQLValue()) } mismatchedPaginationKeys, err := v.compareFingerprints(paginationKeys, batch.TableSchema()) @@ -513,7 +513,7 @@ func (v *IterativeVerifier) verifyStore(sourceTag string, additionalTags []Metri return result, err } -func (v *IterativeVerifier) reverifyPaginationKeys(table *TableSchema, paginationKeys []uint64) (VerificationResult, []uint64, error) { +func (v *IterativeVerifier) reverifyPaginationKeys(table *TableSchema, paginationKeys []interface{}) (VerificationResult, []string, error) { mismatchedPaginationKeys, err := v.compareFingerprints(paginationKeys, table) if err != nil { return VerificationResult{}, mismatchedPaginationKeys, err @@ -523,14 +523,9 @@ func (v *IterativeVerifier) reverifyPaginationKeys(table *TableSchema, paginatio return NewCorrectVerificationResult(), mismatchedPaginationKeys, nil } - paginationKeyStrings := make([]string, len(mismatchedPaginationKeys)) - for idx, paginationKey := range mismatchedPaginationKeys { - paginationKeyStrings[idx] = strconv.FormatUint(paginationKey, 10) - } - return VerificationResult{ DataCorrect: false, - Message: fmt.Sprintf("verification failed on table: %s for paginationKeys: %s", table.String(), strings.Join(paginationKeyStrings, ",")), + Message: fmt.Sprintf("verification failed on table: %s for paginationKeys: %s", table.String(), strings.Join(mismatchedPaginationKeys, ",")), IncorrectTables: []string{table.String()}, }, mismatchedPaginationKeys, nil } @@ -582,7 +577,7 @@ func (v *IterativeVerifier) columnsToVerify(table *TableSchema) []schema.TableCo return columns } -func (v *IterativeVerifier) compareFingerprints(paginationKeys []uint64, table *TableSchema) ([]uint64, error) { +func (v *IterativeVerifier) compareFingerprints(paginationKeys []interface{}, table *TableSchema) ([]string, error) { targetDb := table.Schema if targetDbName, exists := v.DatabaseRewrites[targetDb]; exists { targetDb = targetDbName @@ -596,7 +591,7 @@ func (v *IterativeVerifier) compareFingerprints(paginationKeys []uint64, table * wg := &sync.WaitGroup{} wg.Add(2) - var sourceHashes map[uint64][]byte + var sourceHashes map[string][]byte var sourceErr error go func() { defer wg.Done() @@ -606,7 +601,7 @@ func (v *IterativeVerifier) compareFingerprints(paginationKeys []uint64, table * }) }() - var targetHashes map[uint64][]byte + var targetHashes map[string][]byte var targetErr error go func() { defer wg.Done() @@ -632,7 +627,7 @@ func (v *IterativeVerifier) compareFingerprints(paginationKeys []uint64, table * return mismatches, nil } -func (v *IterativeVerifier) compareCompressedHashes(targetDb, targetTable string, table *TableSchema, paginationKeys []uint64) ([]uint64, error) { +func (v *IterativeVerifier) compareCompressedHashes(targetDb, targetTable string, table *TableSchema, paginationKeys []interface{}) ([]string, error) { sourceHashes, err := v.CompressionVerifier.GetCompressedHashes(v.SourceDB, table.Schema, table.Name, table.GetPaginationColumn().Name, v.columnsToVerify(table), paginationKeys) if err != nil { return nil, err @@ -646,8 +641,8 @@ func (v *IterativeVerifier) compareCompressedHashes(targetDb, targetTable string return compareHashes(sourceHashes, targetHashes), nil } -func compareHashes(source, target map[uint64][]byte) []uint64 { - mismatchSet := map[uint64]struct{}{} +func compareHashes(source, target map[string][]byte) []string { + mismatchSet := map[string]struct{}{} for paginationKey, targetHash := range target { sourceHash, exists := source[paginationKey] @@ -663,7 +658,7 @@ func compareHashes(source, target map[uint64][]byte) []uint64 { } } - mismatches := make([]uint64, 0, len(mismatchSet)) + mismatches := make([]string, 0, len(mismatchSet)) for mismatch, _ := range mismatchSet { mismatches = append(mismatches, mismatch) } @@ -671,7 +666,7 @@ func compareHashes(source, target map[uint64][]byte) []uint64 { return mismatches } -func GetMd5HashesSql(schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (string, []interface{}, error) { +func GetMd5HashesSql(schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (string, []interface{}, error) { quotedPaginationKey := QuoteField(paginationKeyColumn) return rowMd5Selector(columns, paginationKeyColumn). From(QuotedTableNameFromString(schema, table)). diff --git a/pagination_key.go b/pagination_key.go new file mode 100644 index 000000000..4906632e0 --- /dev/null +++ b/pagination_key.go @@ -0,0 +1,262 @@ +package ghostferry + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "math" + + "github.com/go-mysql-org/go-mysql/schema" +) + +// PaginationKey represents a cursor position for paginating through table rows. +// It abstracts over different primary key types (integers, UUIDs, binary data) +// to enable consistent batched iteration through tables. +type PaginationKey interface { + // SQLValue returns the value to use in SQL WHERE clauses (e.g., WHERE id > ?). + SQLValue() interface{} + // Compare returns -1, 0, or 1 if this key is less than, equal to, or greater than other. + Compare(other PaginationKey) int + // NumericPosition returns a float64 approximation for progress tracking and estimation. + NumericPosition() float64 + // String returns a human-readable representation for logging and debugging. + String() string + // MarshalJSON serializes the key for state persistence and checkpointing. + MarshalJSON() ([]byte, error) + // IsMax returns true if this key represents the maximum possible value for its type. + IsMax() bool +} + +type Uint64Key uint64 + +func NewUint64Key(value uint64) Uint64Key { + return Uint64Key(value) +} + +func (k Uint64Key) SQLValue() interface{} { + return uint64(k) +} + +func (k Uint64Key) Compare(other PaginationKey) int { + otherKey, ok := other.(Uint64Key) + if !ok { + panic(fmt.Sprintf("cannot compare Uint64Key with %T", other)) + } + + if k < otherKey { + return -1 + } else if k > otherKey { + return 1 + } + return 0 +} + +func (k Uint64Key) NumericPosition() float64 { + return float64(k) +} + +func (k Uint64Key) String() string { + return fmt.Sprintf("%d", uint64(k)) +} + +func (k Uint64Key) IsMax() bool { + return k == Uint64Key(math.MaxUint64) +} + +func (k Uint64Key) MarshalJSON() ([]byte, error) { + return json.Marshal(uint64(k)) +} + +type BinaryKey []byte + +func NewBinaryKey(value []byte) BinaryKey { + clone := make([]byte, len(value)) + copy(clone, value) + return BinaryKey(clone) +} + +func (k BinaryKey) SQLValue() interface{} { + return []byte(k) +} + +func (k BinaryKey) Compare(other PaginationKey) int { + otherKey, ok := other.(BinaryKey) + if !ok { + panic(fmt.Sprintf("type mismatch: cannot compare BinaryKey with %T", other)) + } + return bytes.Compare(k, otherKey) +} + +// NumericPosition calculates a rough float position for progress tracking. +// +// Note: This method only uses the first 8 bytes of the binary key for progress calculation. +// This works well for timestamp-based keys like UUID v7 (where the first 48 bits are a timestamp), +// but progress may appear frozen when processing rows that differ only in bytes 9+. +// For random binary keys (like UUID v4), progress will be unpredictable. +// +// The core pagination algorithm (using Compare()) is unaffected and works correctly with any binary data. +func (k BinaryKey) NumericPosition() float64 { + if len(k) == 0 { + return 0.0 + } + + // Take up to the first 8 bytes to form a uint64 for estimation + var buf [8]byte + copy(buf[:], k) + + val := binary.BigEndian.Uint64(buf[:]) + return float64(val) +} + +func (k BinaryKey) String() string { + return hex.EncodeToString(k) +} + +func (k BinaryKey) IsMax() bool { + // We cannot know the true "Max" of a VARBINARY without knowing the length. + // However, for UUID(16), we can check for FF... + if len(k) == 0 { + return false + } + for _, b := range k { + if b != 0xFF { + return false + } + } + return true +} + +func (k BinaryKey) MarshalJSON() ([]byte, error) { + return json.Marshal(hex.EncodeToString(k)) +} + +type encodedKey struct { + Type string `json:"type"` + Value json.RawMessage `json:"value"` +} + +func MarshalPaginationKey(k PaginationKey) ([]byte, error) { + var typeName string + var valBytes []byte + var err error + + switch t := k.(type) { + case Uint64Key: + typeName = "uint64" + valBytes, err = t.MarshalJSON() + case BinaryKey: + typeName = "binary" + valBytes, err = t.MarshalJSON() + default: + return nil, fmt.Errorf("unknown pagination key type: %T", k) + } + + if err != nil { + return nil, err + } + + return json.Marshal(encodedKey{ + Type: typeName, + Value: valBytes, + }) +} + +func UnmarshalPaginationKey(data []byte) (PaginationKey, error) { + var wrapper encodedKey + if err := json.Unmarshal(data, &wrapper); err != nil { + return nil, err + } + + switch wrapper.Type { + case "uint64": + var i uint64 + if err := json.Unmarshal(wrapper.Value, &i); err != nil { + return nil, err + } + return NewUint64Key(i), nil + case "binary": + var s string + if err := json.Unmarshal(wrapper.Value, &s); err != nil { + return nil, err + } + b, err := hex.DecodeString(s) + if err != nil { + return nil, err + } + return NewBinaryKey(b), nil + default: + return nil, fmt.Errorf("unknown key type: %s", wrapper.Type) + } +} + +func MinPaginationKey(column *schema.TableColumn) PaginationKey { + switch column.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + return NewUint64Key(0) + // Handle all potential binary/string types + case schema.TYPE_BINARY, schema.TYPE_STRING: + // The smallest value for any binary/string type is an empty slice. + // Even for fixed BINARY(N), starting at empty ensures we catch [0x00, ...] + return NewBinaryKey([]byte{}) + default: + return NewUint64Key(0) + } +} + +func MaxPaginationKey(column *schema.TableColumn) PaginationKey { + switch column.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + return NewUint64Key(math.MaxUint64) + case schema.TYPE_BINARY, schema.TYPE_STRING: + // SAFETY: Cap the size to prevent OOM on LONGBLOB (4GB). + // InnoDB index limit is 3072 bytes. 4KB is a safe upper bound for a PK. + size := column.MaxSize + if size > 4096 { + size = 4096 + } + + maxBytes := make([]byte, size) + for i := range maxBytes { + maxBytes[i] = 0xFF + } + return NewBinaryKey(maxBytes) + default: + return NewUint64Key(math.MaxUint64) + } +} + +// NewPaginationKeyFromRow extracts a pagination key from a row at the given index. +// It determines the appropriate key type based on the column schema. +func NewPaginationKeyFromRow(rowData RowData, index int, column *schema.TableColumn) (PaginationKey, error) { + switch column.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + value, err := rowData.GetUint64(index) + if err != nil { + return nil, fmt.Errorf("failed to get uint64 pagination key: %w", err) + } + return NewUint64Key(value), nil + + case schema.TYPE_BINARY, schema.TYPE_STRING: + valueInterface := rowData[index] + var valueBytes []byte + switch v := valueInterface.(type) { + case []byte: + valueBytes = v + case string: + valueBytes = []byte(v) + default: + return nil, fmt.Errorf("expected binary pagination key to be []byte or string, got %T", valueInterface) + } + return NewBinaryKey(valueBytes), nil + + default: + // Fallback for other integer types + value, err := rowData.GetUint64(index) + if err != nil { + return nil, fmt.Errorf("failed to get pagination key: %w", err) + } + return NewUint64Key(value), nil + } +} diff --git a/row_batch.go b/row_batch.go index 4426fc127..7a865a6f6 100644 --- a/row_batch.go +++ b/row_batch.go @@ -9,7 +9,7 @@ type RowBatch struct { values []RowData paginationKeyIndex int table *TableSchema - fingerprints map[uint64][]byte + fingerprints map[string][]byte columns []string } @@ -55,7 +55,7 @@ func (e *RowBatch) TableSchema() *TableSchema { return e.table } -func (e *RowBatch) Fingerprints() map[uint64][]byte { +func (e *RowBatch) Fingerprints() map[string][]byte { return e.fingerprints } diff --git a/sharding/filter.go b/sharding/filter.go index 0f095b8ae..c5d17c9e5 100644 --- a/sharding/filter.go +++ b/sharding/filter.go @@ -33,7 +33,7 @@ type ShardedCopyFilter struct { missingShardingKeyIndexLogged sync.Map } -func (f *ShardedCopyFilter) BuildSelect(columns []string, table *ghostferry.TableSchema, lastPaginationKey, batchSize uint64) (sq.SelectBuilder, error) { +func (f *ShardedCopyFilter) BuildSelect(columns []string, table *ghostferry.TableSchema, lastPaginationKey ghostferry.PaginationKey, batchSize uint64) (sq.SelectBuilder, error) { quotedPaginationKey := "`" + table.GetPaginationColumn().Name + "`" quotedShardingKey := "`" + f.ShardingKey + "`" quotedTable := ghostferry.QuotedTableName(table) @@ -49,7 +49,7 @@ func (f *ShardedCopyFilter) BuildSelect(columns []string, table *ghostferry.Tabl return sq.Select(columns...). From(quotedTable + " USE INDEX (PRIMARY)"). Where(sq.Eq{quotedPaginationKey: f.ShardingValue}). - Where(sq.Gt{quotedPaginationKey: lastPaginationKey}), nil + Where(sq.Gt{quotedPaginationKey: lastPaginationKey.SQLValue()}), nil } joinTables, exists := f.JoinedTables[table.Name] @@ -90,7 +90,7 @@ func (f *ShardedCopyFilter) BuildSelect(columns []string, table *ghostferry.Tabl return sq.Select(columns...). From(quotedTable). - Join("("+selectPaginationKeys+") AS `batch` USING("+quotedPaginationKey+")", f.ShardingValue, lastPaginationKey), nil + Join("("+selectPaginationKeys+") AS `batch` USING("+quotedPaginationKey+")", f.ShardingValue, lastPaginationKey.SQLValue()), nil } // This is a "joined table". It is the only supported type of table that @@ -126,7 +126,7 @@ func (f *ShardedCopyFilter) BuildSelect(columns []string, table *ghostferry.Tabl pattern := "SELECT `%s` AS sharding_join_alias FROM `%s`.`%s` WHERE `%s` = ? AND `%s` > ?" sql := fmt.Sprintf(pattern, joinTable.JoinColumn, table.Schema, joinTable.TableName, f.ShardingKey, joinTable.JoinColumn) clauses = append(clauses, sql) - args = append(args, f.ShardingValue, lastPaginationKey) + args = append(args, f.ShardingValue, lastPaginationKey.SQLValue()) } subquery := strings.Join(clauses, " UNION DISTINCT ") diff --git a/sharding/test/copy_filter_test.go b/sharding/test/copy_filter_test.go index d0e4ebd42..dab0dfed4 100644 --- a/sharding/test/copy_filter_test.go +++ b/sharding/test/copy_filter_test.go @@ -18,7 +18,7 @@ type CopyFilterTestSuite struct { suite.Suite shardingValue int64 - paginationKeyCursor uint64 + paginationKeyCursor ghostferry.PaginationKey normalTable, normalTable2, joinedTable, primaryKeyTable *ghostferry.TableSchema @@ -27,7 +27,7 @@ type CopyFilterTestSuite struct { func (t *CopyFilterTestSuite) SetupTest() { t.shardingValue = int64(1) - t.paginationKeyCursor = uint64(12345) + t.paginationKeyCursor = ghostferry.NewUint64Key(12345) columns := []schema.TableColumn{{Name: "id"}, {Name: "tenant_id"}, {Name: "data"}} t.normalTable = &ghostferry.TableSchema{ @@ -105,7 +105,7 @@ func (t *CopyFilterTestSuite) TestSelectsRegularTables() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` USE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestFallsBackToLessGoodIndex() { @@ -116,7 +116,7 @@ func (t *CopyFilterTestSuite) TestFallsBackToLessGoodIndex() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` USE INDEX (`less_good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestFallsBackToIgnoredPrimaryIndex() { @@ -128,7 +128,7 @@ func (t *CopyFilterTestSuite) TestFallsBackToIgnoredPrimaryIndex() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` IGNORE INDEX (PRIMARY) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestRemovesIndexHint() { @@ -139,7 +139,7 @@ func (t *CopyFilterTestSuite) TestRemovesIndexHint() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestUsesForceIndex() { @@ -150,7 +150,7 @@ func (t *CopyFilterTestSuite) TestUsesForceIndex() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` FORCE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestUsesIndexHintThatIsNotLowercased() { @@ -161,7 +161,7 @@ func (t *CopyFilterTestSuite) TestUsesIndexHintThatIsNotLowercased() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` FORCE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable() { @@ -179,7 +179,7 @@ func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable() { sql, args, err := selectBuilder1.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` USE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) selectBuilder2, err := t.filter.BuildSelect([]string{"*"}, t.normalTable2, t.paginationKeyCursor, 1024) t.Require().Nil(err) @@ -187,7 +187,7 @@ func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable() { sql, args, err = selectBuilder2.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable2` JOIN (SELECT `id` FROM `shard_1`.`normaltable2` WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable2() { @@ -205,7 +205,7 @@ func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable2() { sql, args, err := selectBuilder1.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) selectBuilder2, err := t.filter.BuildSelect([]string{"*"}, t.normalTable2, t.paginationKeyCursor, 1024) t.Require().Nil(err) @@ -213,7 +213,7 @@ func (t *CopyFilterTestSuite) TestHigherSpecificityOfIndexHintingPerTable2() { sql, args, err = selectBuilder2.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable2` JOIN (SELECT `id` FROM `shard_1`.`normaltable2` FORCE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestIndexHintingPerTableWithNonExistentIndex() { @@ -236,7 +236,7 @@ func (t *CopyFilterTestSuite) TestIndexHintingPerTableWithNonExistentIndex() { sql, args, err := selectBuilder1.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) selectBuilder2, err := t.filter.BuildSelect([]string{"*"}, t.normalTable2, t.paginationKeyCursor, 1024) t.Require().Nil(err) @@ -244,7 +244,7 @@ func (t *CopyFilterTestSuite) TestIndexHintingPerTableWithNonExistentIndex() { sql, args, err = selectBuilder2.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable2` JOIN (SELECT `id` FROM `shard_1`.`normaltable2` FORCE INDEX (`good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestIndexHintingPerTableWithIndexOnTable() { @@ -262,7 +262,7 @@ func (t *CopyFilterTestSuite) TestIndexHintingPerTableWithIndexOnTable() { sql, args, err := selectBuilder1.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`normaltable` JOIN (SELECT `id` FROM `shard_1`.`normaltable` FORCE INDEX (`less_good_sharding_index`) WHERE `tenant_id` = ? AND `id` > ? ORDER BY `id` LIMIT 1024) AS `batch` USING(`id`)", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestSelectsJoinedTables() { @@ -272,7 +272,7 @@ func (t *CopyFilterTestSuite) TestSelectsJoinedTables() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`joinedtable` WHERE `joined_paginationKey` IN (SELECT * FROM (SELECT `joined_paginationKey1` AS sharding_join_alias FROM `shard_1`.`join1` WHERE `tenant_id` = ? AND `joined_paginationKey1` > ? UNION DISTINCT SELECT `joined_paginationKey2` AS sharding_join_alias FROM `shard_1`.`join2` WHERE `tenant_id` = ? AND `joined_paginationKey2` > ? ORDER BY sharding_join_alias LIMIT 1024) AS sharding_join_table) ORDER BY `joined_paginationKey`", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor, t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue(), t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestSelectsPrimaryKeyTables() { @@ -282,7 +282,7 @@ func (t *CopyFilterTestSuite) TestSelectsPrimaryKeyTables() { sql, args, err := selectBuilder.ToSql() t.Require().Nil(err) t.Require().Equal("SELECT * FROM `shard_1`.`pkTable` USE INDEX (PRIMARY) WHERE `tenant_id` = ? AND `tenant_id` > ?", sql) - t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor}, args) + t.Require().Equal([]interface{}{t.shardingValue, t.paginationKeyCursor.SQLValue()}, args) } func (t *CopyFilterTestSuite) TestShardingValueTypes() { diff --git a/sharding/test/trivial_integration_test.go b/sharding/test/trivial_integration_test.go index 907753b9c..bd4c28982 100644 --- a/sharding/test/trivial_integration_test.go +++ b/sharding/test/trivial_integration_test.go @@ -2,6 +2,7 @@ package test import ( "math/rand" + "sync/atomic" "testing" sql "github.com/Shopify/ghostferry/sqlwrapper" @@ -83,6 +84,8 @@ func TestSelectiveCopyDataWithInsertLoadOnOtherTenants(t *testing.T) { } func TestSelectiveCopyDataWithInsertLoadOnAllTenants(t *testing.T) { + var firstInsert atomic.Bool + testcase := &testhelpers.IntegrationTestCase{ T: t, Ferry: selectiveFerry(int64(2)), @@ -93,7 +96,11 @@ func TestSelectiveCopyDataWithInsertLoadOnAllTenants(t *testing.T) { Tables: []string{"gftest.table1"}, ExtraInsertData: func(tableName string, vals map[string]interface{}) { - vals["tenant_id"] = rand.Intn(3) + if firstInsert.CompareAndSwap(false, true) { + vals["tenant_id"] = 2 + } else { + vals["tenant_id"] = rand.Intn(3) + } }, }, } diff --git a/state_tracker.go b/state_tracker.go index 760481a80..d885c33f5 100644 --- a/state_tracker.go +++ b/state_tracker.go @@ -2,7 +2,7 @@ package ghostferry import ( "container/ring" - "math" + "encoding/json" "sync" "time" @@ -34,7 +34,7 @@ type SerializableState struct { GhostferryVersion string LastKnownTableSchemaCache TableSchemaCache - LastSuccessfulPaginationKeys map[string]uint64 + LastSuccessfulPaginationKeys map[string]PaginationKey CompletedTables map[string]bool LastWrittenBinlogPosition mysql.Position BinlogVerifyStore BinlogVerifySerializedStore @@ -42,6 +42,53 @@ type SerializableState struct { LastStoredBinlogPositionForTargetVerifier mysql.Position } +func (s *SerializableState) MarshalJSON() ([]byte, error) { + // Create an alias to avoid infinite recursion, but change the map type + type Alias SerializableState + aux := &struct { + LastSuccessfulPaginationKeys map[string]json.RawMessage + *Alias + }{ + Alias: (*Alias)(s), + LastSuccessfulPaginationKeys: make(map[string]json.RawMessage), + } + + for k, v := range s.LastSuccessfulPaginationKeys { + b, err := MarshalPaginationKey(v) + if err != nil { + return nil, err + } + aux.LastSuccessfulPaginationKeys[k] = b + } + + return json.Marshal(aux) +} + +func (s *SerializableState) UnmarshalJSON(data []byte) error { + type Alias SerializableState + aux := &struct { + LastSuccessfulPaginationKeys map[string]json.RawMessage + *Alias + }{ + Alias: (*Alias)(s), + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + s.LastSuccessfulPaginationKeys = make(map[string]PaginationKey) + for k, v := range aux.LastSuccessfulPaginationKeys { + pk, err := UnmarshalPaginationKey(v) + if err != nil { + return err + } + s.LastSuccessfulPaginationKeys[k] = pk + } + + return nil +} + func (s *SerializableState) MinSourceBinlogPosition() mysql.Position { nilPosition := mysql.Position{} if s.LastWrittenBinlogPosition == nilPosition { @@ -61,7 +108,7 @@ func (s *SerializableState) MinSourceBinlogPosition() mysql.Position { // For tracking the speed of the copy type PaginationKeyPositionLog struct { - Position uint64 + Position float64 At time.Time } @@ -92,7 +139,7 @@ type StateTracker struct { lastStoredBinlogPositionForInlineVerifier mysql.Position lastStoredBinlogPositionForTargetVerifier mysql.Position - lastSuccessfulPaginationKeys map[string]uint64 + lastSuccessfulPaginationKeys map[string]PaginationKey completedTables map[string]bool // TODO: Performance tracking should be refactored out of the state tracker, @@ -106,7 +153,7 @@ func NewStateTracker(speedLogCount int) *StateTracker { BinlogRWMutex: &sync.RWMutex{}, CopyRWMutex: &sync.RWMutex{}, - lastSuccessfulPaginationKeys: make(map[string]uint64), + lastSuccessfulPaginationKeys: make(map[string]PaginationKey), completedTables: make(map[string]bool), iterationSpeedLog: newSpeedLogRing(speedLogCount), rowStatsWrittenPerTable: make(map[string]RowStats), @@ -146,11 +193,16 @@ func (s *StateTracker) UpdateLastResumableBinlogPositionForTargetVerifier(pos my s.lastStoredBinlogPositionForTargetVerifier = pos } -func (s *StateTracker) UpdateLastSuccessfulPaginationKey(table string, paginationKey uint64, rowStats RowStats) { +func (s *StateTracker) UpdateLastSuccessfulPaginationKey(table string, paginationKey PaginationKey, rowStats RowStats) { s.CopyRWMutex.Lock() defer s.CopyRWMutex.Unlock() - deltaPaginationKey := paginationKey - s.lastSuccessfulPaginationKeys[table] + var deltaPaginationKey float64 + if lastKey, exists := s.lastSuccessfulPaginationKeys[table]; exists { + deltaPaginationKey = paginationKey.NumericPosition() - lastKey.NumericPosition() + } else { + deltaPaginationKey = paginationKey.NumericPosition() + } s.lastSuccessfulPaginationKeys[table] = paginationKey // TODO: this code is intentionally left here so it is kind of crappy and @@ -174,18 +226,18 @@ func (s *StateTracker) RowStatsWrittenPerTable() map[string]RowStats { return d } -func (s *StateTracker) LastSuccessfulPaginationKey(table string) uint64 { +func (s *StateTracker) LastSuccessfulPaginationKey(table string, tableSchema *TableSchema) PaginationKey { s.CopyRWMutex.RLock() defer s.CopyRWMutex.RUnlock() _, found := s.completedTables[table] if found { - return math.MaxUint64 + return MaxPaginationKey(tableSchema.GetPaginationColumn()) } paginationKey, found := s.lastSuccessfulPaginationKeys[table] if !found { - return 0 + return MinPaginationKey(tableSchema.GetPaginationColumn()) } return paginationKey @@ -240,7 +292,7 @@ func (s *StateTracker) updateRowStatsForTable(table string, rowStats RowStats) { } } -func (s *StateTracker) updateSpeedLog(deltaPaginationKey uint64) { +func (s *StateTracker) updateSpeedLog(deltaPaginationKey float64) { if s.iterationSpeedLog == nil { return } @@ -263,7 +315,7 @@ func (s *StateTracker) Serialize(lastKnownTableSchemaCache TableSchemaCache, bin state := &SerializableState{ GhostferryVersion: VersionString, LastKnownTableSchemaCache: lastKnownTableSchemaCache, - LastSuccessfulPaginationKeys: make(map[string]uint64), + LastSuccessfulPaginationKeys: make(map[string]PaginationKey), CompletedTables: make(map[string]bool), LastWrittenBinlogPosition: s.lastWrittenBinlogPosition, LastStoredBinlogPositionForInlineVerifier: s.lastStoredBinlogPositionForInlineVerifier, diff --git a/table_schema_cache.go b/table_schema_cache.go index ca5b1df81..2aeb5cfff 100644 --- a/table_schema_cache.go +++ b/table_schema_cache.go @@ -126,8 +126,8 @@ func QuotedTableNameFromString(database, table string) string { return fmt.Sprintf("`%s`.`%s`", database, table) } -func MaxPaginationKeys(db *sql.DB, tables []*TableSchema, logger *logrus.Entry) (map[*TableSchema]uint64, []*TableSchema, error) { - tablesWithData := make(map[*TableSchema]uint64) +func MaxPaginationKeys(db *sql.DB, tables []*TableSchema, logger *logrus.Entry) (map[*TableSchema]PaginationKey, []*TableSchema, error) { + tablesWithData := make(map[*TableSchema]PaginationKey) emptyTables := make([]*TableSchema, 0, len(tables)) for _, table := range tables { @@ -257,6 +257,11 @@ func NonNumericPaginationKeyError(schema, table, paginationKey string) error { return fmt.Errorf("Pagination Key `%s` for %s is non-numeric", paginationKey, QuotedTableNameFromString(schema, table)) } +// NonBinaryCollationError exported to facilitate black box testing +func NonBinaryCollationError(schema, table, paginationKey, collation string) error { + return fmt.Errorf("Pagination Key `%s` for %s has non-binary collation '%s'. Binary columns (BINARY, VARBINARY) or string columns with binary collation (e.g., utf8mb4_bin) are required to ensure consistent ordering between MySQL and Ghostferry", paginationKey, QuotedTableNameFromString(schema, table), collation) +} + func (t *TableSchema) paginationKeyColumn(cascadingPaginationColumnConfig *CascadingPaginationColumnConfig) (*schema.TableColumn, int, error) { var err error var paginationKeyColumn *schema.TableColumn @@ -277,8 +282,25 @@ func (t *TableSchema) paginationKeyColumn(cascadingPaginationColumnConfig *Casca err = NonExistingPaginationKeyError(t.Schema, t.Name) } - if paginationKeyColumn != nil && paginationKeyColumn.Type != schema.TYPE_NUMBER && paginationKeyColumn.Type != schema.TYPE_MEDIUM_INT { - return nil, -1, NonNumericPaginationKeyError(t.Schema, t.Name, paginationKeyColumn.Name) + if paginationKeyColumn != nil { + isNumber := paginationKeyColumn.Type == schema.TYPE_NUMBER || paginationKeyColumn.Type == schema.TYPE_MEDIUM_INT + isBinary := paginationKeyColumn.Type == schema.TYPE_BINARY || + paginationKeyColumn.Type == schema.TYPE_STRING + + if !isNumber && !isBinary { + return nil, -1, NonNumericPaginationKeyError(t.Schema, t.Name, paginationKeyColumn.Name) + } + + // For string types (VARCHAR, CHAR), validate that the collation is binary + // BINARY and VARBINARY types don't have collations and are always binary-safe + // Related PR comment with integration test proof: https://github.com/Shopify/ghostferry/pull/417#discussion_r2619684805 + if paginationKeyColumn.Type == schema.TYPE_STRING && paginationKeyColumn.Collation != "" { + // Binary collations end with "_bin" (e.g., utf8mb4_bin, latin1_bin) + // BINARY type has empty collation and is handled above + if !strings.HasSuffix(paginationKeyColumn.Collation, "_bin") { + return nil, -1, NonBinaryCollationError(t.Schema, t.Name, paginationKeyColumn.Name, paginationKeyColumn.Collation) + } + } } return paginationKeyColumn, paginationKeyIndex, err @@ -398,7 +420,7 @@ func showTablesFrom(c *sql.DB, dbname string) ([]string, error) { return tables, nil } -func maxPaginationKey(db *sql.DB, table *TableSchema) (uint64, bool, error) { +func maxPaginationKey(db *sql.DB, table *TableSchema) (PaginationKey, bool, error) { primaryKeyColumn := table.GetPaginationColumn() paginationKeyName := QuoteField(primaryKeyColumn.Name) query, args, err := sq. @@ -409,18 +431,51 @@ func maxPaginationKey(db *sql.DB, table *TableSchema) (uint64, bool, error) { ToSql() if err != nil { - return 0, false, err + return nil, false, err } - var maxPaginationKey uint64 - err = db.QueryRow(query, args...).Scan(&maxPaginationKey) + var result PaginationKey + switch primaryKeyColumn.Type { + case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT: + var value uint64 + err = db.QueryRow(query, args...).Scan(&value) + result = NewUint64Key(value) + + case schema.TYPE_BINARY, schema.TYPE_STRING: + // Scan into interface{} to handle both []byte and string from driver + var val interface{} + err = db.QueryRow(query, args...).Scan(&val) + if err != nil { + break + } + + var binValue []byte + switch v := val.(type) { + case []byte: + binValue = v + case string: + binValue = []byte(v) + default: + err = fmt.Errorf("expected binary/string for max key, got %T", val) + } + + if err == nil { + result = NewBinaryKey(binValue) + } + + default: + // Fallback + var value uint64 + err = db.QueryRow(query, args...).Scan(&value) + result = NewUint64Key(value) + } switch { case err == sqlorig.ErrNoRows: - return 0, false, nil + return nil, false, nil case err != nil: - return 0, false, err + return nil, false, err default: - return maxPaginationKey, true, nil + return result, true, nil } } diff --git a/target_verifier.go b/target_verifier.go index 1ffe99eb0..2ce7fa04b 100644 --- a/target_verifier.go +++ b/target_verifier.go @@ -42,7 +42,7 @@ func (t *TargetVerifier) BinlogEventListener(evs []DMLEvent) error { if err != nil { return err } - return fmt.Errorf("row data with paginationKey %d on `%s`.`%s` has been corrupted by a change directly performed in the target at binlog file: %s and position: %d", paginationKey, ev.Database(), ev.Table(), ev.BinlogPosition().Name, ev.BinlogPosition().Pos) + return fmt.Errorf("row data with paginationKey %s on `%s`.`%s` has been corrupted by a change directly performed in the target at binlog file: %s and position: %d", paginationKey, ev.Database(), ev.Table(), ev.BinlogPosition().Name, ev.BinlogPosition().Pos) } } diff --git a/test/go/data_iterator_sorter_test.go b/test/go/data_iterator_sorter_test.go index 8a32f9a50..503bb31ab 100644 --- a/test/go/data_iterator_sorter_test.go +++ b/test/go/data_iterator_sorter_test.go @@ -32,7 +32,7 @@ var DBTableMap = map[string]string{ type DataIteratorSorterTestSuite struct { *testhelpers.GhostferryUnitTestSuite - unsortedTables map[*ghostferry.TableSchema]uint64 + unsortedTables map[*ghostferry.TableSchema]ghostferry.PaginationKey dataIterator *ghostferry.DataIterator } @@ -48,11 +48,11 @@ func (t *DataIteratorSorterTestSuite) SetupTest() { } tables, _ := ghostferry.LoadTables(t.Ferry.SourceDB, tableFilter, nil, nil, nil, nil) - t.unsortedTables = make(map[*ghostferry.TableSchema]uint64, len(tables)) + t.unsortedTables = make(map[*ghostferry.TableSchema]ghostferry.PaginationKey, len(tables)) i := 0 for _,f := range tables.AsSlice() { maxPaginationKey := uint64(100_000 - i) - t.unsortedTables[f] = maxPaginationKey + t.unsortedTables[f] = ghostferry.NewUint64Key(maxPaginationKey) i++ } @@ -83,7 +83,7 @@ func (t *DataIteratorSorterTestSuite) TestOrderMaxPaginationKeys() { copy(expectedTables, sortedTables) sort.Slice(expectedTables, func(i, j int) bool { - return sortedTables[i].MaxPaginationKey > sortedTables[j].MaxPaginationKey + return sortedTables[i].MaxPaginationKey.Compare(sortedTables[j].MaxPaginationKey) > 0 }) t.Require().Equal(len(t.unsortedTables), len(sortedTables)) diff --git a/test/go/inline_verifier_test.go b/test/go/inline_verifier_test.go index ee97e3ea5..d707d583f 100644 --- a/test/go/inline_verifier_test.go +++ b/test/go/inline_verifier_test.go @@ -9,18 +9,18 @@ import ( func newMockBinlogVerifySerializedStore() ghostferry.BinlogVerifySerializedStore { s := make(ghostferry.BinlogVerifySerializedStore) - s["db"] = map[string]map[uint64]int{ - "table1": map[uint64]int{ - 3: 1, - 10: 2, - 30: 3, + s["db"] = map[string]map[string]int{ + "table1": map[string]int{ + "3": 1, + "10": 2, + "30": 3, }, } - s["db2"] = map[string]map[uint64]int{ - "table2": map[uint64]int{ - 4: 1, - 20: 2, - 40: 1, + s["db2"] = map[string]map[string]int{ + "table2": map[string]int{ + "4": 1, + "20": 2, + "40": 1, }, } return s @@ -39,7 +39,7 @@ func TestBinlogVerifySerializedStoreCopy(t *testing.T) { s := newMockBinlogVerifySerializedStore() s2 := s.Copy() - s2["db"]["table1"][3] += 1 + s2["db"]["table1"]["3"] += 1 r.Equal(uint64(10), s.RowCount()) r.Equal(uint64(11), s2.RowCount()) diff --git a/test/go/iterative_verifier_integration_test.go b/test/go/iterative_verifier_integration_test.go index c0e087635..facc74b66 100644 --- a/test/go/iterative_verifier_integration_test.go +++ b/test/go/iterative_verifier_integration_test.go @@ -14,8 +14,12 @@ import ( func TestHashesSql(t *testing.T) { columns := []schema.TableColumn{schema.TableColumn{Name: "id"}, schema.TableColumn{Name: "data"}, schema.TableColumn{Name: "float_col", Type: schema.TYPE_FLOAT}} paginationKeys := []uint64{1, 5, 42} + paginationKeysInterface := make([]interface{}, len(paginationKeys)) + for i, pk := range paginationKeys { + paginationKeysInterface[i] = pk + } - sql, args, err := ghostferry.GetMd5HashesSql("gftest", "test_table", "id", columns, paginationKeys) + sql, args, err := ghostferry.GetMd5HashesSql("gftest", "test_table", "id", columns, paginationKeysInterface) assert.Nil(t, err) assert.Equal(t, "SELECT `id`, MD5(CONCAT(MD5(COALESCE(`id`, 'NULL')),MD5(COALESCE(`data`, 'NULL')),MD5(COALESCE((if (`float_col` = '-0', 0, `float_col`)), 'NULL')))) "+ diff --git a/test/go/iterative_verifier_test.go b/test/go/iterative_verifier_test.go index f48008d5e..487b24fb9 100644 --- a/test/go/iterative_verifier_test.go +++ b/test/go/iterative_verifier_test.go @@ -3,6 +3,7 @@ package test import ( "fmt" "sort" + "strconv" "testing" "time" @@ -31,7 +32,7 @@ func (t *IterativeVerifierTestSuite) SetupTest() { tableCompressions[testhelpers.TestCompressedTable1Name] = make(map[string]string) tableCompressions[testhelpers.TestCompressedTable1Name][testhelpers.TestCompressedColumn1Name] = ghostferry.CompressionSnappy - compressionVerifier, err := ghostferry.NewCompressionVerifier(tableCompressions) + compressionVerifier, err := ghostferry.NewCompressionVerifier(tableCompressions, t.Ferry.Tables) if err != nil { t.FailNow(err.Error()) } @@ -223,13 +224,13 @@ func (t *IterativeVerifierTestSuite) TestChangingDataChangesHash() { func (t *IterativeVerifierTestSuite) TestDeduplicatesHashes() { t.InsertRow(42, "foo") - hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, []uint64{42, 42}) + hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, []interface{}{uint64(42), uint64(42)}) t.Require().Nil(err) t.Require().Equal(1, len(hashes)) } func (t *IterativeVerifierTestSuite) TestDoesntReturnHashIfRecordDoesntExist() { - hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, []uint64{42, 42}) + hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, []interface{}{uint64(42), uint64(42)}) t.Require().Nil(err) t.Require().Equal(0, len(hashes)) } @@ -347,14 +348,20 @@ func (t *IterativeVerifierTestSuite) DeleteRow(id int) { } func (t *IterativeVerifierTestSuite) GetHashes(ids []uint64) []string { - hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, ids) + paginationKeys := make([]interface{}, len(ids)) + for i, id := range ids { + paginationKeys[i] = id + } + + hashes, err := t.verifier.GetHashes(t.db, t.table.Schema, t.table.Name, t.table.GetPaginationColumn().Name, t.table.Columns, paginationKeys) t.Require().Nil(err) t.Require().Equal(len(hashes), len(ids)) res := make([]string, len(ids)) for idx, id := range ids { - hash, ok := hashes[id] + paginationKeyStr := ghostferry.NewUint64Key(id).String() + hash, ok := hashes[paginationKeyStr] t.Require().True(ok) t.Require().True(len(hash) > 0) @@ -376,6 +383,9 @@ func (t *IterativeVerifierTestSuite) reloadTables() { t.Ferry.Tables = tables t.verifier.Tables = tables.AsSlice() t.verifier.TableSchemaCache = tables + if t.verifier.CompressionVerifier != nil { + t.verifier.CompressionVerifier.TableSchemaCache = tables + } t.table = tables.Get(testhelpers.TestSchemaName, testhelpers.TestTable1Name) t.Require().NotNil(t.table) @@ -394,19 +404,21 @@ func (t *ReverifyStoreTestSuite) SetupTest() { func (t *ReverifyStoreTestSuite) TestAddEntryIntoReverifyStoreWillDeduplicate() { paginationKey1 := uint64(100) paginationKey2 := uint64(101) + paginationKey1Str := ghostferry.NewUint64Key(paginationKey1).String() + paginationKey2Str := ghostferry.NewUint64Key(paginationKey2).String() table1 := &ghostferry.TableSchema{Table: &schema.Table{Schema: "gftest", Name: "table1"}} - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1, Table: table1}) - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1, Table: table1}) - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1, Table: table1}) - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey2, Table: table1}) - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey2, Table: table1}) + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1Str, Table: table1}) + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1Str, Table: table1}) + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey1Str, Table: table1}) + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey2Str, Table: table1}) + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKey2Str, Table: table1}) t.Require().Equal(uint64(2), t.store.RowCount) t.Require().Equal(1, len(t.store.MapStore)) t.Require().Equal( - map[uint64]struct{}{ - paginationKey1: struct{}{}, - paginationKey2: struct{}{}, + map[string]struct{}{ + paginationKey1Str: struct{}{}, + paginationKey2Str: struct{}{}, }, t.store.MapStore[ghostferry.TableIdentifier{"gftest", "table1"}], ) @@ -417,13 +429,15 @@ func (t *ReverifyStoreTestSuite) TestFlushAndBatchByTableWillCreateReverifyBatch table1 := &ghostferry.TableSchema{Table: &schema.Table{Schema: "gftest", Name: "table1"}} table2 := &ghostferry.TableSchema{Table: &schema.Table{Schema: "gftest", Name: "table2"}} for i := uint64(100); i < 155; i++ { - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: i, Table: table1}) + paginationKeyStr := ghostferry.NewUint64Key(i).String() + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKeyStr, Table: table1}) expectedTable1PaginationKeys = append(expectedTable1PaginationKeys, i) } expectedTable2PaginationKeys := make([]uint64, 0, 45) for i := uint64(200); i < 245; i++ { - t.store.Add(ghostferry.ReverifyEntry{PaginationKey: i, Table: table2}) + paginationKeyStr := ghostferry.NewUint64Key(i).String() + t.store.Add(ghostferry.ReverifyEntry{PaginationKey: paginationKeyStr, Table: table2}) expectedTable2PaginationKeys = append(expectedTable2PaginationKeys, i) } @@ -446,8 +460,11 @@ func (t *ReverifyStoreTestSuite) TestFlushAndBatchByTableWillCreateReverifyBatch actualTable1PaginationKeys := make([]uint64, 0) for _, batch := range table1Batches { - for _, paginationKey := range batch.PaginationKeys { - actualTable1PaginationKeys = append(actualTable1PaginationKeys, paginationKey) + for _, paginationKeyInterface := range batch.PaginationKeys { + paginationKeyStr := paginationKeyInterface.(string) + paginationKeyUint, err := strconv.ParseUint(paginationKeyStr, 10, 64) + t.Require().Nil(err) + actualTable1PaginationKeys = append(actualTable1PaginationKeys, paginationKeyUint) } } @@ -456,8 +473,11 @@ func (t *ReverifyStoreTestSuite) TestFlushAndBatchByTableWillCreateReverifyBatch actualTable2PaginationKeys := make([]uint64, 0) for _, batch := range table2Batches { - for _, paginationKey := range batch.PaginationKeys { - actualTable2PaginationKeys = append(actualTable2PaginationKeys, paginationKey) + for _, paginationKeyInterface := range batch.PaginationKeys { + paginationKeyStr := paginationKeyInterface.(string) + paginationKeyUint, err := strconv.ParseUint(paginationKeyStr, 10, 64) + t.Require().Nil(err) + actualTable2PaginationKeys = append(actualTable2PaginationKeys, paginationKeyUint) } } diff --git a/test/go/pagination_key_test.go b/test/go/pagination_key_test.go new file mode 100644 index 000000000..c32c05a3f --- /dev/null +++ b/test/go/pagination_key_test.go @@ -0,0 +1,505 @@ +package test + +import ( + "encoding/hex" + "encoding/json" + "math" + "testing" + + "github.com/Shopify/ghostferry" + "github.com/go-mysql-org/go-mysql/schema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUint64Key_SQLValue(t *testing.T) { + key := ghostferry.NewUint64Key(12345) + assert.Equal(t, uint64(12345), key.SQLValue()) +} + +func TestUint64Key_Compare(t *testing.T) { + tests := []struct { + name string + key1 ghostferry.Uint64Key + key2 ghostferry.Uint64Key + expected int + }{ + {"less than", ghostferry.NewUint64Key(100), ghostferry.NewUint64Key(200), -1}, + {"equal", ghostferry.NewUint64Key(100), ghostferry.NewUint64Key(100), 0}, + {"greater than", ghostferry.NewUint64Key(200), ghostferry.NewUint64Key(100), 1}, + {"zero vs non-zero", ghostferry.NewUint64Key(0), ghostferry.NewUint64Key(1), -1}, + {"max uint64", ghostferry.NewUint64Key(math.MaxUint64), ghostferry.NewUint64Key(math.MaxUint64-1), 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.key1.Compare(tt.key2) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestUint64Key_ComparePanicsOnTypeMismatch(t *testing.T) { + key1 := ghostferry.NewUint64Key(100) + key2 := ghostferry.NewBinaryKey([]byte{0x01, 0x02}) + + assert.Panics(t, func() { + key1.Compare(key2) + }) +} + +func TestUint64Key_NumericPosition(t *testing.T) { + tests := []struct { + value uint64 + expected float64 + }{ + {0, 0.0}, + {100, 100.0}, + {math.MaxUint64, float64(math.MaxUint64)}, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + key := ghostferry.NewUint64Key(tt.value) + assert.Equal(t, tt.expected, key.NumericPosition()) + }) + } +} + +func TestUint64Key_String(t *testing.T) { + tests := []struct { + value uint64 + expected string + }{ + {0, "0"}, + {12345, "12345"}, + {math.MaxUint64, "18446744073709551615"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + key := ghostferry.NewUint64Key(tt.value) + assert.Equal(t, tt.expected, key.String()) + }) + } +} + +func TestUint64Key_IsMax(t *testing.T) { + assert.True(t, ghostferry.NewUint64Key(math.MaxUint64).IsMax()) + assert.False(t, ghostferry.NewUint64Key(math.MaxUint64-1).IsMax()) + assert.False(t, ghostferry.NewUint64Key(0).IsMax()) +} + +func TestUint64Key_MarshalJSON(t *testing.T) { + key := ghostferry.NewUint64Key(12345) + data, err := key.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, "12345", string(data)) +} + +func TestBinaryKey_NewBinaryKeyClones(t *testing.T) { + original := []byte{0x01, 0x02, 0x03} + key := ghostferry.NewBinaryKey(original) + + original[0] = 0xFF + + assert.Equal(t, []byte{0x01, 0x02, 0x03}, []byte(key)) +} + +func TestBinaryKey_SQLValue(t *testing.T) { + original := []byte{0x01, 0x02, 0x03} + key := ghostferry.NewBinaryKey(original) + assert.Equal(t, original, key.SQLValue()) +} + +func TestBinaryKey_Compare(t *testing.T) { + tests := []struct { + name string + key1 ghostferry.BinaryKey + key2 ghostferry.BinaryKey + expected int + }{ + { + "less than", + ghostferry.NewBinaryKey([]byte{0x01, 0x02}), + ghostferry.NewBinaryKey([]byte{0x01, 0x03}), + -1, + }, + { + "equal", + ghostferry.NewBinaryKey([]byte{0x01, 0x02}), + ghostferry.NewBinaryKey([]byte{0x01, 0x02}), + 0, + }, + { + "greater than", + ghostferry.NewBinaryKey([]byte{0x02, 0x01}), + ghostferry.NewBinaryKey([]byte{0x01, 0x02}), + 1, + }, + { + "empty vs non-empty", + ghostferry.NewBinaryKey([]byte{}), + ghostferry.NewBinaryKey([]byte{0x01}), + -1, + }, + { + "different lengths", + ghostferry.NewBinaryKey([]byte{0x01}), + ghostferry.NewBinaryKey([]byte{0x01, 0x00}), + -1, + }, + { + "UUID comparison", + ghostferry.NewBinaryKey([]byte{ + 0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, + 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x05, + }), + ghostferry.NewBinaryKey([]byte{ + 0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, + 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x06, + }), + -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.key1.Compare(tt.key2) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBinaryKey_ComparePanicsOnTypeMismatch(t *testing.T) { + key1 := ghostferry.NewBinaryKey([]byte{0x01, 0x02}) + key2 := ghostferry.NewUint64Key(100) + + assert.Panics(t, func() { + key1.Compare(key2) + }) +} + +func TestBinaryKey_NumericPosition(t *testing.T) { + tests := []struct { + name string + bytes []byte + expected float64 + }{ + { + "empty", + []byte{}, + 0.0, + }, + { + "single byte", + []byte{0x01}, + float64(0x0100000000000000), + }, + { + "8 bytes", + []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, + float64(0x0102030405060708), + }, + { + "more than 8 bytes uses first 8", + []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a}, + float64(0x0102030405060708), + }, + { + "UUIDv7 timestamp ordering", + []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + float64(0x018f3e4c5a6b0000), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := ghostferry.NewBinaryKey(tt.bytes) + assert.Equal(t, tt.expected, key.NumericPosition()) + }) + } +} + +func TestBinaryKey_NumericPosition_Monotonic(t *testing.T) { + key1 := ghostferry.NewBinaryKey([]byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + key2 := ghostferry.NewBinaryKey([]byte{0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + + assert.True(t, key1.NumericPosition() < key2.NumericPosition()) +} + +func TestBinaryKey_String(t *testing.T) { + tests := []struct { + name string + bytes []byte + expected string + }{ + {"empty", []byte{}, ""}, + {"single byte", []byte{0x01}, "01"}, + {"multiple bytes", []byte{0x01, 0x02, 0x03}, "010203"}, + {"UUID", []byte{ + 0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, + 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x05, + }, "018f3e4c5a6b7c8d9eafb0c1d2e3f405"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := ghostferry.NewBinaryKey(tt.bytes) + assert.Equal(t, tt.expected, key.String()) + }) + } +} + +func TestBinaryKey_IsMax(t *testing.T) { + tests := []struct { + name string + bytes []byte + expected bool + }{ + {"empty is not max", []byte{}, false}, + {"all FF is max", []byte{0xFF, 0xFF, 0xFF, 0xFF}, true}, + {"UUID(16) all FF is max", []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + }, true}, + {"one non-FF byte is not max", []byte{0xFF, 0xFE, 0xFF, 0xFF}, false}, + {"zero is not max", []byte{0x00}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := ghostferry.NewBinaryKey(tt.bytes) + assert.Equal(t, tt.expected, key.IsMax()) + }) + } +} + +func TestBinaryKey_MarshalJSON(t *testing.T) { + key := ghostferry.NewBinaryKey([]byte{0x01, 0x02, 0x03}) + data, err := key.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, `"010203"`, string(data)) +} + +func TestMarshalPaginationKey_Uint64(t *testing.T) { + key := ghostferry.NewUint64Key(12345) + data, err := ghostferry.MarshalPaginationKey(key) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, "uint64", result["type"]) + assert.Equal(t, float64(12345), result["value"]) +} + +func TestMarshalPaginationKey_Binary(t *testing.T) { + key := ghostferry.NewBinaryKey([]byte{0x01, 0x02, 0x03}) + data, err := ghostferry.MarshalPaginationKey(key) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, "binary", result["type"]) + assert.Equal(t, "010203", result["value"]) +} + +func TestUnmarshalPaginationKey_Uint64(t *testing.T) { + data := []byte(`{"type":"uint64","value":12345}`) + key, err := ghostferry.UnmarshalPaginationKey(data) + require.NoError(t, err) + + uint64Key, ok := key.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(12345), uint64(uint64Key)) +} + +func TestUnmarshalPaginationKey_Binary(t *testing.T) { + data := []byte(`{"type":"binary","value":"010203"}`) + key, err := ghostferry.UnmarshalPaginationKey(data) + require.NoError(t, err) + + binaryKey, ok := key.(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte{0x01, 0x02, 0x03}, []byte(binaryKey)) +} + +func TestUnmarshalPaginationKey_InvalidType(t *testing.T) { + data := []byte(`{"type":"invalid","value":"something"}`) + _, err := ghostferry.UnmarshalPaginationKey(data) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown key type") +} + +func TestUnmarshalPaginationKey_InvalidJSON(t *testing.T) { + data := []byte(`{invalid json}`) + _, err := ghostferry.UnmarshalPaginationKey(data) + assert.Error(t, err) +} + +func TestUnmarshalPaginationKey_InvalidBinaryHex(t *testing.T) { + data := []byte(`{"type":"binary","value":"ZZZZ"}`) + _, err := ghostferry.UnmarshalPaginationKey(data) + assert.Error(t, err) +} + +func TestPaginationKey_RoundTrip_Uint64(t *testing.T) { + original := ghostferry.NewUint64Key(98765) + + marshaled, err := ghostferry.MarshalPaginationKey(original) + require.NoError(t, err) + + unmarshaled, err := ghostferry.UnmarshalPaginationKey(marshaled) + require.NoError(t, err) + + assert.Equal(t, original, unmarshaled) +} + +func TestPaginationKey_RoundTrip_Binary(t *testing.T) { + original := ghostferry.NewBinaryKey([]byte{0xDE, 0xAD, 0xBE, 0xEF}) + + marshaled, err := ghostferry.MarshalPaginationKey(original) + require.NoError(t, err) + + unmarshaled, err := ghostferry.UnmarshalPaginationKey(marshaled) + require.NoError(t, err) + + assert.Equal(t, original, unmarshaled) +} + +func TestMinPaginationKey_Numeric(t *testing.T) { + column := &schema.TableColumn{ + Name: "id", + Type: schema.TYPE_NUMBER, + } + + minKey := ghostferry.MinPaginationKey(column) + uint64Key, ok := minKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(0), uint64(uint64Key)) +} + +func TestMinPaginationKey_MediumInt(t *testing.T) { + column := &schema.TableColumn{ + Name: "id", + Type: schema.TYPE_MEDIUM_INT, + } + + minKey := ghostferry.MinPaginationKey(column) + uint64Key, ok := minKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(0), uint64(uint64Key)) +} + +func TestMinPaginationKey_Binary(t *testing.T) { + column := &schema.TableColumn{ + Name: "uuid", + Type: schema.TYPE_BINARY, + } + + minKey := ghostferry.MinPaginationKey(column) + binaryKey, ok := minKey.(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte{}, []byte(binaryKey)) +} + +func TestMinPaginationKey_String(t *testing.T) { + column := &schema.TableColumn{ + Name: "key", + Type: schema.TYPE_STRING, + } + + minKey := ghostferry.MinPaginationKey(column) + binaryKey, ok := minKey.(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte{}, []byte(binaryKey)) +} + + +func TestMaxPaginationKey_Numeric(t *testing.T) { + column := &schema.TableColumn{ + Name: "id", + Type: schema.TYPE_NUMBER, + } + + maxKey := ghostferry.MaxPaginationKey(column) + uint64Key, ok := maxKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(math.MaxUint64), uint64(uint64Key)) +} + +func TestMaxPaginationKey_MediumInt(t *testing.T) { + column := &schema.TableColumn{ + Name: "id", + Type: schema.TYPE_MEDIUM_INT, + } + + maxKey := ghostferry.MaxPaginationKey(column) + uint64Key, ok := maxKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(math.MaxUint64), uint64(uint64Key)) +} + +func TestMaxPaginationKey_Binary_UUID16(t *testing.T) { + column := &schema.TableColumn{ + Name: "uuid", + Type: schema.TYPE_BINARY, + MaxSize: 16, + } + + maxKey := ghostferry.MaxPaginationKey(column) + binaryKey, ok := maxKey.(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, 16, len(binaryKey)) + + for _, b := range binaryKey { + assert.Equal(t, byte(0xFF), b) + } + assert.True(t, binaryKey.IsMax()) +} + +func TestMaxPaginationKey_Binary_LargeSize(t *testing.T) { + column := &schema.TableColumn{ + Name: "large", + Type: schema.TYPE_STRING, + MaxSize: 100000, + } + + maxKey := ghostferry.MaxPaginationKey(column) + binaryKey, ok := maxKey.(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, 4096, len(binaryKey)) +} + +func TestMaxPaginationKey_DefaultToNumeric(t *testing.T) { + column := &schema.TableColumn{ + Name: "id", + Type: 999, + } + + maxKey := ghostferry.MaxPaginationKey(column) + uint64Key, ok := maxKey.(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(math.MaxUint64), uint64(uint64Key)) +} + +func TestPaginationKey_CrossTypeComparison_UUIDv7Ordering(t *testing.T) { + uuidBytes1, _ := hex.DecodeString("018f3e4c5a6b7c8d9eafb0c1d2e3f405") + uuidBytes2, _ := hex.DecodeString("018f3e4c5a6c7c8d9eafb0c1d2e3f405") + uuidBytes3, _ := hex.DecodeString("018f3e4c5b6b7c8d9eafb0c1d2e3f405") + + key1 := ghostferry.NewBinaryKey(uuidBytes1) + key2 := ghostferry.NewBinaryKey(uuidBytes2) + key3 := ghostferry.NewBinaryKey(uuidBytes3) + + assert.Equal(t, -1, key1.Compare(key2)) + assert.Equal(t, -1, key1.Compare(key3)) + assert.Equal(t, -1, key2.Compare(key3)) + + assert.True(t, key1.NumericPosition() < key2.NumericPosition()) + assert.True(t, key2.NumericPosition() < key3.NumericPosition()) +} diff --git a/test/go/state_serialization_test.go b/test/go/state_serialization_test.go new file mode 100644 index 000000000..bc1d28318 --- /dev/null +++ b/test/go/state_serialization_test.go @@ -0,0 +1,331 @@ +package test + +import ( + "encoding/json" + "testing" + + "github.com/Shopify/ghostferry" + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSerializableState_MarshalJSON_EmptyState(t *testing.T) { + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: make(map[string]ghostferry.PaginationKey), + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + assert.NotEmpty(t, data) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "test-version", decoded.GhostferryVersion) + assert.Empty(t, decoded.LastSuccessfulPaginationKeys) + assert.Empty(t, decoded.CompletedTables) +} + +func TestSerializableState_MarshalJSON_WithUint64Keys(t *testing.T) { + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.table1": ghostferry.NewUint64Key(100), + "db.table2": ghostferry.NewUint64Key(200), + "db.table3": ghostferry.NewUint64Key(300), + }, + CompletedTables: map[string]bool{ + "db.table4": true, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "test-version", decoded.GhostferryVersion) + assert.Len(t, decoded.LastSuccessfulPaginationKeys, 3) + + key1, ok := decoded.LastSuccessfulPaginationKeys["db.table1"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(100), uint64(key1)) + + key2, ok := decoded.LastSuccessfulPaginationKeys["db.table2"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(200), uint64(key2)) + + key3, ok := decoded.LastSuccessfulPaginationKeys["db.table3"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(300), uint64(key3)) + + assert.True(t, decoded.CompletedTables["db.table4"]) +} + +func TestSerializableState_MarshalJSON_WithBinaryKeys(t *testing.T) { + uuid1 := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x01} + uuid2 := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x02} + + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.uuid_table1": ghostferry.NewBinaryKey(uuid1), + "db.uuid_table2": ghostferry.NewBinaryKey(uuid2), + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "test-version", decoded.GhostferryVersion) + assert.Len(t, decoded.LastSuccessfulPaginationKeys, 2) + + key1, ok := decoded.LastSuccessfulPaginationKeys["db.uuid_table1"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, uuid1, []byte(key1)) + + key2, ok := decoded.LastSuccessfulPaginationKeys["db.uuid_table2"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, uuid2, []byte(key2)) +} + +func TestSerializableState_MarshalJSON_WithMixedKeys(t *testing.T) { + uuid := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x01} + + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.numeric_table": ghostferry.NewUint64Key(12345), + "db.uuid_table": ghostferry.NewBinaryKey(uuid), + "db.varchar_table": ghostferry.NewBinaryKey([]byte("some_key")), + "db.bigint_table": ghostferry.NewUint64Key(999999999), + }, + CompletedTables: map[string]bool{ + "db.completed_table": true, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "test-version", decoded.GhostferryVersion) + assert.Len(t, decoded.LastSuccessfulPaginationKeys, 4) + + numericKey, ok := decoded.LastSuccessfulPaginationKeys["db.numeric_table"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(12345), uint64(numericKey)) + + uuidKey, ok := decoded.LastSuccessfulPaginationKeys["db.uuid_table"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, uuid, []byte(uuidKey)) + + varcharKey, ok := decoded.LastSuccessfulPaginationKeys["db.varchar_table"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte("some_key"), []byte(varcharKey)) + + bigintKey, ok := decoded.LastSuccessfulPaginationKeys["db.bigint_table"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(999999999), uint64(bigintKey)) + + assert.True(t, decoded.CompletedTables["db.completed_table"]) +} + +func TestSerializableState_MarshalJSON_WithBinlogPosition(t *testing.T) { + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.table1": ghostferry.NewUint64Key(100), + }, + CompletedTables: make(map[string]bool), + LastWrittenBinlogPosition: mysql.Position{ + Name: "mysql-bin.000123", + Pos: 456789, + }, + LastStoredBinlogPositionForInlineVerifier: mysql.Position{ + Name: "mysql-bin.000122", + Pos: 123456, + }, + LastStoredBinlogPositionForTargetVerifier: mysql.Position{ + Name: "mysql-bin.000121", + Pos: 987654, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "mysql-bin.000123", decoded.LastWrittenBinlogPosition.Name) + assert.Equal(t, uint32(456789), decoded.LastWrittenBinlogPosition.Pos) + + assert.Equal(t, "mysql-bin.000122", decoded.LastStoredBinlogPositionForInlineVerifier.Name) + assert.Equal(t, uint32(123456), decoded.LastStoredBinlogPositionForInlineVerifier.Pos) + + assert.Equal(t, "mysql-bin.000121", decoded.LastStoredBinlogPositionForTargetVerifier.Name) + assert.Equal(t, uint32(987654), decoded.LastStoredBinlogPositionForTargetVerifier.Pos) +} + +func TestSerializableState_UnmarshalJSON_CorruptedData(t *testing.T) { + corruptedJSON := `{ + "GhostferryVersion": "test-version", + "LastSuccessfulPaginationKeys": { + "db.table1": {"type": "invalid_type", "value": 123} + } + }` + + var decoded ghostferry.SerializableState + err := json.Unmarshal([]byte(corruptedJSON), &decoded) + assert.Error(t, err) +} + +func TestSerializableState_RoundTrip_LargeState(t *testing.T) { + uuid1 := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x01} + uuid2 := []byte{0x01, 0x8f, 0x3e, 0x4c, 0x5a, 0x6b, 0x7c, 0x8d, 0x9e, 0xaf, 0xb0, 0xc1, 0xd2, 0xe3, 0xf4, 0x02} + + state := &ghostferry.SerializableState{ + GhostferryVersion: "test-version-1.2.3", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "prod.users": ghostferry.NewUint64Key(1000000), + "prod.orders": ghostferry.NewUint64Key(5000000), + "prod.products": ghostferry.NewUint64Key(250000), + "prod.sessions": ghostferry.NewBinaryKey(uuid1), + "prod.api_keys": ghostferry.NewBinaryKey(uuid2), + "staging.users": ghostferry.NewUint64Key(500), + "staging.orders": ghostferry.NewUint64Key(1000), + }, + CompletedTables: map[string]bool{ + "prod.old_table1": true, + "prod.old_table2": true, + "staging.old_table": true, + }, + LastWrittenBinlogPosition: mysql.Position{ + Name: "mysql-bin.001234", + Pos: 987654321, + }, + LastStoredBinlogPositionForInlineVerifier: mysql.Position{ + Name: "mysql-bin.001233", + Pos: 123456789, + }, + LastStoredBinlogPositionForTargetVerifier: mysql.Position{ + Name: "mysql-bin.001232", + Pos: 111222333, + }, + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, state.GhostferryVersion, decoded.GhostferryVersion) + assert.Len(t, decoded.LastSuccessfulPaginationKeys, 7) + assert.Len(t, decoded.CompletedTables, 3) + + usersKey, ok := decoded.LastSuccessfulPaginationKeys["prod.users"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(1000000), uint64(usersKey)) + + sessionsKey, ok := decoded.LastSuccessfulPaginationKeys["prod.sessions"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, uuid1, []byte(sessionsKey)) + + assert.Equal(t, state.LastWrittenBinlogPosition, decoded.LastWrittenBinlogPosition) + assert.Equal(t, state.LastStoredBinlogPositionForInlineVerifier, decoded.LastStoredBinlogPositionForInlineVerifier) + assert.Equal(t, state.LastStoredBinlogPositionForTargetVerifier, decoded.LastStoredBinlogPositionForTargetVerifier) + + for tableName := range state.CompletedTables { + assert.True(t, decoded.CompletedTables[tableName]) + } +} + +func TestSerializableState_JSONStructure(t *testing.T) { + uuid := []byte{0xDE, 0xAD, 0xBE, 0xEF} + state := &ghostferry.SerializableState{ + GhostferryVersion: "test", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.table1": ghostferry.NewUint64Key(123), + "db.table2": ghostferry.NewBinaryKey(uuid), + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var raw map[string]interface{} + err = json.Unmarshal(data, &raw) + require.NoError(t, err) + + keys, ok := raw["LastSuccessfulPaginationKeys"].(map[string]interface{}) + require.True(t, ok) + + table1Data := keys["db.table1"].(map[string]interface{}) + assert.Equal(t, "uint64", table1Data["type"]) + assert.Equal(t, float64(123), table1Data["value"]) + + table2Data := keys["db.table2"].(map[string]interface{}) + assert.Equal(t, "binary", table2Data["type"]) + assert.Equal(t, "deadbeef", table2Data["value"]) +} + +func TestSerializableState_EmptyBinaryKey(t *testing.T) { + state := &ghostferry.SerializableState{ + GhostferryVersion: "test", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.table": ghostferry.NewBinaryKey([]byte{}), + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + key, ok := decoded.LastSuccessfulPaginationKeys["db.table"].(ghostferry.BinaryKey) + require.True(t, ok) + assert.Equal(t, []byte{}, []byte(key)) +} + +func TestSerializableState_ZeroUint64Key(t *testing.T) { + state := &ghostferry.SerializableState{ + GhostferryVersion: "test", + LastSuccessfulPaginationKeys: map[string]ghostferry.PaginationKey{ + "db.table": ghostferry.NewUint64Key(0), + }, + CompletedTables: make(map[string]bool), + } + + data, err := json.Marshal(state) + require.NoError(t, err) + + var decoded ghostferry.SerializableState + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + key, ok := decoded.LastSuccessfulPaginationKeys["db.table"].(ghostferry.Uint64Key) + require.True(t, ok) + assert.Equal(t, uint64(0), uint64(key)) +} diff --git a/test/go/table_schema_cache_test.go b/test/go/table_schema_cache_test.go index fc00a5406..fc2bd18fc 100644 --- a/test/go/table_schema_cache_test.go +++ b/test/go/table_schema_cache_test.go @@ -86,18 +86,19 @@ func (this *TableSchemaCacheTestSuite) TestLoadTablesWithoutFiltering() { } } -func (this *TableSchemaCacheTestSuite) TestLoadTablesRejectTablesWithoutNumericPK() { +func (this *TableSchemaCacheTestSuite) TestLoadTablesAcceptTablesWithVarcharPK() { table := "test_table_4" paginationColumn := "id" - query := fmt.Sprintf("CREATE TABLE %s.%s (%s varchar(20) not null, data TEXT, primary key(%s))", testhelpers.TestSchemaName, table, paginationColumn, paginationColumn) + // Use binary collation to ensure consistent ordering between MySQL and Go + query := fmt.Sprintf("CREATE TABLE %s.%s (%s varchar(20) COLLATE utf8mb4_bin not null, data TEXT, primary key(%s))", testhelpers.TestSchemaName, table, paginationColumn, paginationColumn) _, err := this.Ferry.SourceDB.Exec(query) this.Require().Nil(err) - _, err = ghostferry.LoadTables(this.Ferry.SourceDB, this.tableFilter, nil, nil, nil, nil) + tableSchemaCache, err := ghostferry.LoadTables(this.Ferry.SourceDB, this.tableFilter, nil, nil, nil, nil) - this.Require().NotNil(err) - this.Require().EqualError(err, ghostferry.NonNumericPaginationKeyError(testhelpers.TestSchemaName, table, paginationColumn).Error()) - this.Require().Contains(err.Error(), table) + this.Require().Nil(err) + this.Require().Contains(tableSchemaCache, testhelpers.TestSchemaName+"."+table) + this.Require().Equal(paginationColumn, tableSchemaCache[testhelpers.TestSchemaName+"."+table].GetPaginationColumn().Name) } func (this *TableSchemaCacheTestSuite) TestLoadTablesRejectTablesWithoutNumericPKWithMediumInt() { table := "pagination_by_column_medium_int_pk" @@ -496,6 +497,64 @@ func (suite *TableSchemaCacheTestSuite) TestLoadTablesLoadsVisibleIndexes() { suite.Require().Equal("index_on_name_visible", tables["gftest.test_table_1"].Indexes[1].Name) } +func (this *TableSchemaCacheTestSuite) TestVarcharWithNonBinaryCollationFails() { + // Create a table with VARCHAR PRIMARY KEY using non-binary collation + _, err := this.Ferry.SourceDB.Exec(fmt.Sprintf("CREATE TABLE %s.collation_test (id VARCHAR(10) COLLATE utf8mb4_unicode_ci NOT NULL PRIMARY KEY, data TEXT)", testhelpers.TestSchemaName)) + this.Require().Nil(err) + defer this.Ferry.SourceDB.Exec(fmt.Sprintf("DROP TABLE %s.collation_test", testhelpers.TestSchemaName)) + + tableFilter := &testhelpers.TestTableFilter{ + DbsFunc: testhelpers.DbApplicabilityFilter([]string{testhelpers.TestSchemaName}), + TablesFunc: nil, + } + + _, err = ghostferry.LoadTables(this.Ferry.SourceDB, tableFilter, nil, nil, nil, &ghostferry.CascadingPaginationColumnConfig{}) + this.Require().NotNil(err) + this.Require().Contains(err.Error(), "non-binary collation") + this.Require().Contains(err.Error(), "utf8mb4_unicode_ci") +} + +func (this *TableSchemaCacheTestSuite) TestVarcharWithBinaryCollationPasses() { + // Create a table with VARCHAR PRIMARY KEY using binary collation + _, err := this.Ferry.SourceDB.Exec(fmt.Sprintf("CREATE TABLE %s.collation_test (id VARCHAR(10) COLLATE utf8mb4_bin NOT NULL PRIMARY KEY, data TEXT)", testhelpers.TestSchemaName)) + this.Require().Nil(err) + defer this.Ferry.SourceDB.Exec(fmt.Sprintf("DROP TABLE %s.collation_test", testhelpers.TestSchemaName)) + + tableFilter := &testhelpers.TestTableFilter{ + DbsFunc: testhelpers.DbApplicabilityFilter([]string{testhelpers.TestSchemaName}), + TablesFunc: nil, + } + + tableSchemaCache, err := ghostferry.LoadTables(this.Ferry.SourceDB, tableFilter, nil, nil, nil, &ghostferry.CascadingPaginationColumnConfig{}) + this.Require().Nil(err) + this.Require().NotNil(tableSchemaCache) + + table := tableSchemaCache.Get(testhelpers.TestSchemaName, "collation_test") + this.Require().NotNil(table) + this.Require().Equal("id", table.PaginationKeyColumn.Name) + this.Require().Equal("utf8mb4_bin", table.PaginationKeyColumn.Collation) +} + +func (this *TableSchemaCacheTestSuite) TestVarbinaryPasses() { + // Create a table with VARBINARY PRIMARY KEY (always binary-safe) + _, err := this.Ferry.SourceDB.Exec(fmt.Sprintf("CREATE TABLE %s.collation_test (id VARBINARY(16) NOT NULL PRIMARY KEY, data TEXT)", testhelpers.TestSchemaName)) + this.Require().Nil(err) + defer this.Ferry.SourceDB.Exec(fmt.Sprintf("DROP TABLE %s.collation_test", testhelpers.TestSchemaName)) + + tableFilter := &testhelpers.TestTableFilter{ + DbsFunc: testhelpers.DbApplicabilityFilter([]string{testhelpers.TestSchemaName}), + TablesFunc: nil, + } + + tableSchemaCache, err := ghostferry.LoadTables(this.Ferry.SourceDB, tableFilter, nil, nil, nil, &ghostferry.CascadingPaginationColumnConfig{}) + this.Require().Nil(err) + this.Require().NotNil(tableSchemaCache) + + table := tableSchemaCache.Get(testhelpers.TestSchemaName, "collation_test") + this.Require().NotNil(table) + this.Require().Equal("id", table.PaginationKeyColumn.Name) +} + func TestTableSchemaCache(t *testing.T) { testhelpers.SetupTest() suite.Run(t, &TableSchemaCacheTestSuite{GhostferryUnitTestSuite: &testhelpers.GhostferryUnitTestSuite{}}) diff --git a/test/helpers/db_helper.rb b/test/helpers/db_helper.rb index 68e6468c9..fbcda8d87 100644 --- a/test/helpers/db_helper.rb +++ b/test/helpers/db_helper.rb @@ -1,5 +1,7 @@ +# coding: utf-8 require "logger" require "mysql2" +require "securerandom" module DbHelper ALPHANUMERICS = ("0".."9").to_a + ("a".."z").to_a + ("A".."Z").to_a @@ -9,6 +11,7 @@ module DbHelper DEFAULT_DB = "gftest" DEFAULT_TABLE = "test_table_1" + UUID_TABLE = "test_uuid_table" class Mysql2::Client alias_method :query_without_maginalia, :query @@ -42,6 +45,7 @@ def self.rand_data(length: 32) end DEFAULT_FULL_TABLE_NAME = full_table_name(DEFAULT_DB, DEFAULT_TABLE) + UUID_FULL_TABLE_NAME = full_table_name(DEFAULT_DB, UUID_TABLE) def full_table_name(db, table) DbHelper.full_table_name(db, table) @@ -160,6 +164,49 @@ def seed_simple_database_with_single_table seed_random_data(target_db, number_of_rows: 0) end + def generate_uuid_bytes + uuid_string = [SecureRandom.uuid.delete("-")].pack("H32") + end + + def seed_uuid_data(connection, database_name: DEFAULT_DB, table_name: UUID_TABLE, number_of_rows: 1111) + dbtable = full_table_name(database_name, table_name) + + connection.query("CREATE DATABASE IF NOT EXISTS #{database_name}") + connection.query("CREATE TABLE IF NOT EXISTS #{dbtable} (id VARBINARY(16) NOT NULL, data TEXT, PRIMARY KEY(id))") + + return if number_of_rows == 0 + + insert_statement = connection.prepare("INSERT INTO #{dbtable} (id, data) VALUES (?, ?)") + transaction(connection) do + number_of_rows.times do + uuid_bytes = generate_uuid_bytes + data = rand_data + insert_statement.execute(uuid_bytes, data) + end + end + end + + def seed_simple_database_with_uuid_table + max_rows = 1111 + seed_uuid_data(source_db, number_of_rows: max_rows) + + num_holes = 140 + result = source_db.query("SELECT id FROM #{UUID_FULL_TABLE_NAME} ORDER BY id LIMIT #{num_holes}") + + holes_ids = [] + result.each do |row| + holes_ids << row["id"] + end + + unless holes_ids.empty? + sqlargs = (["?"]*holes_ids.length).join(",") + delete_statement = source_db.prepare("DELETE FROM #{UUID_FULL_TABLE_NAME} WHERE id IN (#{sqlargs})") + delete_statement.execute(*holes_ids) + end + + seed_uuid_data(target_db, number_of_rows: 0) + end + # Get some overall metrics like CHECKSUM, row count, sample row from tables. # Generally used for test validation. def source_and_target_table_metrics(tables: [DEFAULT_FULL_TABLE_NAME]) @@ -184,11 +231,11 @@ def table_metric(conn, table, sample_id: nil) if sample_id.nil? result = conn.query("SELECT * FROM #{table} ORDER BY RAND() LIMIT 1") - metrics[:sample_row] = result.first else - result = conn.query("SELECT * FROM #{table} WHERE id = #{sample_id} LIMIT 1") - metrics[:sample_row] = result.first + stmt = conn.prepare("SELECT * FROM #{table} WHERE id = ? LIMIT 1") + result = stmt.execute(sample_id) end + metrics[:sample_row] = result.first metrics end diff --git a/test/integration/interrupt_resume_test.rb b/test/integration/interrupt_resume_test.rb index fde868efe..2ed3f23c9 100644 --- a/test/integration/interrupt_resume_test.rb +++ b/test/integration/interrupt_resume_test.rb @@ -24,7 +24,7 @@ def test_interrupt_resume_without_writes_to_source_to_check_target_state_when_in result = target_db.query("SELECT MAX(id) AS max_id FROM #{DEFAULT_FULL_TABLE_NAME}") last_successful_id = result.first["max_id"] assert last_successful_id > 0 - assert_equal last_successful_id, dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"] + assert_equal last_successful_id, dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"]["value"] end def test_interrupt_and_resume_without_last_known_schema_cache @@ -553,7 +553,7 @@ def test_issue_149_correct dumped_state = ghostferry.run_expecting_interrupt assert_basic_fields_exist_in_dumped_state(dumped_state) - last_pk = dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"] + last_pk = dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"]["value"] assert last_pk > 200 # We need to rewind the state backwards, and then change that row on the @@ -573,7 +573,7 @@ def test_issue_149_correct data_changed = source_db.query("SELECT data FROM #{DEFAULT_FULL_TABLE_NAME} WHERE id = #{id_to_change}").first["data"] assert_equal "changed", data_changed - dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"] = id_to_change - 1 + dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"]["value"] = id_to_change - 1 ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY, config: { verifier_type: "Inline" }) changed_row_copied = false @@ -623,7 +623,7 @@ def test_issue_149_corrupted dumped_state = ghostferry.run_expecting_interrupt assert_basic_fields_exist_in_dumped_state(dumped_state) - last_pk = dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"] + last_pk = dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"]["value"] assert last_pk > 200 # This should be similar to test_issue_149_correct, except we force the @@ -641,7 +641,7 @@ def test_issue_149_corrupted data_corrupted = target_db.query("SELECT data FROM #{DEFAULT_FULL_TABLE_NAME} WHERE id = #{id_to_change}").first["data"] assert_equal "corrupted", data_corrupted - dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"] = id_to_change - 1 + dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{DEFAULT_TABLE}"]["value"] = id_to_change - 1 ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY, config: { verifier_type: "Inline" }) changed_row_copied = false @@ -680,4 +680,128 @@ def test_issue_149_corrupted assert expectation, "error message: #{error_message.inspect}, didn't start with #{predicate.inspect}" end + + def test_interrupt_resume_without_writes_to_source_with_uuid_table + seed_simple_database_with_uuid_table + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY) + + ghostferry.on_status(Ghostferry::Status::AFTER_ROW_COPY) do + ghostferry.send_signal("TERM") + end + + dumped_state = ghostferry.run_expecting_interrupt + assert_basic_fields_exist_in_dumped_state(dumped_state) + + result = target_db.query("SELECT COUNT(*) AS cnt FROM #{UUID_FULL_TABLE_NAME}") + count = result.first["cnt"] + assert_equal 200, count + + result = target_db.query("SELECT MAX(id) AS max_id FROM #{UUID_FULL_TABLE_NAME}") + last_successful_id_bytes = result.first["max_id"] + assert last_successful_id_bytes.length > 0 + + last_key_in_state = dumped_state["LastSuccessfulPaginationKeys"]["#{DEFAULT_DB}.#{UUID_TABLE}"]["value"] + assert_equal last_successful_id_bytes.unpack1("H*"), last_key_in_state + end + + def test_interrupt_and_resume_without_last_known_schema_cache_with_uuid_table + seed_simple_database_with_uuid_table + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY) + + ghostferry.on_status(Ghostferry::Status::AFTER_ROW_COPY) do + ghostferry.send_signal("TERM") + end + + dumped_state = ghostferry.run_expecting_interrupt + assert_basic_fields_exist_in_dumped_state(dumped_state) + dumped_state["LastKnownTableSchemaCache"] = nil + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY) + + ghostferry.run(dumped_state) + + assert_uuid_table_is_identical + end + + def test_interrupt_resume_with_writes_to_source_with_uuid_table + seed_simple_database_with_uuid_table + + datawriter = new_source_datawriter + ghostferry = new_ghostferry_with_interrupt_after_row_copy(MINIMAL_GHOSTFERRY, after_batches_written: 2) + + start_datawriter_with_ghostferry(datawriter, ghostferry) + + dumped_state = ghostferry.run_expecting_interrupt + assert_basic_fields_exist_in_dumped_state(dumped_state) + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY) + + stop_datawriter_during_cutover(datawriter, ghostferry) + + ghostferry.run(dumped_state) + + assert_uuid_table_is_identical + end + + def test_interrupt_resume_idempotence_with_uuid_table + seed_simple_database_with_uuid_table + + ghostferry = new_ghostferry_with_interrupt_after_row_copy(MINIMAL_GHOSTFERRY) + dumped_state = ghostferry.run_expecting_interrupt + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY) + ghostferry.run_with_logs(dumped_state) + + assert_uuid_table_is_identical + + ghostferry.run_with_logs(dumped_state) + + assert_uuid_table_is_identical + + assert_ghostferry_completed(ghostferry, times: 2) + end + + def test_interrupt_resume_inline_verifier_with_uuid_table + seed_simple_database_with_single_table + seed_simple_database_with_uuid_table + + datawriter = new_source_datawriter + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY, config: { verifier_type: "Inline" }) + + start_datawriter_with_ghostferry(datawriter, ghostferry) + + batches_written = 0 + ghostferry.on_status(Ghostferry::Status::AFTER_ROW_COPY) do + batches_written += 1 + if batches_written >= 5 + ghostferry.term_and_wait_for_exit + end + end + + dumped_state = ghostferry.run_expecting_interrupt + assert_basic_fields_exist_in_dumped_state(dumped_state) + refute_nil dumped_state["BinlogVerifyStore"] + refute_nil dumped_state["BinlogVerifyStore"]["gftest"] + refute_nil dumped_state["BinlogVerifyStore"]["gftest"]["test_table_1"] + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY, config: { verifier_type: "Inline" }) + + verification_ran = false + incorrect_tables = [] + ghostferry.on_status(Ghostferry::Status::VERIFIED) do |*tables| + verification_ran = true + incorrect_tables = tables + end + + stop_datawriter_during_cutover(datawriter, ghostferry) + + ghostferry.run(dumped_state) + + assert verification_ran + assert_equal 0, incorrect_tables.length + assert_test_table_is_identical + assert_uuid_table_is_identical + end end diff --git a/test/test_helper.rb b/test/test_helper.rb index a3a5ba5af..d7247860b 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -92,8 +92,8 @@ def new_ghostferry_with_interrupt_after_row_copy(filepath, config: {}, after_bat g end - def new_source_datawriter(*args) - dw = DataWriter.new(source_db_config, *args, logger: @log_capturer.logger) + def new_source_datawriter(*args, **kwargs) + dw = DataWriter.new(source_db_config, *args, **kwargs, logger: @log_capturer.logger) @datawriter_instances << dw dw end @@ -179,6 +179,25 @@ def assert_test_table_is_identical ) end + def assert_uuid_table_is_identical + source, target = source_and_target_table_metrics(tables: [UUID_FULL_TABLE_NAME]) + + assert source[UUID_FULL_TABLE_NAME][:row_count] > 0 + assert target[UUID_FULL_TABLE_NAME][:row_count] > 0 + + assert_equal( + source[UUID_FULL_TABLE_NAME][:row_count], + target[UUID_FULL_TABLE_NAME][:row_count], + "source and target row count don't match", + ) + + assert_equal( + source[UUID_FULL_TABLE_NAME][:checksum], + target[UUID_FULL_TABLE_NAME][:checksum], + "source and target checksum don't match", + ) + end + # Use this method to assert the validity of the structure of the dumped # state. #