Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 7 additions & 58 deletions batch_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

sql "github.com/Shopify/ghostferry/sqlwrapper"

"github.com/go-mysql-org/go-mysql/schema"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -57,65 +56,15 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
return nil
}

var startPaginationKeypos, endPaginationKeypos PaginationKey
var err error

paginationColumn := batch.TableSchema().GetPaginationColumn()

switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
var startValue, endValue uint64
startValue, err = values[0].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
startPaginationKeypos = NewUint64Key(startValue)
endPaginationKeypos = NewUint64Key(endValue)

case schema.TYPE_BINARY, schema.TYPE_STRING:
startValueInterface := values[0][batch.PaginationKeyIndex()]
endValueInterface := values[len(values)-1][batch.PaginationKeyIndex()]

getBytes := func(val interface{}) ([]byte, error) {
switch v := val.(type) {
case []byte:
return v, nil
case string:
return []byte(v), nil
default:
return nil, fmt.Errorf("expected binary/string pagination key, got %T", val)
}
}

startValue, err := getBytes(startValueInterface)
if err != nil {
return err
}

endValue, err := getBytes(endValueInterface)
if err != nil {
return err
}

startPaginationKeypos = NewBinaryKey(startValue)
endPaginationKeypos = NewBinaryKey(endValue)

default:
var startValue, endValue uint64
startValue, err = values[0].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
startPaginationKeypos = NewUint64Key(startValue)
endPaginationKeypos = NewUint64Key(endValue)
startPaginationKeypos, err := NewPaginationKeyFromRow(values[0], batch.PaginationKeyIndex(), paginationColumn)
if err != nil {
return err
}
endPaginationKeypos, err := NewPaginationKeyFromRow(values[len(values)-1], batch.PaginationKeyIndex(), paginationColumn)
if err != nil {
return err
}

db := batch.TableSchema().Schema
Expand Down
41 changes: 4 additions & 37 deletions cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,43 +262,10 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos Pagina

if len(batchData) > 0 {
lastRowData := batchData[len(batchData)-1]

switch c.paginationKeyColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
var value uint64
value, err = lastRowData.GetUint64(paginationKeyIndex)
if err != nil {
logger.WithError(err).Error("failed to get uint64 paginationKey value")
return
}
paginationKeypos = NewUint64Key(value)

case schema.TYPE_BINARY, schema.TYPE_STRING:
valueInterface := lastRowData[paginationKeyIndex]

var valueBytes []byte
switch v := valueInterface.(type) {
case []byte:
valueBytes = v
case string:
valueBytes = []byte(v)
default:
err = fmt.Errorf("expected binary pagination key to be []byte or string, got %T", valueInterface)
logger.WithError(err).Error("failed to get binary paginationKey value")
return
}

paginationKeypos = NewBinaryKey(valueBytes)

default:
// Fallback for other integer types
var value uint64
value, err = lastRowData.GetUint64(paginationKeyIndex)
if err != nil {
logger.WithError(err).Error("failed to get uint64 paginationKey value")
return
}
paginationKeypos = NewUint64Key(value)
paginationKeypos, err = NewPaginationKeyFromRow(lastRowData, paginationKeyIndex, c.paginationKeyColumn)
if err != nil {
logger.WithError(err).Error("failed to get paginationKey value")
return
}
}

Expand Down
38 changes: 5 additions & 33 deletions data_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

sql "github.com/Shopify/ghostferry/sqlwrapper"

"github.com/go-mysql-org/go-mysql/schema"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -115,40 +114,13 @@ func (d *DataIterator) Run(tables []*TableSchema) {
paginationColumn := table.GetPaginationColumn()

for i, rowData := range batch.Values() {
var paginationKeyStr string

switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex())
if err != nil {
logger.WithError(err).Error("failed to get uint64 paginationKey data")
return err
}
paginationKeyStr = NewUint64Key(paginationKeyUint).String()

case schema.TYPE_BINARY, schema.TYPE_STRING:
paginationKeyInterface := rowData[batch.PaginationKeyIndex()]
var paginationKeyBytes []byte
switch v := paginationKeyInterface.(type) {
case []byte:
paginationKeyBytes = v
case string:
paginationKeyBytes = []byte(v)
default:
return fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface)
}
paginationKeyStr = NewBinaryKey(paginationKeyBytes).String()

default:
paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex())
if err != nil {
logger.WithError(err).Error("failed to get paginationKey data")
return err
}
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
paginationKey, err := NewPaginationKeyFromRow(rowData, batch.PaginationKeyIndex(), paginationColumn)
if err != nil {
logger.WithError(err).Error("failed to get paginationKey data")
return err
}

fingerprints[paginationKeyStr] = rowData[len(rowData)-1].([]byte)
fingerprints[paginationKey.String()] = rowData[len(rowData)-1].([]byte)
rows[i] = rowData[:len(rowData)-1]
}

Expand Down
34 changes: 4 additions & 30 deletions dml_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,35 +576,9 @@ func paginationKeyFromEventData(table *TableSchema, rowData RowData) (string, er
return "", err
}

paginationColumn := table.GetPaginationColumn()
paginationKeyIndex := table.GetPaginationKeyIndex()

switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
paginationKeyUint, err := rowData.GetUint64(paginationKeyIndex)
if err != nil {
return "", err
}
return NewUint64Key(paginationKeyUint).String(), nil

