Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions drivers/mysql/internal/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,14 @@ func (m *MySQL) backfill(pool *protocol.WriterPool, stream protocol.Stream) erro
// Begin transaction with repeatable read isolation
return jdbc.WithIsolation(backfillCtx, m.client, func(tx *sql.Tx) error {
// Build query for the chunk
stmt := jdbc.MysqlChunkScanQuery(stream, pkColumn, chunk)
stmt := jdbc.MySQLChunkScanQuery(stream, pkColumn, chunk)
setter := jdbc.NewReader(backfillCtx, stmt, 0, func(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return tx.QueryContext(ctx, query, args...)
if chunk.Min != nil && chunk.Max != nil {
return tx.QueryContext(ctx, query, chunk.Min, chunk.Max)
} else if chunk.Min != nil {
return tx.QueryContext(ctx, query, chunk.Min)
}
return tx.QueryContext(ctx, query, chunk.Max)
})
// Capture and process rows
return setter.Capture(func(rows *sql.Rows) error {
Expand Down Expand Up @@ -134,7 +139,7 @@ func (m *MySQL) splitChunks(stream protocol.Stream, chunks *types.Set[types.Chun
}

// Generate chunks based on range
query := jdbc.NextChunkEndQuery(stream, pkColumn, chunkSize)
query := jdbc.NextChunkEndQuery(stream, pkColumn)

currentVal := minVal
for {
Expand Down
19 changes: 12 additions & 7 deletions drivers/postgres/internal/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (
func (p *Postgres) backfill(pool *protocol.WriterPool, stream protocol.Stream) error {
backfillCtx := context.TODO()
var approxRowCount int64
approxRowCountQuery := jdbc.PostgresRowCountQuery(stream)
err := p.client.QueryRow(approxRowCountQuery).Scan(&approxRowCount)
approxRowCountQuery := jdbc.PostgresRowCountQuery()
err := p.client.QueryRow(approxRowCountQuery, stream.Name(), stream.Namespace()).Scan(&approxRowCount)
if err != nil {
return fmt.Errorf("failed to get approx row count: %s", err)
}
Expand Down Expand Up @@ -55,7 +55,12 @@ func (p *Postgres) backfill(pool *protocol.WriterPool, stream protocol.Stream) e
stmt := jdbc.PostgresChunkScanQuery(stream, splitColumn, chunk)

setter := jdbc.NewReader(backfillCtx, stmt, p.config.BatchSize, func(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return tx.Query(query, args...)
if chunk.Min != nil && chunk.Max != nil {
return tx.Query(query, chunk.Min, chunk.Max)
} else if chunk.Min != nil {
return tx.Query(query, chunk.Min)
}
return tx.Query(query, chunk.Max)
})
batchStartTime := time.Now()
waitChannel := make(chan error, 1)
Expand Down Expand Up @@ -102,8 +107,8 @@ func (p *Postgres) backfill(pool *protocol.WriterPool, stream protocol.Stream) e
func (p *Postgres) splitTableIntoChunks(stream protocol.Stream) ([]types.Chunk, error) {
generateCTIDRanges := func(stream protocol.Stream) ([]types.Chunk, error) {
var relPages uint32
relPagesQuery := jdbc.PostgresRelPageCount(stream)
err := p.client.QueryRow(relPagesQuery).Scan(&relPages)
relPagesQuery := jdbc.PostgresRelPageCount()
err := p.client.QueryRow(relPagesQuery, stream.Name(), stream.Namespace()).Scan(&relPages)
if err != nil {
return nil, fmt.Errorf("failed to get relPages: %s", err)
}
Expand Down Expand Up @@ -193,8 +198,8 @@ func (p *Postgres) splitTableIntoChunks(stream protocol.Stream) ([]types.Chunk,

func (p *Postgres) nextChunkEnd(stream protocol.Stream, previousChunkEnd interface{}, splitColumn string) (interface{}, error) {
var chunkEnd interface{}
nextChunkEnd := jdbc.PostgresNextChunkEndQuery(stream, splitColumn, previousChunkEnd, p.config.BatchSize)
err := p.client.QueryRow(nextChunkEnd).Scan(&chunkEnd)
nextChunkEnd := jdbc.PostgresNextChunkEndQuery(stream, splitColumn)
err := p.client.QueryRow(nextChunkEnd, previousChunkEnd, p.config.BatchSize).Scan(&chunkEnd)
if err != nil {
return nil, fmt.Errorf("failed to query[%s] next chunk end: %s", nextChunkEnd, err)
}
Expand Down
79 changes: 59 additions & 20 deletions pkg/jdbc/jdbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,29 @@ func MinMaxQuery(stream protocol.Stream, column string) string {
}

// NextChunkEndQuery returns the query to calculate the next chunk boundary
func NextChunkEndQuery(stream protocol.Stream, column string, chunkSize int) string {
return fmt.Sprintf(`SELECT MAX(%[1]s) FROM (SELECT %[1]s FROM %[2]s.%[3]s WHERE %[1]s > ? ORDER BY %[1]s LIMIT %[4]d) AS subquery`, column, stream.Namespace(), stream.Name(), chunkSize)
// ?: is the filter value, ?: is the batch size
func NextChunkEndQuery(stream protocol.Stream, column string) string {
return fmt.Sprintf(`SELECT MAX(%[1]s) FROM (SELECT %[1]s FROM %[2]s.%[3]s WHERE %[1]s > ? ORDER BY %[1]s LIMIT ?) AS subquery`, column, stream.Namespace(), stream.Name())
}

// buildChunkCondition builds the condition for a chunk
func buildChunkCondition(filterColumn string, chunk types.Chunk) string {
// PostgresBuildChunkCondition builds the condition for a chunk
func PostgresBuildChunkCondition(filterColumn string, chunk types.Chunk) string {
if chunk.Min != nil && chunk.Max != nil {
return fmt.Sprintf("%s >= %v AND %s <= %v", filterColumn, chunk.Min, filterColumn, chunk.Max)
return fmt.Sprintf("%s >= $1 AND %s <= $2", filterColumn, filterColumn)
} else if chunk.Min != nil {
return fmt.Sprintf("%s >= %v", filterColumn, chunk.Min)
return fmt.Sprintf("%s >= $1", filterColumn)
}
return fmt.Sprintf("%s <= %v", filterColumn, chunk.Max)
return fmt.Sprintf("%s <= $1", filterColumn)
}

// MySQLBuildChunkCondition builds the condition for a chunk
func MySQLBuildChunkCondition(filterColumn string, chunk types.Chunk) string {
if chunk.Min != nil && chunk.Max != nil {
return fmt.Sprintf("%s >= ? AND %s <= ?", filterColumn, filterColumn)
} else if chunk.Min != nil {
return fmt.Sprintf("%s >= ?", filterColumn)
}
return fmt.Sprintf("%s <= ?", filterColumn)
}

// PostgreSQL-Specific Queries
Expand All @@ -43,13 +54,19 @@ func PostgresWithState(stream protocol.Stream) string {
}

// PostgresRowCountQuery returns the query to fetch the estimated row count in PostgreSQL
func PostgresRowCountQuery(stream protocol.Stream) string {
return fmt.Sprintf(`SELECT reltuples::bigint AS approx_row_count FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = '%s' AND n.nspname = '%s';`, stream.Name(), stream.Namespace())
// args to be passed:
// $1: stream name,
// $2: stream namespace
func PostgresRowCountQuery() string {
return `SELECT reltuples::bigint AS approx_row_count FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = $1 AND n.nspname = $2;`
}

// PostgresRelPageCount returns the query to fetch relation page count in PostgreSQL
func PostgresRelPageCount(stream protocol.Stream) string {
return fmt.Sprintf(`SELECT relpages FROM pg_class WHERE relname = '%s' AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = '%s')`, stream.Name(), stream.Namespace())
// args to be passed:
// $1: stream name,
// $2: stream namespace
func PostgresRelPageCount() string {
return `SELECT relpages FROM pg_class WHERE relname = $1 AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = $2);`
}

// PostgresWalLSNQuery returns the query to fetch the current WAL LSN in PostgreSQL
Expand All @@ -58,30 +75,41 @@ func PostgresWalLSNQuery() string {
}

// PostgresNextChunkEndQuery generates a SQL query to fetch the maximum value of a specified column
func PostgresNextChunkEndQuery(stream protocol.Stream, filterColumn string, filterValue interface{}, batchSize int) string {
return fmt.Sprintf(`SELECT MAX(%s) FROM (SELECT %s FROM "%s"."%s" WHERE %s > %v ORDER BY %s ASC LIMIT %d) AS T`, filterColumn, filterColumn, stream.Namespace(), stream.Name(), filterColumn, filterValue, filterColumn, batchSize)
// args to be passed:
// $1: filter value,
// $2: batch size
func PostgresNextChunkEndQuery(stream protocol.Stream, filterColumn string) string {
return fmt.Sprintf(`SELECT MAX(%s) FROM (SELECT %s FROM "%s"."%s" WHERE %s > $1 ORDER BY %s ASC LIMIT $2) AS T`, filterColumn, filterColumn, stream.Namespace(), stream.Name(), filterColumn, filterColumn)
}

// PostgresMinQuery returns the query to fetch the minimum value of a column in PostgreSQL
func PostgresMinQuery(stream protocol.Stream, filterColumn string, filterValue interface{}) string {
return fmt.Sprintf(`SELECT MIN(%s) FROM "%s"."%s" WHERE %s > %v`, filterColumn, stream.Namespace(), stream.Name(), filterColumn, filterValue)
// args to be passed:
// $1: filter value,
func PostgresMinQuery(stream protocol.Stream, filterColumn string) string {
return fmt.Sprintf(`SELECT MIN(%s) FROM "%s"."%s" WHERE %s > $1`, filterColumn, stream.Namespace(), stream.Name(), filterColumn)
}

// PostgresBuildSplitScanQuery builds a chunk scan query for PostgreSQL
// PostgresChunkScanQuery builds a chunk scan query for PostgreSQL
// args to be passed:
// Chunk.Min/Chunk.Max: filter value,
func PostgresChunkScanQuery(stream protocol.Stream, filterColumn string, chunk types.Chunk) string {
condition := buildChunkCondition(filterColumn, chunk)
condition := PostgresBuildChunkCondition(filterColumn, chunk)
return fmt.Sprintf(`SELECT * FROM "%s"."%s" WHERE %s`, stream.Namespace(), stream.Name(), condition)
}

// MySQL-Specific Queries

// MySQLWithoutState builds a chunk scan query for MySql
func MysqlChunkScanQuery(stream protocol.Stream, filterColumn string, chunk types.Chunk) string {
condition := buildChunkCondition(filterColumn, chunk)
// MySQLChunkScanQuery builds a chunk scan query for MySql
// args to be passed:
// Chunk.Min/Chunk.Max: filter value,
func MySQLChunkScanQuery(stream protocol.Stream, filterColumn string, chunk types.Chunk) string {
condition := MySQLBuildChunkCondition(filterColumn, chunk)
return fmt.Sprintf("SELECT * FROM `%s`.`%s` WHERE %s", stream.Namespace(), stream.Name(), condition)
}

// MySQLDiscoverTablesQuery returns the query to discover tables in a MySQL database
// Args:
// ?: schema name (string)
func MySQLDiscoverTablesQuery() string {
return `
SELECT
Expand All @@ -96,6 +124,9 @@ func MySQLDiscoverTablesQuery() string {
}

// MySQLTableSchemaQuery returns the query to fetch schema information for a table in MySQL
// Args:
// ?: schema name (string)
// ?: table name (string)
func MySQLTableSchemaQuery() string {
return `
SELECT
Expand All @@ -114,6 +145,8 @@ func MySQLTableSchemaQuery() string {
}

// MySQLPrimaryKeyQuery returns the query to fetch the primary key column of a table in MySQL
// Args:
// ?: table name (string)
func MySQLPrimaryKeyQuery() string {
return `
SELECT COLUMN_NAME
Expand All @@ -126,6 +159,8 @@ func MySQLPrimaryKeyQuery() string {
}

// MySQLTableRowsQuery returns the query to fetch the estimated row count of a table in MySQL
// Args:
// ?: table name (string)
func MySQLTableRowsQuery() string {
return `
SELECT TABLE_ROWS
Expand All @@ -141,6 +176,9 @@ func MySQLMasterStatusQuery() string {
}

// MySQLTableColumnsQuery returns the query to fetch column names of a table in MySQL
// Args:
// ?: schema name (string)
// ?: table name (string)
func MySQLTableColumnsQuery() string {
return `
SELECT COLUMN_NAME
Expand All @@ -149,6 +187,7 @@ func MySQLTableColumnsQuery() string {
ORDER BY ORDINAL_POSITION
`
}

func WithIsolation(ctx context.Context, client *sql.DB, fn func(tx *sql.Tx) error) error {
tx, err := client.BeginTx(ctx, &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
Expand Down
Loading