case schema.TYPE_BINARY, schema.TYPE_STRING:
paginationKeyInterface := rowData[paginationKeyIndex]
var paginationKeyBytes []byte
switch v := paginationKeyInterface.(type) {
case []byte:
paginationKeyBytes = v
case string:
paginationKeyBytes = []byte(v)
default:
return "", fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface)
}
return NewBinaryKey(paginationKeyBytes).String(), nil

default:
paginationKeyUint, err := rowData.GetUint64(paginationKeyIndex)
if err != nil {
return "", err
}
return NewUint64Key(paginationKeyUint).String(), nil
paginationKey, err := NewPaginationKeyFromRow(rowData, table.GetPaginationKeyIndex(), table.GetPaginationColumn())
if err != nil {
return "", err
}
return paginationKey.String(), nil
}
64 changes: 8 additions & 56 deletions inline_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,34 +335,11 @@ func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, target

paginationKeys := make([]interface{}, len(sourceBatch.Values()))
for i, row := range sourceBatch.Values() {
switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
paginationKeyUint, err := row.GetUint64(sourceBatch.PaginationKeyIndex())
if err != nil {
return nil, err
}
paginationKeys[i] = paginationKeyUint

case schema.TYPE_BINARY, schema.TYPE_STRING:
paginationKeyInterface := row[sourceBatch.PaginationKeyIndex()]
var paginationKeyBytes []byte
switch v := paginationKeyInterface.(type) {
case []byte:
paginationKeyBytes = v
case string:
paginationKeyBytes = []byte(v)
default:
return nil, fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface)
}
paginationKeys[i] = paginationKeyBytes

default:
paginationKeyUint, err := row.GetUint64(sourceBatch.PaginationKeyIndex())
if err != nil {
return nil, err
}
paginationKeys[i] = paginationKeyUint
paginationKey, err := NewPaginationKeyFromRow(row, sourceBatch.PaginationKeyIndex(), paginationColumn)
if err != nil {
return nil, err
}
paginationKeys[i] = paginationKey.SQLValue()
}

// Fetch target data
Expand All @@ -376,36 +353,11 @@ func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, target
sourceDecompressedData := make(map[string]map[string][]byte)

for _, rowData := range sourceBatch.Values() {
var paginationKeyStr string

switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
paginationKeyUint, err := rowData.GetUint64(sourceBatch.PaginationKeyIndex())
if err != nil {
return nil, err
}
paginationKeyStr = NewUint64Key(paginationKeyUint).String()

case schema.TYPE_BINARY, schema.TYPE_STRING:
paginationKeyInterface := rowData[sourceBatch.PaginationKeyIndex()]
var paginationKeyBytes []byte
switch v := paginationKeyInterface.(type) {
case []byte:
paginationKeyBytes = v
case string:
paginationKeyBytes = []byte(v)
default:
return nil, fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface)
}
paginationKeyStr = NewBinaryKey(paginationKeyBytes).String()

default:
paginationKeyUint, err := rowData.GetUint64(sourceBatch.PaginationKeyIndex())
if err != nil {
return nil, err
}
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
paginationKey, err := NewPaginationKeyFromRow(rowData, sourceBatch.PaginationKeyIndex(), paginationColumn)
if err != nil {
return nil, err
}
paginationKeyStr := paginationKey.String()

sourceDecompressedData[paginationKeyStr] = make(map[string][]byte)
for idx, col := range table.Columns {
Expand Down
58 changes: 8 additions & 50 deletions iterative_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,31 +320,12 @@ func (v *IterativeVerifier) GetHashes(db *sql.DB, schemaName, tableName, paginat
return nil, err
}

var paginationKeyStr string
switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
paginationKeyUint, err := rowData.GetUint64(0)
if err != nil {
return nil, err
}
paginationKeyStr = NewUint64Key(paginationKeyUint).String()

case schema.TYPE_BINARY, schema.TYPE_STRING:
paginationKeyBytes, ok := rowData[0].([]byte)
if !ok {
return nil, fmt.Errorf("expected []byte for binary pagination key, got %T", rowData[0])
}
paginationKeyStr = NewBinaryKey(paginationKeyBytes).String()

default:
paginationKeyUint, err := rowData.GetUint64(0)
if err != nil {
return nil, err
}
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
paginationKey, err := NewPaginationKeyFromRow(rowData, 0, paginationColumn)
if err != nil {
return nil, err
}

resultSet[paginationKeyStr] = rowData[1].([]byte)
resultSet[paginationKey.String()] = rowData[1].([]byte)
}
return resultSet, nil
}
Expand Down Expand Up @@ -422,34 +403,11 @@ func (v *IterativeVerifier) iterateTableFingerprints(table *TableSchema, mismatc
paginationColumn := table.GetPaginationColumn()

for _, rowData := range batch.Values() {
switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
paginationKeys = append(paginationKeys, paginationKeyUint)

case schema.TYPE_BINARY, schema.TYPE_STRING:
paginationKeyInterface := rowData[batch.PaginationKeyIndex()]
var paginationKeyBytes []byte
switch v := paginationKeyInterface.(type) {
case []byte:
paginationKeyBytes = v
case string:
paginationKeyBytes = []byte(v)
default:
return fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface)
}
paginationKeys = append(paginationKeys, paginationKeyBytes)

default:
paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
paginationKeys = append(paginationKeys, paginationKeyUint)
paginationKey, err := NewPaginationKeyFromRow(rowData, batch.PaginationKeyIndex(), paginationColumn)
if err != nil {
return err
}
paginationKeys = append(paginationKeys, paginationKey.SQLValue())
}

mismatchedPaginationKeys, err := v.compareFingerprints(paginationKeys, batch.TableSchema())
Expand Down
Loading