From d14037fd58ab97dbd7023dda1c0563d64e676e30 Mon Sep 17 00:00:00 2001 From: wk989898 Date: Mon, 22 Jun 2026 04:06:04 +0000 Subject: [PATCH 01/10] init Signed-off-by: wk989898 --- pkg/config/sink.go | 3 +- pkg/sink/codec/builder.go | 3 + pkg/sink/codec/common/config.go | 10 +- pkg/sink/codec/common/config_test.go | 33 ++ pkg/sink/codec/debezium/avro.go | 511 +++++++++++++++++++++++++++ pkg/sink/codec/debezium/avro_test.go | 131 +++++++ pkg/sink/codec/debezium/encoder.go | 54 ++- 7 files changed, 741 insertions(+), 4 deletions(-) create mode 100644 pkg/sink/codec/common/config_test.go create mode 100644 pkg/sink/codec/debezium/avro.go create mode 100644 pkg/sink/codec/debezium/avro_test.go diff --git a/pkg/config/sink.go b/pkg/config/sink.go index c0f9342953..753f470c14 100644 --- a/pkg/config/sink.go +++ b/pkg/config/sink.go @@ -144,7 +144,8 @@ type SinkConfig struct { DispatchRules []*DispatchRule `toml:"dispatchers" json:"dispatchers,omitempty"` ColumnSelectors []*ColumnSelector `toml:"column-selectors" json:"column-selectors,omitempty"` - // SchemaRegistry is only available when the downstream is MQ using avro protocol. + // SchemaRegistry is only available when the downstream is MQ using avro protocol + // or debezium protocol with Confluent Avro encoding. SchemaRegistry *string `toml:"schema-registry" json:"schema-registry,omitempty"` // EncoderConcurrency is only available when the downstream is MQ. EncoderConcurrency *int `toml:"encoder-concurrency" json:"encoder-concurrency,omitempty"` diff --git a/pkg/sink/codec/builder.go b/pkg/sink/codec/builder.go index 8a17146921..b826898115 100644 --- a/pkg/sink/codec/builder.go +++ b/pkg/sink/codec/builder.go @@ -40,6 +40,9 @@ func NewEventEncoder(ctx context.Context, cfg *common.Config) (common.EventEncod case config.ProtocolCanalJSON: return canal.NewJSONRowEventEncoder(ctx, cfg) case config.ProtocolDebezium: + if cfg.AvroConfluentSchemaRegistry != "" { + return debezium.NewAvroBatchEncoder(ctx, cfg, config.GetGlobalServerConfig().ClusterID) + } return debezium.NewBatchEncoder(cfg, config.GetGlobalServerConfig().ClusterID), nil case config.ProtocolSimple: return simple.NewEncoder(ctx, cfg) diff --git a/pkg/sink/codec/common/config.go b/pkg/sink/codec/common/config.go index 83486f3da3..599d61f95b 100644 --- a/pkg/sink/codec/common/config.go +++ b/pkg/sink/codec/common/config.go @@ -55,7 +55,8 @@ type Config struct { OutputRowKey bool - // avro only + // avro only, except AvroConfluentSchemaRegistry is also used by debezium + // protocol when Confluent Avro encoding is enabled. AvroConfluentSchemaRegistry string AvroDecimalHandlingMode string AvroBigintUnsignedHandlingMode string @@ -410,6 +411,13 @@ func (c *Config) Validate() error { } } + if c.Protocol == config.ProtocolDebezium && c.AvroGlueSchemaRegistry != nil { + return errors.ErrCodecInvalidConfig.GenWithStack( + `Debezium protocol only supports "%s" for Confluent Avro Schema Registry`, + codecOPTAvroSchemaRegistry, + ) + } + if c.MaxMessageBytes <= 0 { return errors.ErrCodecInvalidConfig.Wrap( errors.Errorf("invalid max-message-bytes %d", c.MaxMessageBytes), diff --git a/pkg/sink/codec/common/config_test.go b/pkg/sink/codec/common/config_test.go new file mode 100644 index 0000000000..682504c9d4 --- /dev/null +++ b/pkg/sink/codec/common/config_test.go @@ -0,0 +1,33 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "testing" + + "github.com/pingcap/ticdc/pkg/config" + "github.com/stretchr/testify/require" +) + +func TestDebeziumSchemaRegistryConfig(t *testing.T) { + t.Parallel() + + cfg := NewConfig(config.ProtocolDebezium) + cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" + require.NoError(t, cfg.Validate()) + + cfg = NewConfig(config.ProtocolDebezium) + cfg.AvroGlueSchemaRegistry = &config.GlueSchemaRegistryConfig{} + require.ErrorContains(t, cfg.Validate(), `Debezium protocol only supports "schema-registry"`) +} diff --git a/pkg/sink/codec/debezium/avro.go b/pkg/sink/codec/debezium/avro.go new file mode 100644 index 0000000000..169d683075 --- /dev/null +++ b/pkg/sink/codec/debezium/avro.go @@ -0,0 +1,511 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package debezium + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "strings" + + "github.com/linkedin/goavro/v2" + commonEvent "github.com/pingcap/ticdc/pkg/common/event" + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/codec/common" +) + +const ( + debeziumAvroKeySchemaSuffix = "-key" + debeziumAvroValueSchemaSuffix = "-value" +) + +type debeziumAvroMessage struct { + Schema *debeziumConnectSchema `json:"schema"` + Payload any `json:"payload"` +} + +type debeziumConnectSchema struct { + Type string `json:"type"` + Optional bool `json:"optional"` + Name string `json:"name"` + Version int `json:"version"` + Field string `json:"field"` + Fields []*debeziumConnectSchema `json:"fields"` + Items *debeziumConnectSchema `json:"items"` + Parameters map[string]string `json:"parameters"` +} + +type debeziumAvroSchemaConverter struct { + definedNames map[string]struct{} +} + +func (d *BatchEncoder) appendAvroRowChangedEvent( + ctx context.Context, + topic string, + e *commonEvent.RowEvent, +) error { + keyBuf := bytes.Buffer{} + if err := d.codec.EncodeKey(e, &keyBuf); err != nil { + return errors.Trace(err) + } + + valueBuf := bytes.Buffer{} + if err := d.codec.EncodeValue(e, &valueBuf); err != nil { + return errors.Trace(err) + } + + message, err := d.encodeAvroMessage( + ctx, + topic, + keyBuf.Bytes(), + valueBuf.Bytes(), + e.TableInfo.GetUpdateTS(), + ) + if err != nil { + return err + } + message.Callback = e.Callback + message.IncRowsCount() + + d.messages = append(d.messages, message) + return nil +} + +func (d *BatchEncoder) encodeAvroMessage( + ctx context.Context, + topic string, + keyJSON []byte, + valueJSON []byte, + schemaVersion uint64, +) (*common.Message, error) { + key, err := d.encodeAvroPayload( + ctx, + topic, + debeziumAvroKeySchemaSuffix, + keyJSON, + schemaVersion, + ) + if err != nil { + return nil, err + } + + value, err := d.encodeAvroPayload( + ctx, + topic, + debeziumAvroValueSchemaSuffix, + valueJSON, + schemaVersion, + ) + if err != nil { + return nil, err + } + + return common.NewMsg(key, value), nil +} + +func (d *BatchEncoder) encodeAvroPayload( + ctx context.Context, + topic string, + subjectSuffix string, + data []byte, + schemaVersion uint64, +) ([]byte, error) { + message, err := unmarshalDebeziumAvroMessage(data) + if err != nil { + return nil, err + } + if message.Schema == nil { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("schema is missing") + } + + converter := newDebeziumAvroSchemaConverter() + avroSchema, err := converter.toAvroSchema(message.Schema, "") + if err != nil { + return nil, err + } + schemaBytes, err := json.Marshal(avroSchema) + if err != nil { + return nil, errors.WrapError(errors.ErrAvroMarshalFailed, err) + } + + subject := debeziumAvroSubject(topic, subjectSuffix, message.Schema.Name) + avroCodec, header, err := d.schemaM.GetCachedOrRegister( + ctx, + subject, + schemaVersion, + func() (string, error) { + return string(schemaBytes), nil + }, + ) + if err != nil { + return nil, errors.Trace(err) + } + + native, err := converter.toNative(message.Schema, message.Payload, "") + if err != nil { + return nil, err + } + binaryData, err := avroCodec.BinaryFromNative(nil, native) + if err != nil { + return nil, errors.WrapError(errors.ErrAvroEncodeToBinary, err) + } + + result := make([]byte, 0, len(header)+len(binaryData)) + result = append(result, header...) + result = append(result, binaryData...) + return result, nil +} + +func unmarshalDebeziumAvroMessage(data []byte) (*debeziumAvroMessage, error) { + decoder := json.NewDecoder(bytes.NewReader(data)) + decoder.UseNumber() + + var message debeziumAvroMessage + if err := decoder.Decode(&message); err != nil { + return nil, errors.WrapError(errors.ErrDebeziumInvalidMessage, err) + } + return &message, nil +} + +func newDebeziumAvroSchemaConverter() *debeziumAvroSchemaConverter { + return &debeziumAvroSchemaConverter{ + definedNames: make(map[string]struct{}), + } +} + +func debeziumAvroSubject(topic string, subjectSuffix string, schemaName string) string { + if topic != "" { + return topic + subjectSuffix + } + if schemaName != "" { + return schemaName + } + return "debezium" + subjectSuffix +} + +func (c *debeziumAvroSchemaConverter) toAvroSchema( + schema *debeziumConnectSchema, + fallbackName string, +) (any, error) { + if schema == nil { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("schema is nil") + } + + switch schema.Type { + case "struct": + fullName := avroFullName(schema.Name, fallbackName) + if _, exists := c.definedNames[fullName]; exists { + return fullName, nil + } + c.definedNames[fullName] = struct{}{} + + name, namespace := splitAvroFullName(fullName) + record := map[string]any{ + "type": "record", + "name": name, + "fields": make([]any, 0, len(schema.Fields)), + } + if namespace != "" { + record["namespace"] = namespace + } + addConnectMetadata(record, schema) + + fields := record["fields"].([]any) + for _, fieldSchema := range schema.Fields { + fieldName := avroFieldName(fieldSchema.Field) + fieldType, err := c.toAvroSchema(fieldSchema, fieldName) + if err != nil { + return nil, err + } + + field := map[string]any{ + "name": fieldName, + "type": fieldType, + } + if fieldSchema.Optional { + field["type"] = []any{"null", fieldType} + field["default"] = nil + } + fields = append(fields, field) + } + record["fields"] = fields + return record, nil + case "array": + if schema.Items == nil { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("array schema is missing items") + } + items, err := c.toAvroSchema(schema.Items, fallbackName+"Item") + if err != nil { + return nil, err + } + arraySchema := map[string]any{ + "type": "array", + "items": items, + } + addConnectMetadata(arraySchema, schema) + return arraySchema, nil + default: + avroType, err := connectPrimitiveToAvro(schema.Type) + if err != nil { + return nil, err + } + if !hasConnectMetadata(schema) && schema.Type != "int8" && schema.Type != "int16" { + return avroType, nil + } + primitive := map[string]any{ + "type": avroType, + } + if schema.Type == "int8" || schema.Type == "int16" { + primitive["connect.type"] = schema.Type + } + addConnectMetadata(primitive, schema) + return primitive, nil + } +} + +func (c *debeziumAvroSchemaConverter) toNative( + schema *debeziumConnectSchema, + value any, + fallbackName string, +) (any, error) { + if value == nil { + return nil, nil + } + + switch schema.Type { + case "struct": + valueMap, ok := value.(map[string]any) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("struct payload is not an object") + } + + native := make(map[string]any, len(schema.Fields)) + for _, fieldSchema := range schema.Fields { + fieldName := avroFieldName(fieldSchema.Field) + rawValue := valueMap[fieldSchema.Field] + if rawValue == nil && fieldSchema.Field != fieldName { + rawValue = valueMap[fieldName] + } + + fieldValue, err := c.toNative(fieldSchema, rawValue, fieldName) + if err != nil { + return nil, err + } + if fieldSchema.Optional { + if fieldValue == nil { + native[fieldName] = nil + } else { + native[fieldName] = goavro.Union( + avroUnionBranchName(fieldSchema, fieldName), + fieldValue, + ) + } + } else { + native[fieldName] = fieldValue + } + } + return native, nil + case "array": + values, ok := value.([]any) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("array payload is not an array") + } + native := make([]any, 0, len(values)) + for _, item := range values { + itemValue, err := c.toNative(schema.Items, item, fallbackName+"Item") + if err != nil { + return nil, err + } + if schema.Items.Optional && itemValue != nil { + itemValue = goavro.Union( + avroUnionBranchName(schema.Items, fallbackName+"Item"), + itemValue, + ) + } + native = append(native, itemValue) + } + return native, nil + case "boolean": + v, ok := value.(bool) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("boolean payload is invalid") + } + return v, nil + case "string": + v, ok := value.(string) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("string payload is invalid") + } + return v, nil + case "bytes": + v, ok := value.(string) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("bytes payload is invalid") + } + data, err := base64.StdEncoding.DecodeString(v) + if err != nil { + return nil, errors.WrapError(errors.ErrDebeziumInvalidMessage, err) + } + return data, nil + case "int8", "int16", "int32": + v, err := numberToInt64(value) + if err != nil { + return nil, err + } + return int32(v), nil + case "int64": + return numberToInt64(value) + case "float": + v, err := numberToFloat64(value) + if err != nil { + return nil, err + } + return float32(v), nil + case "double": + return numberToFloat64(value) + default: + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("unsupported schema type " + schema.Type) + } +} + +func connectPrimitiveToAvro(connectType string) (string, error) { + switch connectType { + case "boolean": + return "boolean", nil + case "string": + return "string", nil + case "bytes": + return "bytes", nil + case "int8", "int16", "int32": + return "int", nil + case "int64": + return "long", nil + case "float": + return "float", nil + case "double": + return "double", nil + default: + return "", errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("unsupported schema type " + connectType) + } +} + +func addConnectMetadata(avroSchema map[string]any, schema *debeziumConnectSchema) { + if schema.Name != "" { + avroSchema["connect.name"] = schema.Name + } + if schema.Version != 0 { + avroSchema["connect.version"] = schema.Version + } + if len(schema.Parameters) != 0 { + avroSchema["connect.parameters"] = schema.Parameters + } +} + +func hasConnectMetadata(schema *debeziumConnectSchema) bool { + return schema.Name != "" || schema.Version != 0 || len(schema.Parameters) != 0 +} + +func avroFullName(connectName string, fallbackName string) string { + if connectName != "" { + return connectName + } + if fallbackName != "" { + return avroFieldName(fallbackName) + } + return "ConnectDefault" +} + +func splitAvroFullName(fullName string) (name string, namespace string) { + idx := strings.LastIndex(fullName, ".") + if idx < 0 { + return avroFieldName(fullName), "" + } + return avroFieldName(fullName[idx+1:]), fullName[:idx] +} + +func avroFieldName(field string) string { + return common.SanitizeName(field) +} + +func avroUnionBranchName(schema *debeziumConnectSchema, fallbackName string) string { + switch schema.Type { + case "struct": + return avroFullName(schema.Name, fallbackName) + case "array": + return "array" + case "int8", "int16", "int32": + return "int" + case "int64": + return "long" + case "float": + return "float" + case "double": + return "double" + default: + return schema.Type + } +} + +func numberToInt64(value any) (int64, error) { + switch v := value.(type) { + case json.Number: + i, err := v.Int64() + if err == nil { + return i, nil + } + f, err := v.Float64() + if err != nil { + return 0, errors.WrapError(errors.ErrDebeziumInvalidMessage, err) + } + return int64(f), nil + case int: + return int64(v), nil + case int32: + return int64(v), nil + case int64: + return v, nil + case uint64: + return int64(v), nil + case float64: + return int64(v), nil + default: + return 0, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("number payload is invalid") + } +} + +func numberToFloat64(value any) (float64, error) { + switch v := value.(type) { + case json.Number: + f, err := v.Float64() + if err != nil { + return 0, errors.WrapError(errors.ErrDebeziumInvalidMessage, err) + } + return f, nil + case int: + return float64(v), nil + case int32: + return float64(v), nil + case int64: + return float64(v), nil + case uint64: + return float64(v), nil + case float32: + return float64(v), nil + case float64: + return v, nil + default: + return 0, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("number payload is invalid") + } +} diff --git a/pkg/sink/codec/debezium/avro_test.go b/pkg/sink/codec/debezium/avro_test.go new file mode 100644 index 0000000000..29c664c6f1 --- /dev/null +++ b/pkg/sink/codec/debezium/avro_test.go @@ -0,0 +1,131 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package debezium + +import ( + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" + + "github.com/pingcap/ticdc/downstreamadapter/sink/columnselector" + commonEvent "github.com/pingcap/ticdc/pkg/common/event" + "github.com/pingcap/ticdc/pkg/config" + "github.com/pingcap/ticdc/pkg/sink/codec/avro" + "github.com/pingcap/ticdc/pkg/sink/codec/common" + "github.com/stretchr/testify/require" +) + +func TestDebeziumConfluentAvroEncodeRowEvent(t *testing.T) { + ctx := context.Background() + _, err := avro.SetupEncoderAndSchemaRegistry4Testing( + ctx, + common.NewConfig(config.ProtocolAvro), + ) + require.NoError(t, err) + defer avro.TeardownEncoderAndSchemaRegistry4Testing() + + helper := NewSQLTestHelper(t, "foo", ` + create table foo( + id int primary key, + name varchar(16), + v bigint null + )`) + defer helper.Close() + + dmls := helper.helper.DML2Event("test", "foo", "insert into foo values (1, 'alice', null)") + row, ok := dmls.GetNextRow() + require.True(t, ok) + + cfg := common.NewConfig(config.ProtocolDebezium) + cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" + cfg.DebeziumDisableSchema = true + cfg.TimeZone = time.UTC + + encoder, err := NewAvroBatchEncoder(ctx, cfg, "dbserver1") + require.NoError(t, err) + require.NoError(t, encoder.AppendRowChangedEvent(ctx, "dbserver1.test.foo", &commonEvent.RowEvent{ + TableInfo: helper.tableInfo, + CommitTs: 1, + Event: row, + ColumnSelector: columnselector.NewDefaultColumnSelector(), + Callback: func() {}, + })) + + messages := encoder.Build() + require.Len(t, messages, 1) + require.Equal(t, byte(0), messages[0].Key[0]) + require.Equal(t, byte(0), messages[0].Value[0]) + + key := decodeConfluentAvroForTest(t, messages[0].Key) + require.Equal(t, int32(1), key["id"]) + + value := decodeConfluentAvroForTest(t, messages[0].Value) + require.Equal(t, "c", value["op"]) + require.Nil(t, value["before"]) + + afterUnion, ok := value["after"].(map[string]any) + require.True(t, ok) + after, ok := afterUnion["dbserver1.test.foo.Value"].(map[string]any) + require.True(t, ok) + require.Equal(t, int32(1), after["id"]) + name, ok := after["name"].(map[string]any) + require.True(t, ok) + require.Equal(t, "alice", name["string"]) + require.Nil(t, after["v"]) + + source, ok := value["source"].(map[string]any) + require.True(t, ok) + require.Equal(t, "test", source["db"]) + table, ok := source["table"].(map[string]any) + require.True(t, ok) + require.Equal(t, "foo", table["string"]) + require.Equal(t, "dbserver1", source["name"]) +} + +func decodeConfluentAvroForTest(t *testing.T, data []byte) map[string]any { + t.Helper() + + require.GreaterOrEqual(t, len(data), 5) + require.Equal(t, byte(0), data[0]) + schemaID := int(binary.BigEndian.Uint32(data[1:5])) + binaryData := data[5:] + + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:8081/schemas/ids/%d", schemaID)) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var schemaResp struct { + Schema string `json:"schema"` + } + require.NoError(t, json.Unmarshal(body, &schemaResp)) + + codec, err := avro.GenCodec(schemaResp.Schema) + require.NoError(t, err) + + native, _, err := codec.NativeFromBinary(binaryData) + require.NoError(t, err) + + result, ok := native.(map[string]any) + require.True(t, ok) + return result +} diff --git a/pkg/sink/codec/debezium/encoder.go b/pkg/sink/codec/debezium/encoder.go index c0f8c3d07a..80deca2d44 100644 --- a/pkg/sink/codec/debezium/encoder.go +++ b/pkg/sink/codec/debezium/encoder.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/log" commonEvent "github.com/pingcap/ticdc/pkg/common/event" "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/pkg/sink/codec/avro" "github.com/pingcap/ticdc/pkg/sink/codec/common" "go.uber.org/zap" ) @@ -31,6 +32,8 @@ type BatchEncoder struct { config *common.Config codec *dbzCodec + + schemaM avro.SchemaManager } // EncodeCheckpointEvent implements the RowEventEncoder interface @@ -44,6 +47,15 @@ func (d *BatchEncoder) EncodeCheckpointEvent(ts uint64) (*common.Message, error) if err != nil { return nil, errors.Trace(err) } + if d.schemaM != nil { + return d.encodeAvroMessage( + context.Background(), + "", + keyMap.Bytes(), + valueBuf.Bytes(), + 0, + ) + } key, err := common.Compress( d.config.ChangefeedID, d.config.LargeMessageHandle.LargeMessageHandleCompression, @@ -66,10 +78,14 @@ func (d *BatchEncoder) EncodeCheckpointEvent(ts uint64) (*common.Message, error) // AppendRowChangedEvent implements the RowEventEncoder interface func (d *BatchEncoder) AppendRowChangedEvent( - _ context.Context, - _ string, + ctx context.Context, + topic string, e *commonEvent.RowEvent, ) error { + if d.schemaM != nil { + return d.appendAvroRowChangedEvent(ctx, topic, e) + } + var key []byte var value []byte var err error @@ -103,6 +119,15 @@ func (d *BatchEncoder) EncodeDDLEvent(e *commonEvent.DDLEvent) (*common.Message, } return nil, errors.Trace(err) } + if d.schemaM != nil { + return d.encodeAvroMessage( + context.Background(), + "", + keyMap.Bytes(), + valueBuf.Bytes(), + 0, + ) + } key, err := common.Compress( d.config.ChangefeedID, d.config.LargeMessageHandle.LargeMessageHandleCompression, @@ -180,3 +205,28 @@ func NewBatchEncoder(c *common.Config, clusterID string) common.EventEncoder { } return batch } + +func NewAvroBatchEncoder( + ctx context.Context, + c *common.Config, + clusterID string, +) (common.EventEncoder, error) { + schemaM, err := avro.NewConfluentSchemaManager(ctx, c.AvroConfluentSchemaRegistry, nil) + if err != nil { + return nil, errors.Trace(err) + } + + codecConfig := *c + codecConfig.DebeziumDisableSchema = false + batch := &BatchEncoder{ + messages: nil, + config: c, + codec: &dbzCodec{ + config: &codecConfig, + clusterID: clusterID, + nowFunc: time.Now, + }, + schemaM: schemaM, + } + return batch, nil +} From 3f4d30265ba8f9719404facb19ae957b1c780291 Mon Sep 17 00:00:00 2001 From: wk989898 Date: Mon, 22 Jun 2026 06:58:45 +0000 Subject: [PATCH 02/10] add test Signed-off-by: wk989898 --- .../debezium_avro/data/prepare.sql | 83 +++ .../debezium_avro/data/workload.sql | 75 +++ tests/integration_tests/debezium_avro/run.sh | 65 ++ .../debezium_avro/verify/main.go | 627 ++++++++++++++++++ tests/integration_tests/run_light_it_in_ci.sh | 2 +- 5 files changed, 851 insertions(+), 1 deletion(-) create mode 100644 tests/integration_tests/debezium_avro/data/prepare.sql create mode 100644 tests/integration_tests/debezium_avro/data/workload.sql create mode 100644 tests/integration_tests/debezium_avro/run.sh create mode 100644 tests/integration_tests/debezium_avro/verify/main.go diff --git a/tests/integration_tests/debezium_avro/data/prepare.sql b/tests/integration_tests/debezium_avro/data/prepare.sql new file mode 100644 index 0000000000..f4c762f03e --- /dev/null +++ b/tests/integration_tests/debezium_avro/data/prepare.sql @@ -0,0 +1,83 @@ +DROP DATABASE IF EXISTS test; +CREATE DATABASE test; +USE test; + +CREATE TABLE tp_int ( + id INT AUTO_INCREMENT, + c_tinyint TINYINT NULL, + c_smallint SMALLINT NULL, + c_mediumint MEDIUMINT NULL, + c_int INT NULL, + c_bigint BIGINT NULL, + PRIMARY KEY (id) +); + +CREATE TABLE tp_unsigned_int ( + id INT AUTO_INCREMENT, + c_unsigned_tinyint TINYINT UNSIGNED NULL, + c_unsigned_smallint SMALLINT UNSIGNED NULL, + c_unsigned_mediumint MEDIUMINT UNSIGNED NULL, + c_unsigned_int INT UNSIGNED NULL, + c_unsigned_bigint BIGINT UNSIGNED NULL, + PRIMARY KEY (id) +); + +CREATE TABLE tp_real ( + id INT AUTO_INCREMENT, + c_float FLOAT NULL, + c_double DOUBLE NULL, + c_decimal DECIMAL NULL, + c_decimal_2 DECIMAL(10, 4) NULL, + PRIMARY KEY (id) +); + +CREATE TABLE tp_time ( + id INT AUTO_INCREMENT, + c_date DATE NULL, + c_datetime DATETIME NULL, + c_timestamp TIMESTAMP NULL, + c_time TIME NULL, + c_year YEAR NULL, + PRIMARY KEY (id) +); + +CREATE TABLE tp_text ( + id INT AUTO_INCREMENT, + c_tinytext TINYTEXT NULL, + c_text TEXT NULL, + c_mediumtext MEDIUMTEXT NULL, + c_longtext LONGTEXT NULL, + PRIMARY KEY (id) +); + +CREATE TABLE tp_blob ( + id INT AUTO_INCREMENT, + c_tinyblob TINYBLOB NULL, + c_blob BLOB NULL, + c_mediumblob MEDIUMBLOB NULL, + c_longblob LONGBLOB NULL, + PRIMARY KEY (id) +); + +CREATE TABLE tp_char_binary ( + id INT AUTO_INCREMENT, + c_char CHAR(16) NULL, + c_varchar VARCHAR(16) NULL, + c_binary BINARY(16) NULL, + c_varbinary VARBINARY(16) NULL, + PRIMARY KEY (id) +); + +CREATE TABLE tp_other ( + id INT AUTO_INCREMENT, + c_enum ENUM ('a', 'b', 'c') NULL, + c_set SET ('a', 'b', 'c') NULL, + c_bit BIT(64) NULL, + c_json JSON NULL, + PRIMARY KEY (id) +); + +CREATE TABLE tp_account ( + id INT PRIMARY KEY, + account_id INT NOT NULL +); diff --git a/tests/integration_tests/debezium_avro/data/workload.sql b/tests/integration_tests/debezium_avro/data/workload.sql new file mode 100644 index 0000000000..9931a83409 --- /dev/null +++ b/tests/integration_tests/debezium_avro/data/workload.sql @@ -0,0 +1,75 @@ +USE test; + +INSERT INTO tp_int() VALUES (); +INSERT INTO tp_int(c_tinyint, c_smallint, c_mediumint, c_int, c_bigint) +VALUES (1, 2, 3, 4, 5); +INSERT INTO tp_int(c_tinyint, c_smallint, c_mediumint, c_int, c_bigint) +VALUES (127, 32767, 8388607, 2147483647, 9223372036854775807); +INSERT INTO tp_int(c_tinyint, c_smallint, c_mediumint, c_int, c_bigint) +VALUES (-128, -32768, -8388608, -2147483648, -9223372036854775808); +UPDATE tp_int SET c_int = 0, c_tinyint = 0 WHERE id = 2; +DELETE FROM tp_int WHERE id = 2; + +INSERT INTO tp_unsigned_int() VALUES (); +INSERT INTO tp_unsigned_int( + c_unsigned_tinyint, + c_unsigned_smallint, + c_unsigned_mediumint, + c_unsigned_int, + c_unsigned_bigint +) VALUES (1, 2, 3, 4, 5); +INSERT INTO tp_unsigned_int( + c_unsigned_tinyint, + c_unsigned_smallint, + c_unsigned_mediumint, + c_unsigned_int, + c_unsigned_bigint +) VALUES (255, 65535, 16777215, 4294967295, 18446744073709551615); +UPDATE tp_unsigned_int SET c_unsigned_int = 0, c_unsigned_tinyint = 0 WHERE id = 3; +DELETE FROM tp_unsigned_int WHERE id = 3; + +INSERT INTO tp_real() VALUES (); +INSERT INTO tp_real(c_float, c_double, c_decimal, c_decimal_2) +VALUES (2020.0202, 2020.0303, 2020.0404, 2021.1208); +INSERT INTO tp_real(c_float, c_double, c_decimal, c_decimal_2) +VALUES (-2.7182818284, -3.1415926, -8000, -179394.233); +UPDATE tp_real SET c_double = 2.333 WHERE id = 2; + +INSERT INTO tp_time() VALUES (); +INSERT INTO tp_time(c_date, c_datetime, c_timestamp, c_time, c_year) +VALUES ('2020-02-20', '2020-02-20 02:20:20', '2020-02-20 02:20:20', '02:20:20', '2020'); +INSERT INTO tp_time(c_date, c_datetime, c_timestamp, c_time, c_year) +VALUES ('2022-02-22', '2022-02-22 22:22:22', '2020-02-20 02:20:20', '02:20:20', '2021'); +UPDATE tp_time SET c_year = '2022' WHERE id = 2; + +INSERT INTO tp_text() VALUES (); +INSERT INTO tp_text(c_tinytext, c_text, c_mediumtext, c_longtext) +VALUES ('89504E470D0A1A0A', '89504E470D0A1A0A', '89504E470D0A1A0A', '89504E470D0A1A0A'); +INSERT INTO tp_text(c_tinytext, c_text, c_mediumtext, c_longtext) +VALUES ('89504E470D0A1A0B', '89504E470D0A1A0B', '89504E470D0A1A0B', '89504E470D0A1A0B'); +UPDATE tp_text SET c_text = '89504E470D0A1A0B' WHERE id = 2; + +INSERT INTO tp_blob() VALUES (); +INSERT INTO tp_blob(c_tinyblob, c_blob, c_mediumblob, c_longblob) +VALUES (x'89504E470D0A1A0A', x'89504E470D0A1A0A', x'89504E470D0A1A0A', x'89504E470D0A1A0A'); +INSERT INTO tp_blob(c_tinyblob, c_blob, c_mediumblob, c_longblob) +VALUES (x'89504E470D0A1A0B', x'89504E470D0A1A0B', x'89504E470D0A1A0B', x'89504E470D0A1A0B'); +UPDATE tp_blob SET c_blob = x'89504E470D0A1A0B' WHERE id = 2; + +INSERT INTO tp_char_binary() VALUES (); +INSERT INTO tp_char_binary(c_char, c_varchar, c_binary, c_varbinary) +VALUES ('89504E470D0A1A0A', '89504E470D0A1A0A', x'89504E470D0A1A0A', x'89504E470D0A1A0A'); +INSERT INTO tp_char_binary(c_char, c_varchar, c_binary, c_varbinary) +VALUES ('89504E470D0A1A0B', '89504E470D0A1A0B', x'89504E470D0A1A0B', x'89504E470D0A1A0B'); +UPDATE tp_char_binary SET c_varchar = '89504E470D0A1A0B' WHERE id = 2; + +INSERT INTO tp_other() VALUES (); +INSERT INTO tp_other(c_enum, c_set, c_bit, c_json) +VALUES ('a', 'a,b', b'1000001', '{"key1":"value1","key2":"value2"}'); +INSERT INTO tp_other(c_enum, c_set, c_bit, c_json) +VALUES ('b', 'b,c', b'1000001', '{"key1":"value1","key2":"value2","key3":"123"}'); +UPDATE tp_other SET c_enum = 'c' WHERE id = 3; + +INSERT INTO tp_account VALUES (12, 34); +UPDATE tp_account SET account_id = 35 WHERE id = 12; +DELETE FROM tp_account WHERE id = 12; diff --git a/tests/integration_tests/debezium_avro/run.sh b/tests/integration_tests/debezium_avro/run.sh new file mode 100644 index 0000000000..fccabb10ab --- /dev/null +++ b/tests/integration_tests/debezium_avro/run.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +set -e + +CUR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +source $CUR/../_utils/test_prepare +WORK_DIR=$OUT_DIR/$TEST_NAME +CDC_BINARY=cdc.test +SINK_TYPE=$1 + +function start_schema_registry() { + if ! curl -o /dev/null -s "http://127.0.0.1:8088"; then + echo 'Starting schema registry...' + ./bin/bin/schema-registry-start -daemon ./bin/etc/schema-registry/schema-registry.properties + local i=0 + while ! curl -o /dev/null -s "http://127.0.0.1:8088"; do + i=$((i + 1)) + if [ "$i" -gt 30 ]; then + echo 'Failed to start schema registry' + exit 1 + fi + sleep 2 + done + fi + + curl -X PUT -H "Content-Type: application/vnd.schemaregistry.v1+json" --data '{"compatibility": "NONE"}' http://127.0.0.1:8088/config +} + +function run() { + if [ "$SINK_TYPE" != "kafka" ]; then + return + fi + + rm -rf "$WORK_DIR" && mkdir -p "$WORK_DIR" + + start_schema_registry + start_tidb_cluster --workdir "$WORK_DIR" + run_sql_file "$CUR/data/prepare.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" + + start_ts=$(run_cdc_cli_tso_query "$UP_PD_HOST_1" "$UP_PD_PORT_1") + + run_cdc_server --workdir "$WORK_DIR" --binary "$CDC_BINARY" + + TOPIC_NAME="ticdc-debezium-avro-$RANDOM" + SINK_URI="kafka://127.0.0.1:9092/$TOPIC_NAME?protocol=debezium&enable-tidb-extension=true&partition-num=1&kafka-version=${KAFKA_VERSION}&max-message-bytes=10485760" + schema_registry_uri="http://127.0.0.1:8088" + changefeed_id="debezium-avro-$RANDOM" + + cdc_cli_changefeed create --start-ts="$start_ts" --sink-uri="$SINK_URI" -c "$changefeed_id" --schema-registry="$schema_registry_uri" + run_sql_file "$CUR/data/workload.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" + + GO111MODULE=on go run ./tests/integration_tests/debezium_avro/verify \ + --topic "$TOPIC_NAME" \ + --kafka-addr "127.0.0.1:9092" \ + --schema-registry "$schema_registry_uri" \ + --timeout "120s" \ + 2>&1 | tee "$WORK_DIR/debezium_avro_verify.log" + + cleanup_process "$CDC_BINARY" +} + +trap 'stop_test $WORK_DIR' EXIT +run "$@" +check_logs "$WORK_DIR" +echo "[$(date)] <<<<<< run test case $TEST_NAME success! >>>>>>" diff --git a/tests/integration_tests/debezium_avro/verify/main.go b/tests/integration_tests/debezium_avro/verify/main.go new file mode 100644 index 0000000000..bf0305b258 --- /dev/null +++ b/tests/integration_tests/debezium_avro/verify/main.go @@ -0,0 +1,627 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "encoding/binary" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "math" + "net/http" + "net/url" + "os" + "sort" + "strings" + "time" + + "github.com/linkedin/goavro/v2" + "github.com/segmentio/kafka-go" +) + +var requiredChecks = []string{ + "op_insert", + "op_update", + "op_delete", + "tp_int_insert", + "tp_int_update", + "tp_int_delete", + "tp_unsigned_normal_insert", + "tp_unsigned_max_insert", + "tp_real_insert", + "tp_real_update", + "tp_time_insert", + "tp_time_update", + "tp_text_update", + "tp_blob_update", + "tp_char_binary_update", + "tp_other_update", + "tp_account_delete", +} + +type schemaRegistryClient struct { + baseURL string + client *http.Client + codecs map[int]*goavro.Codec +} + +type rowEvent struct { + offset int64 + op string + table string + key map[string]any + before map[string]any + after map[string]any +} + +type coverage struct { + checks map[string]bool + tables map[string]bool + count int +} + +func main() { + if err := run(); err != nil { + fmt.Fprintf(os.Stderr, "verify debezium avro failed: %v\n", err) + os.Exit(1) + } +} + +func run() error { + var ( + topic = flag.String("topic", "", "Kafka topic to verify") + kafkaAddr = flag.String("kafka-addr", "127.0.0.1:9092", "Kafka broker address") + schemaRegistry = flag.String("schema-registry", "http://127.0.0.1:8088", "Confluent Schema Registry URI") + timeout = flag.Duration("timeout", 60*time.Second, "Time to wait for the row events") + ) + flag.Parse() + + if *topic == "" { + return errors.New("topic is required") + } + + ctx, cancel := context.WithTimeout(context.Background(), *timeout) + defer cancel() + + registry := &schemaRegistryClient{ + baseURL: strings.TrimRight(*schemaRegistry, "/"), + client: &http.Client{Timeout: 10 * time.Second}, + codecs: make(map[int]*goavro.Codec), + } + + reader := kafka.NewReader(kafka.ReaderConfig{ + Brokers: []string{*kafkaAddr}, + Topic: *topic, + Partition: 0, + MinBytes: 1, + MaxBytes: 10e6, + MaxWait: 500 * time.Millisecond, + ReadBackoffMin: 100 * time.Millisecond, + ReadBackoffMax: time.Second, + }) + defer reader.Close() + if err := reader.SetOffset(kafka.FirstOffset); err != nil { + return fmt.Errorf("set kafka reader offset: %w", err) + } + + result := &coverage{ + checks: make(map[string]bool), + tables: make(map[string]bool), + } + var lastErr error + for { + message, err := reader.ReadMessage(ctx) + if err != nil { + if ctx.Err() != nil { + if lastErr != nil { + return fmt.Errorf("timed out waiting for Debezium Avro row events on topic %s; coverage: %s; last error: %w", *topic, result.summary(), lastErr) + } + return fmt.Errorf("timed out waiting for Debezium Avro row events on topic %s; coverage: %s", *topic, result.summary()) + } + lastErr = err + continue + } + + event, err := decodeRowEvent(ctx, registry, message) + if err != nil { + lastErr = err + continue + } + if event == nil { + continue + } + if err := result.observe(event); err != nil { + return err + } + if result.done() { + if err := registry.ensureSubject(ctx, *topic+"-key"); err != nil { + return err + } + if err := registry.ensureSubject(ctx, *topic+"-value"); err != nil { + return err + } + fmt.Printf("verified %d Debezium Confluent Avro row events from topic %s; coverage: %s\n", result.count, *topic, result.summary()) + return nil + } + } +} + +func decodeRowEvent( + ctx context.Context, + registry *schemaRegistryClient, + message kafka.Message, +) (*rowEvent, error) { + key, err := decodeConfluentAvro(ctx, registry, message.Key) + if err != nil { + return nil, fmt.Errorf("decode key at offset %d: %w", message.Offset, err) + } + value, err := decodeConfluentAvro(ctx, registry, message.Value) + if err != nil { + return nil, fmt.Errorf("decode value at offset %d: %w", message.Offset, err) + } + + op, ok := value["op"].(string) + if !ok || (op != "c" && op != "u" && op != "d") { + return nil, nil + } + + source, ok := asMap(value["source"]) + if !ok { + return nil, fmt.Errorf("source is not a record at offset %d: %T", message.Offset, value["source"]) + } + if source["db"] != "test" { + return nil, fmt.Errorf("unexpected source db at offset %d: %v", message.Offset, source["db"]) + } + table, ok := stringValue(source["table"]) + if !ok { + return nil, fmt.Errorf("source table is not a string at offset %d: %v", message.Offset, source["table"]) + } + + event := &rowEvent{ + offset: message.Offset, + op: op, + table: table, + key: key, + } + event.before, _ = asMap(value["before"]) + event.after, _ = asMap(value["after"]) + return event, nil +} + +func (c *coverage) observe(event *rowEvent) error { + c.count++ + c.tables[event.table] = true + + switch event.op { + case "c": + c.checks["op_insert"] = true + if event.before != nil || event.after == nil { + return fmt.Errorf("invalid insert shape at offset %d", event.offset) + } + case "u": + c.checks["op_update"] = true + if event.before == nil || event.after == nil { + return fmt.Errorf("invalid update shape at offset %d", event.offset) + } + case "d": + c.checks["op_delete"] = true + if event.before == nil || event.after != nil { + return fmt.Errorf("invalid delete shape at offset %d", event.offset) + } + } + + id, ok := intValue(event.key["id"]) + if !ok { + return fmt.Errorf("key id is not an integer at offset %d: %v", event.offset, event.key["id"]) + } + + switch event.table { + case "tp_int": + return c.observeInt(event, id) + case "tp_unsigned_int": + return c.observeUnsignedInt(event, id) + case "tp_real": + return c.observeReal(event, id) + case "tp_time": + return c.observeTime(event, id) + case "tp_text": + return c.observeText(event, id) + case "tp_blob": + return c.observeBlob(event, id) + case "tp_char_binary": + return c.observeCharBinary(event, id) + case "tp_other": + return c.observeOther(event, id) + case "tp_account": + return c.observeAccount(event, id) + default: + return fmt.Errorf("unexpected table %s at offset %d", event.table, event.offset) + } +} + +func (c *coverage) observeInt(event *rowEvent, id int64) error { + switch { + case event.op == "c" && id == 2: + if err := expectInt(event.after, "c_tinyint", 1); err != nil { + return err + } + if err := expectInt(event.after, "c_smallint", 2); err != nil { + return err + } + if err := expectInt(event.after, "c_mediumint", 3); err != nil { + return err + } + if err := expectInt(event.after, "c_int", 4); err != nil { + return err + } + if err := expectInt(event.after, "c_bigint", 5); err != nil { + return err + } + c.checks["tp_int_insert"] = true + case event.op == "u" && id == 2: + if err := expectInt(event.before, "c_int", 4); err != nil { + return err + } + if err := expectInt(event.after, "c_int", 0); err != nil { + return err + } + c.checks["tp_int_update"] = true + case event.op == "d" && id == 2: + if err := expectInt(event.before, "c_int", 0); err != nil { + return err + } + c.checks["tp_int_delete"] = true + } + return nil +} + +func (c *coverage) observeUnsignedInt(event *rowEvent, id int64) error { + if event.op != "c" { + return nil + } + switch id { + case 2: + if err := expectInt(event.after, "c_unsigned_tinyint", 1); err != nil { + return err + } + if err := expectInt(event.after, "c_unsigned_bigint", 5); err != nil { + return err + } + c.checks["tp_unsigned_normal_insert"] = true + case 3: + if err := expectInt(event.after, "c_unsigned_tinyint", 255); err != nil { + return err + } + if err := expectInt(event.after, "c_unsigned_int", 4294967295); err != nil { + return err + } + if err := expectInt(event.after, "c_unsigned_bigint", -1); err != nil { + return err + } + c.checks["tp_unsigned_max_insert"] = true + } + return nil +} + +func (c *coverage) observeReal(event *rowEvent, id int64) error { + switch { + case event.op == "c" && id == 2: + if err := expectFloat(event.after, "c_double", 2020.0303, 0.000001); err != nil { + return err + } + if err := expectFloat(event.after, "c_decimal", 2020, 0.000001); err != nil { + return err + } + if err := expectFloat(event.after, "c_decimal_2", 2021.1208, 0.000001); err != nil { + return err + } + c.checks["tp_real_insert"] = true + case event.op == "u" && id == 2: + if err := expectFloat(event.before, "c_double", 2020.0303, 0.000001); err != nil { + return err + } + if err := expectFloat(event.after, "c_double", 2.333, 0.000001); err != nil { + return err + } + c.checks["tp_real_update"] = true + } + return nil +} + +func (c *coverage) observeTime(event *rowEvent, id int64) error { + switch { + case event.op == "c" && id == 2: + if err := expectInt(event.after, "c_date", 18312); err != nil { + return err + } + if err := expectString(event.after, "c_timestamp", "2020-02-20T02:20:20Z"); err != nil { + return err + } + if err := expectInt(event.after, "c_year", 2020); err != nil { + return err + } + c.checks["tp_time_insert"] = true + case event.op == "u" && id == 2: + if err := expectInt(event.before, "c_year", 2020); err != nil { + return err + } + if err := expectInt(event.after, "c_year", 2022); err != nil { + return err + } + c.checks["tp_time_update"] = true + } + return nil +} + +func (c *coverage) observeText(event *rowEvent, id int64) error { + if event.op == "u" && id == 2 { + if err := expectString(event.after, "c_text", "89504E470D0A1A0B"); err != nil { + return err + } + c.checks["tp_text_update"] = true + } + return nil +} + +func (c *coverage) observeBlob(event *rowEvent, id int64) error { + if event.op == "u" && id == 2 { + if err := expectString(event.after, "c_blob", "iVBORw0KGgs="); err != nil { + return err + } + c.checks["tp_blob_update"] = true + } + return nil +} + +func (c *coverage) observeCharBinary(event *rowEvent, id int64) error { + if event.op == "u" && id == 2 { + if err := expectString(event.after, "c_varchar", "89504E470D0A1A0B"); err != nil { + return err + } + c.checks["tp_char_binary_update"] = true + } + return nil +} + +func (c *coverage) observeOther(event *rowEvent, id int64) error { + if event.op == "u" && id == 3 { + if err := expectString(event.before, "c_enum", "b"); err != nil { + return err + } + if err := expectString(event.after, "c_enum", "c"); err != nil { + return err + } + if err := expectString(event.after, "c_set", "b,c"); err != nil { + return err + } + jsonValue, ok := stringValue(event.after["c_json"]) + if !ok { + return fmt.Errorf("unexpected c_json value: %v", event.after["c_json"]) + } + var jsonObject map[string]any + if err := json.Unmarshal([]byte(jsonValue), &jsonObject); err != nil { + return fmt.Errorf("decode c_json value: %w", err) + } + if jsonObject["key3"] != "123" { + return fmt.Errorf("unexpected c_json key3: %v", jsonObject["key3"]) + } + c.checks["tp_other_update"] = true + } + return nil +} + +func (c *coverage) observeAccount(event *rowEvent, id int64) error { + if event.op == "d" && id == 12 { + if err := expectInt(event.before, "account_id", 35); err != nil { + return err + } + c.checks["tp_account_delete"] = true + } + return nil +} + +func (c *coverage) done() bool { + for _, check := range requiredChecks { + if !c.checks[check] { + return false + } + } + return true +} + +func (c *coverage) summary() string { + var checks []string + for check := range c.checks { + checks = append(checks, check) + } + sort.Strings(checks) + var tables []string + for table := range c.tables { + tables = append(tables, table) + } + sort.Strings(tables) + return fmt.Sprintf("events=%d tables=%v checks=%v", c.count, tables, checks) +} + +func decodeConfluentAvro( + ctx context.Context, + registry *schemaRegistryClient, + data []byte, +) (map[string]any, error) { + if len(data) < 5 { + return nil, fmt.Errorf("message is shorter than Confluent Avro header: %d bytes", len(data)) + } + if data[0] != 0 { + return nil, fmt.Errorf("unexpected Confluent Avro magic byte: %d", data[0]) + } + + schemaID := int(binary.BigEndian.Uint32(data[1:5])) + codec, err := registry.codecByID(ctx, schemaID) + if err != nil { + return nil, err + } + native, _, err := codec.NativeFromBinary(data[5:]) + if err != nil { + return nil, err + } + result, ok := native.(map[string]any) + if !ok { + return nil, fmt.Errorf("decoded Avro payload is %T, expected record", native) + } + return result, nil +} + +func (c *schemaRegistryClient) codecByID(ctx context.Context, id int) (*goavro.Codec, error) { + if codec, ok := c.codecs[id]; ok { + return codec, nil + } + + body, err := c.get(ctx, fmt.Sprintf("/schemas/ids/%d", id)) + if err != nil { + return nil, err + } + var response struct { + Schema string `json:"schema"` + } + if err := json.Unmarshal(body, &response); err != nil { + return nil, err + } + if response.Schema == "" { + return nil, fmt.Errorf("schema registry returned empty schema for id %d", id) + } + + codec, err := goavro.NewCodecWithOptions(response.Schema, &goavro.CodecOption{EnableStringNull: false}) + if err != nil { + return nil, err + } + c.codecs[id] = codec + return codec, nil +} + +func (c *schemaRegistryClient) ensureSubject(ctx context.Context, subject string) error { + if _, err := c.get(ctx, "/subjects/"+url.PathEscape(subject)+"/versions/latest"); err != nil { + return fmt.Errorf("schema registry subject %s is missing: %w", subject, err) + } + return nil +} + +func (c *schemaRegistryClient) get(ctx context.Context, path string) ([]byte, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+path, nil) + if err != nil { + return nil, err + } + + response, err := c.client.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + body, err := io.ReadAll(io.LimitReader(response.Body, 1<<20)) + if err != nil { + return nil, err + } + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GET %s returned %s: %s", path, response.Status, strings.TrimSpace(string(body))) + } + return body, nil +} + +func expectInt(row map[string]any, field string, expected int64) error { + value, ok := intValue(row[field]) + if !ok { + return fmt.Errorf("%s is not an integer: %v", field, row[field]) + } + if value != expected { + return fmt.Errorf("unexpected %s: got %d, want %d", field, value, expected) + } + return nil +} + +func expectFloat(row map[string]any, field string, expected float64, tolerance float64) error { + value, ok := floatValue(row[field]) + if !ok { + return fmt.Errorf("%s is not a float: %v", field, row[field]) + } + if math.Abs(value-expected) > tolerance { + return fmt.Errorf("unexpected %s: got %f, want %f", field, value, expected) + } + return nil +} + +func expectString(row map[string]any, field string, expected string) error { + value, ok := stringValue(row[field]) + if !ok { + return fmt.Errorf("%s is not a string: %v", field, row[field]) + } + if value != expected { + return fmt.Errorf("unexpected %s: got %q, want %q", field, value, expected) + } + return nil +} + +func asMap(value any) (map[string]any, bool) { + unwrapped := unwrapUnion(value) + result, ok := unwrapped.(map[string]any) + return result, ok +} + +func stringValue(value any) (string, bool) { + v, ok := unwrapUnion(value).(string) + return v, ok +} + +func intValue(value any) (int64, bool) { + switch v := unwrapUnion(value).(type) { + case int: + return int64(v), true + case int32: + return int64(v), true + case int64: + return v, true + case uint64: + return int64(v), true + default: + return 0, false + } +} + +func floatValue(value any) (float64, bool) { + switch v := unwrapUnion(value).(type) { + case float32: + return float64(v), true + case float64: + return v, true + case int32: + return float64(v), true + case int64: + return float64(v), true + default: + return 0, false + } +} + +func unwrapUnion(value any) any { + result, ok := value.(map[string]any) + if !ok || len(result) != 1 { + return value + } + for _, branchValue := range result { + return branchValue + } + return value +} diff --git a/tests/integration_tests/run_light_it_in_ci.sh b/tests/integration_tests/run_light_it_in_ci.sh index 1dd75a780f..c16268cdea 100755 --- a/tests/integration_tests/run_light_it_in_ci.sh +++ b/tests/integration_tests/run_light_it_in_ci.sh @@ -99,7 +99,7 @@ kafka_groups=( # G13 'cli_with_auth fail_over_ddl_N maintainer_failover_when_operator' # G14 - 'kafka_simple_basic avro_basic fail_over_ddl_O update_changefeed_check_config' + 'kafka_simple_basic avro_basic debezium_avro fail_over_ddl_O update_changefeed_check_config' # G15 'kafka_simple_basic_avro split_region autorandom gc_safepoint kafka_log_info' ) From 4e9e16c3806a5416ee58922bc7cb082c36f4a688 Mon Sep 17 00:00:00 2001 From: wk989898 Date: Mon, 22 Jun 2026 07:09:42 +0000 Subject: [PATCH 03/10] update Signed-off-by: wk989898 --- .../debezium_avro/verify/main.go | 121 +++++++++++------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/tests/integration_tests/debezium_avro/verify/main.go b/tests/integration_tests/debezium_avro/verify/main.go index bf0305b258..bfe9c036c1 100644 --- a/tests/integration_tests/debezium_avro/verify/main.go +++ b/tests/integration_tests/debezium_avro/verify/main.go @@ -29,8 +29,8 @@ import ( "strings" "time" + "github.com/IBM/sarama" "github.com/linkedin/goavro/v2" - "github.com/segmentio/kafka-go" ) var requiredChecks = []string{ @@ -68,6 +68,12 @@ type rowEvent struct { after map[string]any } +type kafkaMessage struct { + offset int64 + key []byte + value []byte +} + type coverage struct { checks map[string]bool tables map[string]bool @@ -103,59 +109,78 @@ func run() error { codecs: make(map[int]*goavro.Codec), } - reader := kafka.NewReader(kafka.ReaderConfig{ - Brokers: []string{*kafkaAddr}, - Topic: *topic, - Partition: 0, - MinBytes: 1, - MaxBytes: 10e6, - MaxWait: 500 * time.Millisecond, - ReadBackoffMin: 100 * time.Millisecond, - ReadBackoffMax: time.Second, - }) - defer reader.Close() - if err := reader.SetOffset(kafka.FirstOffset); err != nil { - return fmt.Errorf("set kafka reader offset: %w", err) + saramaConfig := sarama.NewConfig() + saramaConfig.Version = sarama.V2_4_0_0 + saramaConfig.Consumer.Return.Errors = true + saramaConfig.Net.DialTimeout = 10 * time.Second + saramaConfig.Net.ReadTimeout = 10 * time.Second + saramaConfig.Net.WriteTimeout = 10 * time.Second + + consumer, err := sarama.NewConsumer([]string{*kafkaAddr}, saramaConfig) + if err != nil { + return fmt.Errorf("create kafka consumer: %w", err) + } + defer consumer.Close() + + partitionConsumer, err := consumer.ConsumePartition(*topic, 0, sarama.OffsetOldest) + if err != nil { + return fmt.Errorf("consume kafka partition: %w", err) } + defer partitionConsumer.Close() result := &coverage{ checks: make(map[string]bool), tables: make(map[string]bool), } var lastErr error + messagesCh := partitionConsumer.Messages() + errorsCh := partitionConsumer.Errors() for { - message, err := reader.ReadMessage(ctx) - if err != nil { - if ctx.Err() != nil { + select { + case <-ctx.Done(): + if lastErr != nil { + return fmt.Errorf("timed out waiting for Debezium Avro row events on topic %s; coverage: %s; last error: %w", *topic, result.summary(), lastErr) + } + return fmt.Errorf("timed out waiting for Debezium Avro row events on topic %s; coverage: %s", *topic, result.summary()) + case err, ok := <-errorsCh: + if ok { + lastErr = err + } else { + errorsCh = nil + } + case message, ok := <-messagesCh: + if !ok { if lastErr != nil { - return fmt.Errorf("timed out waiting for Debezium Avro row events on topic %s; coverage: %s; last error: %w", *topic, result.summary(), lastErr) + return fmt.Errorf("kafka message channel closed for topic %s; coverage: %s; last error: %w", *topic, result.summary(), lastErr) } - return fmt.Errorf("timed out waiting for Debezium Avro row events on topic %s; coverage: %s", *topic, result.summary()) + return fmt.Errorf("kafka message channel closed for topic %s; coverage: %s", *topic, result.summary()) } - lastErr = err - continue - } - event, err := decodeRowEvent(ctx, registry, message) - if err != nil { - lastErr = err - continue - } - if event == nil { - continue - } - if err := result.observe(event); err != nil { - return err - } - if result.done() { - if err := registry.ensureSubject(ctx, *topic+"-key"); err != nil { - return err + event, err := decodeRowEvent(ctx, registry, kafkaMessage{ + offset: message.Offset, + key: message.Key, + value: message.Value, + }) + if err != nil { + lastErr = err + continue + } + if event == nil { + continue } - if err := registry.ensureSubject(ctx, *topic+"-value"); err != nil { + if err := result.observe(event); err != nil { return err } - fmt.Printf("verified %d Debezium Confluent Avro row events from topic %s; coverage: %s\n", result.count, *topic, result.summary()) - return nil + if result.done() { + if err := registry.ensureSubject(ctx, *topic+"-key"); err != nil { + return err + } + if err := registry.ensureSubject(ctx, *topic+"-value"); err != nil { + return err + } + fmt.Printf("verified %d Debezium Confluent Avro row events from topic %s; coverage: %s\n", result.count, *topic, result.summary()) + return nil + } } } } @@ -163,15 +188,15 @@ func run() error { func decodeRowEvent( ctx context.Context, registry *schemaRegistryClient, - message kafka.Message, + message kafkaMessage, ) (*rowEvent, error) { - key, err := decodeConfluentAvro(ctx, registry, message.Key) + key, err := decodeConfluentAvro(ctx, registry, message.key) if err != nil { - return nil, fmt.Errorf("decode key at offset %d: %w", message.Offset, err) + return nil, fmt.Errorf("decode key at offset %d: %w", message.offset, err) } - value, err := decodeConfluentAvro(ctx, registry, message.Value) + value, err := decodeConfluentAvro(ctx, registry, message.value) if err != nil { - return nil, fmt.Errorf("decode value at offset %d: %w", message.Offset, err) + return nil, fmt.Errorf("decode value at offset %d: %w", message.offset, err) } op, ok := value["op"].(string) @@ -181,18 +206,18 @@ func decodeRowEvent( source, ok := asMap(value["source"]) if !ok { - return nil, fmt.Errorf("source is not a record at offset %d: %T", message.Offset, value["source"]) + return nil, fmt.Errorf("source is not a record at offset %d: %T", message.offset, value["source"]) } if source["db"] != "test" { - return nil, fmt.Errorf("unexpected source db at offset %d: %v", message.Offset, source["db"]) + return nil, fmt.Errorf("unexpected source db at offset %d: %v", message.offset, source["db"]) } table, ok := stringValue(source["table"]) if !ok { - return nil, fmt.Errorf("source table is not a string at offset %d: %v", message.Offset, source["table"]) + return nil, fmt.Errorf("source table is not a string at offset %d: %v", message.offset, source["table"]) } event := &rowEvent{ - offset: message.Offset, + offset: message.offset, op: op, table: table, key: key, From 0c4c082d1b1f045df2974c4974745591ac8a32f2 Mon Sep 17 00:00:00 2001 From: wk989898 Date: Mon, 22 Jun 2026 10:15:38 +0000 Subject: [PATCH 04/10] update Signed-off-by: wk989898 --- pkg/sink/codec/builder.go | 3 + pkg/sink/codec/debezium/avro.go | 10 + pkg/sink/codec/debezium/avro_decoder.go | 578 ++++++++++++++++ pkg/sink/codec/debezium/avro_test.go | 92 +++ pkg/sink/codec/debezium/codec.go | 12 + .../debezium_avro/conf/diff_config.toml | 29 + tests/integration_tests/debezium_avro/run.sh | 15 +- .../debezium_avro/verify/main.go | 652 ------------------ 8 files changed, 731 insertions(+), 660 deletions(-) create mode 100644 pkg/sink/codec/debezium/avro_decoder.go create mode 100644 tests/integration_tests/debezium_avro/conf/diff_config.toml delete mode 100644 tests/integration_tests/debezium_avro/verify/main.go diff --git a/pkg/sink/codec/builder.go b/pkg/sink/codec/builder.go index b826898115..ea9d8da21c 100644 --- a/pkg/sink/codec/builder.go +++ b/pkg/sink/codec/builder.go @@ -69,6 +69,9 @@ func NewEventDecoder( case config.ProtocolSimple: return simple.NewDecoder(ctx, codecConfig, upstreamTiDB) case config.ProtocolDebezium: + if codecConfig.AvroConfluentSchemaRegistry != "" { + return debezium.NewAvroDecoder(ctx, codecConfig, idx, upstreamTiDB) + } return debezium.NewDecoder(codecConfig, idx, upstreamTiDB), nil default: } diff --git a/pkg/sink/codec/debezium/avro.go b/pkg/sink/codec/debezium/avro.go index 169d683075..8b4326fb81 100644 --- a/pkg/sink/codec/debezium/avro.go +++ b/pkg/sink/codec/debezium/avro.go @@ -29,6 +29,9 @@ import ( const ( debeziumAvroKeySchemaSuffix = "-key" debeziumAvroValueSchemaSuffix = "-value" + + debeziumAvroConnectFieldKey = "connect.field" + debeziumAvroTiDBTypeKey = "tidb_type" ) type debeziumAvroMessage struct { @@ -45,6 +48,7 @@ type debeziumConnectSchema struct { Fields []*debeziumConnectSchema `json:"fields"` Items *debeziumConnectSchema `json:"items"` Parameters map[string]string `json:"parameters"` + TiDBType string `json:"tidb_type"` } type debeziumAvroSchemaConverter struct { @@ -234,6 +238,12 @@ func (c *debeziumAvroSchemaConverter) toAvroSchema( "name": fieldName, "type": fieldType, } + if fieldSchema.Field != "" { + field[debeziumAvroConnectFieldKey] = fieldSchema.Field + } + if fieldSchema.TiDBType != "" { + field[debeziumAvroTiDBTypeKey] = fieldSchema.TiDBType + } if fieldSchema.Optional { field["type"] = []any{"null", fieldType} field["default"] = nil diff --git a/pkg/sink/codec/debezium/avro_decoder.go b/pkg/sink/codec/debezium/avro_decoder.go new file mode 100644 index 0000000000..3669cf2d39 --- /dev/null +++ b/pkg/sink/codec/debezium/avro_decoder.go @@ -0,0 +1,578 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package debezium + +import ( + "context" + "database/sql" + "encoding/binary" + "encoding/json" + "io" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/linkedin/goavro/v2" + "github.com/pingcap/log" + commonEvent "github.com/pingcap/ticdc/pkg/common/event" + "github.com/pingcap/ticdc/pkg/errors" + codecavro "github.com/pingcap/ticdc/pkg/sink/codec/avro" + "github.com/pingcap/ticdc/pkg/sink/codec/common" + "go.uber.org/zap" +) + +const confluentAvroHeaderLen = 5 + +type avroDecoder struct { + ctx context.Context + registryURL string + httpClient *http.Client + inner *decoder + + mu sync.RWMutex + schemas map[int]*registeredDebeziumAvroSchema +} + +type registeredDebeziumAvroSchema struct { + schema any + namedSchemas map[string]any + codec *goavro.Codec +} + +// NewAvroDecoder returns a Debezium decoder for Confluent Avro wire-format +// messages. It decodes the Avro payload and then delegates Debezium event +// semantics to the JSON decoder. +func NewAvroDecoder( + ctx context.Context, + config *common.Config, + idx int, + db *sql.DB, +) (common.Decoder, error) { + registryURL := strings.TrimRight(config.AvroConfluentSchemaRegistry, "/") + if registryURL == "" { + return nil, errors.ErrAvroSchemaAPIError.GenWithStackByArgs("schema registry URI is empty") + } + + return &avroDecoder{ + ctx: ctx, + registryURL: registryURL, + httpClient: http.DefaultClient, + inner: NewDecoder(config, idx, db).(*decoder), + schemas: make(map[int]*registeredDebeziumAvroSchema), + }, nil +} + +func (d *avroDecoder) AddKeyValue(key, value []byte) { + keyJSON, err := d.toDebeziumJSON(key) + if err != nil { + log.Panic("decode Debezium Avro key failed", zap.Error(err), zap.Int("keySize", len(key))) + } + valueJSON, err := d.toDebeziumJSON(value) + if err != nil { + log.Panic("decode Debezium Avro value failed", zap.Error(err), zap.Int("valueSize", len(value))) + } + d.inner.AddKeyValue(keyJSON, valueJSON) +} + +func (d *avroDecoder) HasNext() (common.MessageType, bool) { + return d.inner.HasNext() +} + +func (d *avroDecoder) NextResolvedEvent() uint64 { + return d.inner.NextResolvedEvent() +} + +func (d *avroDecoder) NextDMLEvent() *commonEvent.DMLEvent { + return d.inner.NextDMLEvent() +} + +func (d *avroDecoder) NextDDLEvent() *commonEvent.DDLEvent { + return d.inner.NextDDLEvent() +} + +func (d *avroDecoder) toDebeziumJSON(data []byte) ([]byte, error) { + payload, schema, err := d.decodeConfluentAvroMessage(data) + if err != nil { + return nil, err + } + message := map[string]any{ + "schema": schema, + "payload": payload, + } + result, err := json.Marshal(message) + if err != nil { + return nil, errors.WrapError(errors.ErrDebeziumInvalidMessage, err) + } + return result, nil +} + +func (d *avroDecoder) decodeConfluentAvroMessage(data []byte) (any, map[string]any, error) { + if len(data) == 0 { + return nil, nil, errors.ErrDebeziumEmptyValueMessage.GenWithStackByArgs() + } + if len(data) < confluentAvroHeaderLen { + return nil, nil, errors.ErrAvroInvalidMessage.GenWithStackByArgs("confluent header is too short") + } + if data[0] != 0 { + return nil, nil, errors.ErrAvroInvalidMessage.GenWithStackByArgs("invalid confluent magic byte") + } + + schemaID := int(binary.BigEndian.Uint32(data[1:confluentAvroHeaderLen])) + registeredSchema, err := d.getSchema(schemaID) + if err != nil { + return nil, nil, err + } + + native, _, err := registeredSchema.codec.NativeFromBinary(data[confluentAvroHeaderLen:]) + if err != nil { + return nil, nil, errors.WrapError(errors.ErrAvroInvalidMessage, err) + } + + payload, err := avroNativeToConnectPayload( + registeredSchema.schema, + native, + registeredSchema.namedSchemas, + ) + if err != nil { + return nil, nil, err + } + schema, err := avroSchemaToConnectSchema( + registeredSchema.schema, + "", + nil, + registeredSchema.namedSchemas, + ) + if err != nil { + return nil, nil, err + } + return payload, schema, nil +} + +func (d *avroDecoder) getSchema(schemaID int) (*registeredDebeziumAvroSchema, error) { + d.mu.RLock() + schema, ok := d.schemas[schemaID] + d.mu.RUnlock() + if ok { + return schema, nil + } + + uri := d.registryURL + "/schemas/ids/" + strconv.Itoa(schemaID) + req, err := http.NewRequestWithContext(d.ctx, http.MethodGet, uri, nil) + if err != nil { + return nil, errors.WrapError(errors.ErrAvroSchemaAPIError, err) + } + req.Header.Add( + "Accept", + "application/vnd.schemaregistry.v1+json, application/vnd.schemaregistry+json, "+ + "application/json", + ) + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, errors.WrapError(errors.ErrAvroSchemaAPIError, err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.WrapError(errors.ErrAvroSchemaAPIError, err) + } + if resp.StatusCode != http.StatusOK { + return nil, errors.ErrAvroSchemaAPIError.GenWithStackByArgs( + "failed to query schema id " + strconv.Itoa(schemaID)) + } + + var lookupResp struct { + Schema string `json:"schema"` + } + if err := json.Unmarshal(body, &lookupResp); err != nil { + return nil, errors.WrapError(errors.ErrAvroSchemaAPIError, err) + } + + codec, err := codecavro.GenCodec(lookupResp.Schema) + if err != nil { + return nil, errors.WrapError(errors.ErrAvroSchemaAPIError, err) + } + + decoder := json.NewDecoder(strings.NewReader(lookupResp.Schema)) + decoder.UseNumber() + var schemaDef any + if err := decoder.Decode(&schemaDef); err != nil { + return nil, errors.WrapError(errors.ErrAvroSchemaAPIError, err) + } + + namedSchemas := make(map[string]any) + collectAvroNamedSchemas(schemaDef, namedSchemas) + + schema = ®isteredDebeziumAvroSchema{ + schema: schemaDef, + namedSchemas: namedSchemas, + codec: codec, + } + d.mu.Lock() + d.schemas[schemaID] = schema + d.mu.Unlock() + return schema, nil +} + +func avroNativeToConnectPayload(schema any, value any, namedSchemas map[string]any) (any, error) { + switch typedSchema := schema.(type) { + case []any: + if value == nil { + return nil, nil + } + branchSchema, branchValue, err := avroUnionBranch(typedSchema, value) + if err != nil { + return nil, err + } + return avroNativeToConnectPayload(branchSchema, branchValue, namedSchemas) + case map[string]any: + rawType, ok := typedSchema["type"] + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro schema is missing type") + } + if unionType, ok := rawType.([]any); ok { + return avroNativeToConnectPayload(unionType, value, namedSchemas) + } + typeName, ok := rawType.(string) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro schema type is invalid") + } + switch typeName { + case "record": + valueMap, ok := value.(map[string]any) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro record payload is invalid") + } + fields, ok := typedSchema["fields"].([]any) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro record schema is missing fields") + } + result := make(map[string]any, len(fields)) + for _, rawField := range fields { + field, ok := rawField.(map[string]any) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro field schema is invalid") + } + avroFieldName, ok := field["name"].(string) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro field is missing name") + } + connectFieldName := avroConnectFieldName(field, avroFieldName) + fieldValue, err := avroNativeToConnectPayload( + field["type"], + valueMap[avroFieldName], + namedSchemas, + ) + if err != nil { + return nil, err + } + result[connectFieldName] = fieldValue + } + return result, nil + case "array": + items, ok := typedSchema["items"] + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro array schema is missing items") + } + values, ok := value.([]any) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro array payload is invalid") + } + result := make([]any, 0, len(values)) + for _, item := range values { + itemValue, err := avroNativeToConnectPayload(items, item, namedSchemas) + if err != nil { + return nil, err + } + result = append(result, itemValue) + } + return result, nil + default: + return value, nil + } + case string: + if namedSchema, ok := namedSchemas[typedSchema]; ok { + return avroNativeToConnectPayload(namedSchema, value, namedSchemas) + } + return value, nil + default: + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro schema is invalid") + } +} + +func avroSchemaToConnectSchema( + schema any, + fieldName string, + fieldMeta map[string]any, + namedSchemas map[string]any, +) (map[string]any, error) { + switch typedSchema := schema.(type) { + case []any: + branchSchema, _, err := avroNonNullUnionBranch(typedSchema) + if err != nil { + return nil, err + } + connectSchema, err := avroSchemaToConnectSchema( + branchSchema, + fieldName, + fieldMeta, + namedSchemas, + ) + if err != nil { + return nil, err + } + connectSchema["optional"] = true + return connectSchema, nil + case map[string]any: + rawType, ok := typedSchema["type"] + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro schema is missing type") + } + if unionType, ok := rawType.([]any); ok { + return avroSchemaToConnectSchema(unionType, fieldName, fieldMeta, namedSchemas) + } + typeName, ok := rawType.(string) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro schema type is invalid") + } + switch typeName { + case "record": + connectSchema := newConnectSchema("struct", false, fieldName, typedSchema, fieldMeta) + fields, ok := typedSchema["fields"].([]any) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro record schema is missing fields") + } + connectFields := make([]any, 0, len(fields)) + for _, rawField := range fields { + field, ok := rawField.(map[string]any) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro field schema is invalid") + } + avroFieldName, ok := field["name"].(string) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro field is missing name") + } + fieldSchema, err := avroSchemaToConnectSchema( + field["type"], + avroConnectFieldName(field, avroFieldName), + field, + namedSchemas, + ) + if err != nil { + return nil, err + } + connectFields = append(connectFields, fieldSchema) + } + connectSchema["fields"] = connectFields + return connectSchema, nil + case "array": + connectSchema := newConnectSchema("array", false, fieldName, typedSchema, fieldMeta) + items, ok := typedSchema["items"] + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro array schema is missing items") + } + connectItems, err := avroSchemaToConnectSchema(items, "", nil, namedSchemas) + if err != nil { + return nil, err + } + connectSchema["items"] = connectItems + return connectSchema, nil + default: + connectType, err := avroPrimitiveToConnectType(typeName, typedSchema) + if err != nil { + return nil, err + } + return newConnectSchema(connectType, false, fieldName, typedSchema, fieldMeta), nil + } + case string: + if namedSchema, ok := namedSchemas[typedSchema]; ok { + return avroSchemaToConnectSchema(namedSchema, fieldName, fieldMeta, namedSchemas) + } + connectType, err := avroPrimitiveToConnectType(typedSchema, nil) + if err != nil { + return nil, err + } + return newConnectSchema(connectType, false, fieldName, nil, fieldMeta), nil + default: + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro schema is invalid") + } +} + +func collectAvroNamedSchemas(schema any, namedSchemas map[string]any) { + switch typedSchema := schema.(type) { + case []any: + for _, branch := range typedSchema { + collectAvroNamedSchemas(branch, namedSchemas) + } + case map[string]any: + rawType := typedSchema["type"] + if unionType, ok := rawType.([]any); ok { + collectAvroNamedSchemas(unionType, namedSchemas) + return + } + typeName, _ := rawType.(string) + switch typeName { + case "record": + name := avroBranchName(typedSchema) + if name != "" { + namedSchemas[name] = typedSchema + } + fields, _ := typedSchema["fields"].([]any) + for _, rawField := range fields { + field, ok := rawField.(map[string]any) + if !ok { + continue + } + collectAvroNamedSchemas(field["type"], namedSchemas) + } + case "array": + collectAvroNamedSchemas(typedSchema["items"], namedSchemas) + } + } +} + +func newConnectSchema( + connectType string, + optional bool, + fieldName string, + schemaMeta map[string]any, + fieldMeta map[string]any, +) map[string]any { + connectSchema := map[string]any{ + "type": connectType, + "optional": optional, + } + if fieldName != "" { + connectSchema["field"] = fieldName + } + addConnectSchemaMetadata(connectSchema, schemaMeta) + addConnectFieldMetadata(connectSchema, fieldMeta) + return connectSchema +} + +func addConnectSchemaMetadata(connectSchema map[string]any, schemaMeta map[string]any) { + if schemaMeta == nil { + return + } + if name, ok := schemaMeta["connect.name"].(string); ok && name != "" { + connectSchema["name"] = name + } + if version, ok := schemaMeta["connect.version"]; ok { + connectSchema["version"] = version + } + if parameters, ok := schemaMeta["connect.parameters"].(map[string]any); ok { + connectSchema["parameters"] = parameters + } + if tidbType, ok := schemaMeta[debeziumAvroTiDBTypeKey].(string); ok && tidbType != "" { + connectSchema[debeziumAvroTiDBTypeKey] = tidbType + } +} + +func addConnectFieldMetadata(connectSchema map[string]any, fieldMeta map[string]any) { + if fieldMeta == nil { + return + } + if tidbType, ok := fieldMeta[debeziumAvroTiDBTypeKey].(string); ok && tidbType != "" { + connectSchema[debeziumAvroTiDBTypeKey] = tidbType + } +} + +func avroPrimitiveToConnectType(avroType string, schemaMeta map[string]any) (string, error) { + if schemaMeta != nil { + if connectType, ok := schemaMeta["connect.type"].(string); ok && connectType != "" { + return connectType, nil + } + } + switch avroType { + case "boolean": + return "boolean", nil + case "string": + return "string", nil + case "bytes": + return "bytes", nil + case "int": + return "int32", nil + case "long": + return "int64", nil + case "float": + return "float", nil + case "double": + return "double", nil + default: + return "", errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("unsupported avro type " + avroType) + } +} + +func avroUnionBranch(union []any, value any) (any, any, error) { + if value == nil { + return nil, nil, nil + } + if branchValueMap, ok := value.(map[string]any); ok && len(branchValueMap) == 1 { + for branchName, branchValue := range branchValueMap { + for _, branchSchema := range union { + if avroBranchName(branchSchema) == branchName { + return branchSchema, branchValue, nil + } + } + } + } + + branchSchema, _, err := avroNonNullUnionBranch(union) + if err != nil { + return nil, nil, err + } + return branchSchema, value, nil +} + +func avroNonNullUnionBranch(union []any) (any, bool, error) { + for _, branch := range union { + if avroBranchName(branch) != "null" { + return branch, true, nil + } + } + return nil, false, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro union has no non-null branch") +} + +func avroBranchName(schema any) string { + switch typedSchema := schema.(type) { + case string: + return typedSchema + case map[string]any: + typeName, _ := typedSchema["type"].(string) + switch typeName { + case "record": + name, _ := typedSchema["name"].(string) + namespace, _ := typedSchema["namespace"].(string) + if namespace != "" && name != "" { + return namespace + "." + name + } + return name + case "array": + return "array" + default: + return typeName + } + default: + return "" + } +} + +func avroConnectFieldName(field map[string]any, fallback string) string { + if fieldName, ok := field[debeziumAvroConnectFieldKey].(string); ok && fieldName != "" { + return fieldName + } + return fallback +} diff --git a/pkg/sink/codec/debezium/avro_test.go b/pkg/sink/codec/debezium/avro_test.go index 29c664c6f1..e2f2b3fdaa 100644 --- a/pkg/sink/codec/debezium/avro_test.go +++ b/pkg/sink/codec/debezium/avro_test.go @@ -98,6 +98,98 @@ func TestDebeziumConfluentAvroEncodeRowEvent(t *testing.T) { require.Equal(t, "dbserver1", source["name"]) } +func TestDebeziumConfluentAvroDecodeRowEvent(t *testing.T) { + ctx := context.Background() + _, err := avro.SetupEncoderAndSchemaRegistry4Testing( + ctx, + common.NewConfig(config.ProtocolAvro), + ) + require.NoError(t, err) + defer avro.TeardownEncoderAndSchemaRegistry4Testing() + + helper := NewSQLTestHelper(t, "foo", ` + create table foo( + id int primary key, + name varchar(16), + v bigint null + )`) + defer helper.Close() + + dmls := helper.helper.DML2Event("test", "foo", "insert into foo values (1, 'alice', null)") + row, ok := dmls.GetNextRow() + require.True(t, ok) + + cfg := common.NewConfig(config.ProtocolDebezium) + cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" + cfg.EnableTiDBExtension = true + cfg.TimeZone = time.UTC + + encoder, err := NewAvroBatchEncoder(ctx, cfg, "dbserver1") + require.NoError(t, err) + require.NoError(t, encoder.AppendRowChangedEvent(ctx, "dbserver1.test.foo", &commonEvent.RowEvent{ + TableInfo: helper.tableInfo, + CommitTs: 1, + Event: row, + ColumnSelector: columnselector.NewDefaultColumnSelector(), + Callback: func() {}, + })) + + messages := encoder.Build() + require.Len(t, messages, 1) + + decoder, err := NewAvroDecoder(ctx, cfg, 0, nil) + require.NoError(t, err) + decoder.AddKeyValue(messages[0].Key, messages[0].Value) + + messageType, hasNext := decoder.HasNext() + require.True(t, hasNext) + require.Equal(t, common.MessageTypeRow, messageType) + + decoded := decoder.NextDMLEvent() + require.Equal(t, "test", decoded.TableInfo.GetSchemaName()) + require.Equal(t, "foo", decoded.TableInfo.GetTableName()) + + change, ok := decoded.GetNextRow() + require.True(t, ok) + common.CompareRow(t, row, helper.tableInfo, change, decoded.TableInfo) +} + +func TestDebeziumConfluentAvroDecodeDDLEvent(t *testing.T) { + ctx := context.Background() + _, err := avro.SetupEncoderAndSchemaRegistry4Testing( + ctx, + common.NewConfig(config.ProtocolAvro), + ) + require.NoError(t, err) + defer avro.TeardownEncoderAndSchemaRegistry4Testing() + + cfg := common.NewConfig(config.ProtocolDebezium) + cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" + cfg.EnableTiDBExtension = true + cfg.TimeZone = time.UTC + + encoder, err := NewAvroBatchEncoder(ctx, cfg, "dbserver1") + require.NoError(t, err) + + routedDDL := common.NewRoutedDDLEvent4Test() + message, err := encoder.EncodeDDLEvent(routedDDL) + require.NoError(t, err) + require.NotNil(t, message) + + decoder, err := NewAvroDecoder(ctx, cfg, 0, nil) + require.NoError(t, err) + decoder.AddKeyValue(message.Key, message.Value) + + messageType, hasNext := decoder.HasNext() + require.True(t, hasNext) + require.Equal(t, common.MessageTypeDDL, messageType) + + decoded := decoder.NextDDLEvent() + require.Equal(t, "target_db", decoded.SchemaName) + require.Equal(t, "target_table", decoded.TableName) + require.Equal(t, routedDDL.Query, decoded.Query) +} + func decodeConfluentAvroForTest(t *testing.T, data []byte) map[string]any { t.Helper() diff --git a/pkg/sink/codec/debezium/codec.go b/pkg/sink/codec/debezium/codec.go index 9f7853bb33..81707b6a99 100644 --- a/pkg/sink/codec/debezium/codec.go +++ b/pkg/sink/codec/debezium/codec.go @@ -843,6 +843,18 @@ func (c *dbzCodec) writeSourceSchema(writer *util.JSONWriter) { writer.WriteBoolField("optional", true) writer.WriteStringField("field", "query") }) + if c.config.EnableTiDBExtension { + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int64") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "commit_ts") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "cluster_id") + }) + } }) writer.WriteBoolField("optional", false) writer.WriteStringField("name", "io.debezium.connector.mysql.Source") diff --git a/tests/integration_tests/debezium_avro/conf/diff_config.toml b/tests/integration_tests/debezium_avro/conf/diff_config.toml new file mode 100644 index 0000000000..c241f97f87 --- /dev/null +++ b/tests/integration_tests/debezium_avro/conf/diff_config.toml @@ -0,0 +1,29 @@ +# diff Configuration. + +check-thread-count = 4 + +export-fix-sql = true + +check-struct-only = false + +[task] +output-dir = "/tmp/tidb_cdc_test/debezium_avro/output" + +source-instances = ["mysql1"] + +target-instance = "tidb0" + +target-check-tables = ["test.?*"] + +[data-sources] +[data-sources.mysql1] +host = "127.0.0.1" +port = 4000 +user = "root" +password = "" + +[data-sources.tidb0] +host = "127.0.0.1" +port = 3306 +user = "root" +password = "" diff --git a/tests/integration_tests/debezium_avro/run.sh b/tests/integration_tests/debezium_avro/run.sh index fccabb10ab..bfb1ceaede 100644 --- a/tests/integration_tests/debezium_avro/run.sh +++ b/tests/integration_tests/debezium_avro/run.sh @@ -35,7 +35,6 @@ function run() { start_schema_registry start_tidb_cluster --workdir "$WORK_DIR" - run_sql_file "$CUR/data/prepare.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" start_ts=$(run_cdc_cli_tso_query "$UP_PD_HOST_1" "$UP_PD_PORT_1") @@ -47,14 +46,14 @@ function run() { changefeed_id="debezium-avro-$RANDOM" cdc_cli_changefeed create --start-ts="$start_ts" --sink-uri="$SINK_URI" -c "$changefeed_id" --schema-registry="$schema_registry_uri" - run_sql_file "$CUR/data/workload.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" + sleep 5 # wait for changefeed to start + run_kafka_consumer "$WORK_DIR" "$SINK_URI" "" "$schema_registry_uri" - GO111MODULE=on go run ./tests/integration_tests/debezium_avro/verify \ - --topic "$TOPIC_NAME" \ - --kafka-addr "127.0.0.1:9092" \ - --schema-registry "$schema_registry_uri" \ - --timeout "120s" \ - 2>&1 | tee "$WORK_DIR/debezium_avro_verify.log" + run_sql_file "$CUR/data/prepare.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" + run_sql_file "$CUR/data/workload.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" + run_sql "CREATE TABLE test.finish_mark (id int primary key);" "$UP_TIDB_HOST" "$UP_TIDB_PORT" + check_table_exists test.finish_mark "$DOWN_TIDB_HOST" "$DOWN_TIDB_PORT" 200 + check_sync_diff "$WORK_DIR" "$CUR/conf/diff_config.toml" cleanup_process "$CDC_BINARY" } diff --git a/tests/integration_tests/debezium_avro/verify/main.go b/tests/integration_tests/debezium_avro/verify/main.go deleted file mode 100644 index bfe9c036c1..0000000000 --- a/tests/integration_tests/debezium_avro/verify/main.go +++ /dev/null @@ -1,652 +0,0 @@ -// Copyright 2026 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "encoding/binary" - "encoding/json" - "errors" - "flag" - "fmt" - "io" - "math" - "net/http" - "net/url" - "os" - "sort" - "strings" - "time" - - "github.com/IBM/sarama" - "github.com/linkedin/goavro/v2" -) - -var requiredChecks = []string{ - "op_insert", - "op_update", - "op_delete", - "tp_int_insert", - "tp_int_update", - "tp_int_delete", - "tp_unsigned_normal_insert", - "tp_unsigned_max_insert", - "tp_real_insert", - "tp_real_update", - "tp_time_insert", - "tp_time_update", - "tp_text_update", - "tp_blob_update", - "tp_char_binary_update", - "tp_other_update", - "tp_account_delete", -} - -type schemaRegistryClient struct { - baseURL string - client *http.Client - codecs map[int]*goavro.Codec -} - -type rowEvent struct { - offset int64 - op string - table string - key map[string]any - before map[string]any - after map[string]any -} - -type kafkaMessage struct { - offset int64 - key []byte - value []byte -} - -type coverage struct { - checks map[string]bool - tables map[string]bool - count int -} - -func main() { - if err := run(); err != nil { - fmt.Fprintf(os.Stderr, "verify debezium avro failed: %v\n", err) - os.Exit(1) - } -} - -func run() error { - var ( - topic = flag.String("topic", "", "Kafka topic to verify") - kafkaAddr = flag.String("kafka-addr", "127.0.0.1:9092", "Kafka broker address") - schemaRegistry = flag.String("schema-registry", "http://127.0.0.1:8088", "Confluent Schema Registry URI") - timeout = flag.Duration("timeout", 60*time.Second, "Time to wait for the row events") - ) - flag.Parse() - - if *topic == "" { - return errors.New("topic is required") - } - - ctx, cancel := context.WithTimeout(context.Background(), *timeout) - defer cancel() - - registry := &schemaRegistryClient{ - baseURL: strings.TrimRight(*schemaRegistry, "/"), - client: &http.Client{Timeout: 10 * time.Second}, - codecs: make(map[int]*goavro.Codec), - } - - saramaConfig := sarama.NewConfig() - saramaConfig.Version = sarama.V2_4_0_0 - saramaConfig.Consumer.Return.Errors = true - saramaConfig.Net.DialTimeout = 10 * time.Second - saramaConfig.Net.ReadTimeout = 10 * time.Second - saramaConfig.Net.WriteTimeout = 10 * time.Second - - consumer, err := sarama.NewConsumer([]string{*kafkaAddr}, saramaConfig) - if err != nil { - return fmt.Errorf("create kafka consumer: %w", err) - } - defer consumer.Close() - - partitionConsumer, err := consumer.ConsumePartition(*topic, 0, sarama.OffsetOldest) - if err != nil { - return fmt.Errorf("consume kafka partition: %w", err) - } - defer partitionConsumer.Close() - - result := &coverage{ - checks: make(map[string]bool), - tables: make(map[string]bool), - } - var lastErr error - messagesCh := partitionConsumer.Messages() - errorsCh := partitionConsumer.Errors() - for { - select { - case <-ctx.Done(): - if lastErr != nil { - return fmt.Errorf("timed out waiting for Debezium Avro row events on topic %s; coverage: %s; last error: %w", *topic, result.summary(), lastErr) - } - return fmt.Errorf("timed out waiting for Debezium Avro row events on topic %s; coverage: %s", *topic, result.summary()) - case err, ok := <-errorsCh: - if ok { - lastErr = err - } else { - errorsCh = nil - } - case message, ok := <-messagesCh: - if !ok { - if lastErr != nil { - return fmt.Errorf("kafka message channel closed for topic %s; coverage: %s; last error: %w", *topic, result.summary(), lastErr) - } - return fmt.Errorf("kafka message channel closed for topic %s; coverage: %s", *topic, result.summary()) - } - - event, err := decodeRowEvent(ctx, registry, kafkaMessage{ - offset: message.Offset, - key: message.Key, - value: message.Value, - }) - if err != nil { - lastErr = err - continue - } - if event == nil { - continue - } - if err := result.observe(event); err != nil { - return err - } - if result.done() { - if err := registry.ensureSubject(ctx, *topic+"-key"); err != nil { - return err - } - if err := registry.ensureSubject(ctx, *topic+"-value"); err != nil { - return err - } - fmt.Printf("verified %d Debezium Confluent Avro row events from topic %s; coverage: %s\n", result.count, *topic, result.summary()) - return nil - } - } - } -} - -func decodeRowEvent( - ctx context.Context, - registry *schemaRegistryClient, - message kafkaMessage, -) (*rowEvent, error) { - key, err := decodeConfluentAvro(ctx, registry, message.key) - if err != nil { - return nil, fmt.Errorf("decode key at offset %d: %w", message.offset, err) - } - value, err := decodeConfluentAvro(ctx, registry, message.value) - if err != nil { - return nil, fmt.Errorf("decode value at offset %d: %w", message.offset, err) - } - - op, ok := value["op"].(string) - if !ok || (op != "c" && op != "u" && op != "d") { - return nil, nil - } - - source, ok := asMap(value["source"]) - if !ok { - return nil, fmt.Errorf("source is not a record at offset %d: %T", message.offset, value["source"]) - } - if source["db"] != "test" { - return nil, fmt.Errorf("unexpected source db at offset %d: %v", message.offset, source["db"]) - } - table, ok := stringValue(source["table"]) - if !ok { - return nil, fmt.Errorf("source table is not a string at offset %d: %v", message.offset, source["table"]) - } - - event := &rowEvent{ - offset: message.offset, - op: op, - table: table, - key: key, - } - event.before, _ = asMap(value["before"]) - event.after, _ = asMap(value["after"]) - return event, nil -} - -func (c *coverage) observe(event *rowEvent) error { - c.count++ - c.tables[event.table] = true - - switch event.op { - case "c": - c.checks["op_insert"] = true - if event.before != nil || event.after == nil { - return fmt.Errorf("invalid insert shape at offset %d", event.offset) - } - case "u": - c.checks["op_update"] = true - if event.before == nil || event.after == nil { - return fmt.Errorf("invalid update shape at offset %d", event.offset) - } - case "d": - c.checks["op_delete"] = true - if event.before == nil || event.after != nil { - return fmt.Errorf("invalid delete shape at offset %d", event.offset) - } - } - - id, ok := intValue(event.key["id"]) - if !ok { - return fmt.Errorf("key id is not an integer at offset %d: %v", event.offset, event.key["id"]) - } - - switch event.table { - case "tp_int": - return c.observeInt(event, id) - case "tp_unsigned_int": - return c.observeUnsignedInt(event, id) - case "tp_real": - return c.observeReal(event, id) - case "tp_time": - return c.observeTime(event, id) - case "tp_text": - return c.observeText(event, id) - case "tp_blob": - return c.observeBlob(event, id) - case "tp_char_binary": - return c.observeCharBinary(event, id) - case "tp_other": - return c.observeOther(event, id) - case "tp_account": - return c.observeAccount(event, id) - default: - return fmt.Errorf("unexpected table %s at offset %d", event.table, event.offset) - } -} - -func (c *coverage) observeInt(event *rowEvent, id int64) error { - switch { - case event.op == "c" && id == 2: - if err := expectInt(event.after, "c_tinyint", 1); err != nil { - return err - } - if err := expectInt(event.after, "c_smallint", 2); err != nil { - return err - } - if err := expectInt(event.after, "c_mediumint", 3); err != nil { - return err - } - if err := expectInt(event.after, "c_int", 4); err != nil { - return err - } - if err := expectInt(event.after, "c_bigint", 5); err != nil { - return err - } - c.checks["tp_int_insert"] = true - case event.op == "u" && id == 2: - if err := expectInt(event.before, "c_int", 4); err != nil { - return err - } - if err := expectInt(event.after, "c_int", 0); err != nil { - return err - } - c.checks["tp_int_update"] = true - case event.op == "d" && id == 2: - if err := expectInt(event.before, "c_int", 0); err != nil { - return err - } - c.checks["tp_int_delete"] = true - } - return nil -} - -func (c *coverage) observeUnsignedInt(event *rowEvent, id int64) error { - if event.op != "c" { - return nil - } - switch id { - case 2: - if err := expectInt(event.after, "c_unsigned_tinyint", 1); err != nil { - return err - } - if err := expectInt(event.after, "c_unsigned_bigint", 5); err != nil { - return err - } - c.checks["tp_unsigned_normal_insert"] = true - case 3: - if err := expectInt(event.after, "c_unsigned_tinyint", 255); err != nil { - return err - } - if err := expectInt(event.after, "c_unsigned_int", 4294967295); err != nil { - return err - } - if err := expectInt(event.after, "c_unsigned_bigint", -1); err != nil { - return err - } - c.checks["tp_unsigned_max_insert"] = true - } - return nil -} - -func (c *coverage) observeReal(event *rowEvent, id int64) error { - switch { - case event.op == "c" && id == 2: - if err := expectFloat(event.after, "c_double", 2020.0303, 0.000001); err != nil { - return err - } - if err := expectFloat(event.after, "c_decimal", 2020, 0.000001); err != nil { - return err - } - if err := expectFloat(event.after, "c_decimal_2", 2021.1208, 0.000001); err != nil { - return err - } - c.checks["tp_real_insert"] = true - case event.op == "u" && id == 2: - if err := expectFloat(event.before, "c_double", 2020.0303, 0.000001); err != nil { - return err - } - if err := expectFloat(event.after, "c_double", 2.333, 0.000001); err != nil { - return err - } - c.checks["tp_real_update"] = true - } - return nil -} - -func (c *coverage) observeTime(event *rowEvent, id int64) error { - switch { - case event.op == "c" && id == 2: - if err := expectInt(event.after, "c_date", 18312); err != nil { - return err - } - if err := expectString(event.after, "c_timestamp", "2020-02-20T02:20:20Z"); err != nil { - return err - } - if err := expectInt(event.after, "c_year", 2020); err != nil { - return err - } - c.checks["tp_time_insert"] = true - case event.op == "u" && id == 2: - if err := expectInt(event.before, "c_year", 2020); err != nil { - return err - } - if err := expectInt(event.after, "c_year", 2022); err != nil { - return err - } - c.checks["tp_time_update"] = true - } - return nil -} - -func (c *coverage) observeText(event *rowEvent, id int64) error { - if event.op == "u" && id == 2 { - if err := expectString(event.after, "c_text", "89504E470D0A1A0B"); err != nil { - return err - } - c.checks["tp_text_update"] = true - } - return nil -} - -func (c *coverage) observeBlob(event *rowEvent, id int64) error { - if event.op == "u" && id == 2 { - if err := expectString(event.after, "c_blob", "iVBORw0KGgs="); err != nil { - return err - } - c.checks["tp_blob_update"] = true - } - return nil -} - -func (c *coverage) observeCharBinary(event *rowEvent, id int64) error { - if event.op == "u" && id == 2 { - if err := expectString(event.after, "c_varchar", "89504E470D0A1A0B"); err != nil { - return err - } - c.checks["tp_char_binary_update"] = true - } - return nil -} - -func (c *coverage) observeOther(event *rowEvent, id int64) error { - if event.op == "u" && id == 3 { - if err := expectString(event.before, "c_enum", "b"); err != nil { - return err - } - if err := expectString(event.after, "c_enum", "c"); err != nil { - return err - } - if err := expectString(event.after, "c_set", "b,c"); err != nil { - return err - } - jsonValue, ok := stringValue(event.after["c_json"]) - if !ok { - return fmt.Errorf("unexpected c_json value: %v", event.after["c_json"]) - } - var jsonObject map[string]any - if err := json.Unmarshal([]byte(jsonValue), &jsonObject); err != nil { - return fmt.Errorf("decode c_json value: %w", err) - } - if jsonObject["key3"] != "123" { - return fmt.Errorf("unexpected c_json key3: %v", jsonObject["key3"]) - } - c.checks["tp_other_update"] = true - } - return nil -} - -func (c *coverage) observeAccount(event *rowEvent, id int64) error { - if event.op == "d" && id == 12 { - if err := expectInt(event.before, "account_id", 35); err != nil { - return err - } - c.checks["tp_account_delete"] = true - } - return nil -} - -func (c *coverage) done() bool { - for _, check := range requiredChecks { - if !c.checks[check] { - return false - } - } - return true -} - -func (c *coverage) summary() string { - var checks []string - for check := range c.checks { - checks = append(checks, check) - } - sort.Strings(checks) - var tables []string - for table := range c.tables { - tables = append(tables, table) - } - sort.Strings(tables) - return fmt.Sprintf("events=%d tables=%v checks=%v", c.count, tables, checks) -} - -func decodeConfluentAvro( - ctx context.Context, - registry *schemaRegistryClient, - data []byte, -) (map[string]any, error) { - if len(data) < 5 { - return nil, fmt.Errorf("message is shorter than Confluent Avro header: %d bytes", len(data)) - } - if data[0] != 0 { - return nil, fmt.Errorf("unexpected Confluent Avro magic byte: %d", data[0]) - } - - schemaID := int(binary.BigEndian.Uint32(data[1:5])) - codec, err := registry.codecByID(ctx, schemaID) - if err != nil { - return nil, err - } - native, _, err := codec.NativeFromBinary(data[5:]) - if err != nil { - return nil, err - } - result, ok := native.(map[string]any) - if !ok { - return nil, fmt.Errorf("decoded Avro payload is %T, expected record", native) - } - return result, nil -} - -func (c *schemaRegistryClient) codecByID(ctx context.Context, id int) (*goavro.Codec, error) { - if codec, ok := c.codecs[id]; ok { - return codec, nil - } - - body, err := c.get(ctx, fmt.Sprintf("/schemas/ids/%d", id)) - if err != nil { - return nil, err - } - var response struct { - Schema string `json:"schema"` - } - if err := json.Unmarshal(body, &response); err != nil { - return nil, err - } - if response.Schema == "" { - return nil, fmt.Errorf("schema registry returned empty schema for id %d", id) - } - - codec, err := goavro.NewCodecWithOptions(response.Schema, &goavro.CodecOption{EnableStringNull: false}) - if err != nil { - return nil, err - } - c.codecs[id] = codec - return codec, nil -} - -func (c *schemaRegistryClient) ensureSubject(ctx context.Context, subject string) error { - if _, err := c.get(ctx, "/subjects/"+url.PathEscape(subject)+"/versions/latest"); err != nil { - return fmt.Errorf("schema registry subject %s is missing: %w", subject, err) - } - return nil -} - -func (c *schemaRegistryClient) get(ctx context.Context, path string) ([]byte, error) { - request, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+path, nil) - if err != nil { - return nil, err - } - - response, err := c.client.Do(request) - if err != nil { - return nil, err - } - defer response.Body.Close() - - body, err := io.ReadAll(io.LimitReader(response.Body, 1<<20)) - if err != nil { - return nil, err - } - if response.StatusCode != http.StatusOK { - return nil, fmt.Errorf("GET %s returned %s: %s", path, response.Status, strings.TrimSpace(string(body))) - } - return body, nil -} - -func expectInt(row map[string]any, field string, expected int64) error { - value, ok := intValue(row[field]) - if !ok { - return fmt.Errorf("%s is not an integer: %v", field, row[field]) - } - if value != expected { - return fmt.Errorf("unexpected %s: got %d, want %d", field, value, expected) - } - return nil -} - -func expectFloat(row map[string]any, field string, expected float64, tolerance float64) error { - value, ok := floatValue(row[field]) - if !ok { - return fmt.Errorf("%s is not a float: %v", field, row[field]) - } - if math.Abs(value-expected) > tolerance { - return fmt.Errorf("unexpected %s: got %f, want %f", field, value, expected) - } - return nil -} - -func expectString(row map[string]any, field string, expected string) error { - value, ok := stringValue(row[field]) - if !ok { - return fmt.Errorf("%s is not a string: %v", field, row[field]) - } - if value != expected { - return fmt.Errorf("unexpected %s: got %q, want %q", field, value, expected) - } - return nil -} - -func asMap(value any) (map[string]any, bool) { - unwrapped := unwrapUnion(value) - result, ok := unwrapped.(map[string]any) - return result, ok -} - -func stringValue(value any) (string, bool) { - v, ok := unwrapUnion(value).(string) - return v, ok -} - -func intValue(value any) (int64, bool) { - switch v := unwrapUnion(value).(type) { - case int: - return int64(v), true - case int32: - return int64(v), true - case int64: - return v, true - case uint64: - return int64(v), true - default: - return 0, false - } -} - -func floatValue(value any) (float64, bool) { - switch v := unwrapUnion(value).(type) { - case float32: - return float64(v), true - case float64: - return v, true - case int32: - return float64(v), true - case int64: - return float64(v), true - default: - return 0, false - } -} - -func unwrapUnion(value any) any { - result, ok := value.(map[string]any) - if !ok || len(result) != 1 { - return value - } - for _, branchValue := range result { - return branchValue - } - return value -} From d47b0e8d5f97e1e644b6558ffd34e5fc91083d24 Mon Sep 17 00:00:00 2001 From: wk989898 Date: Tue, 23 Jun 2026 04:17:49 +0000 Subject: [PATCH 05/10] fix Signed-off-by: wk989898 --- pkg/sink/codec/debezium/avro_decoder.go | 3 ++ pkg/sink/codec/debezium/avro_test.go | 43 +++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/pkg/sink/codec/debezium/avro_decoder.go b/pkg/sink/codec/debezium/avro_decoder.go index 3669cf2d39..4a476fccc7 100644 --- a/pkg/sink/codec/debezium/avro_decoder.go +++ b/pkg/sink/codec/debezium/avro_decoder.go @@ -287,6 +287,9 @@ func avroNativeToConnectPayload(schema any, value any, namedSchemas map[string]a if !ok { return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro array schema is missing items") } + if value == nil { + return []any{}, nil + } values, ok := value.([]any) if !ok { return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro array payload is invalid") diff --git a/pkg/sink/codec/debezium/avro_test.go b/pkg/sink/codec/debezium/avro_test.go index e2f2b3fdaa..0ee7991cf5 100644 --- a/pkg/sink/codec/debezium/avro_test.go +++ b/pkg/sink/codec/debezium/avro_test.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/ticdc/pkg/config" "github.com/pingcap/ticdc/pkg/sink/codec/avro" "github.com/pingcap/ticdc/pkg/sink/codec/common" + timodel "github.com/pingcap/tidb/pkg/meta/model" "github.com/stretchr/testify/require" ) @@ -190,6 +191,48 @@ func TestDebeziumConfluentAvroDecodeDDLEvent(t *testing.T) { require.Equal(t, routedDDL.Query, decoded.Query) } +func TestDebeziumConfluentAvroDecodeSchemaDDLEvent(t *testing.T) { + ctx := context.Background() + _, err := avro.SetupEncoderAndSchemaRegistry4Testing( + ctx, + common.NewConfig(config.ProtocolAvro), + ) + require.NoError(t, err) + defer avro.TeardownEncoderAndSchemaRegistry4Testing() + + cfg := common.NewConfig(config.ProtocolDebezium) + cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" + cfg.EnableTiDBExtension = true + cfg.TimeZone = time.UTC + + encoder, err := NewAvroBatchEncoder(ctx, cfg, "dbserver1") + require.NoError(t, err) + + ddl := &commonEvent.DDLEvent{ + Version: commonEvent.DDLEventVersion1, + Type: byte(timodel.ActionCreateSchema), + SchemaName: "test", + Query: "CREATE DATABASE `test`", + FinishedTs: 100, + } + message, err := encoder.EncodeDDLEvent(ddl) + require.NoError(t, err) + require.NotNil(t, message) + + decoder, err := NewAvroDecoder(ctx, cfg, 0, nil) + require.NoError(t, err) + decoder.AddKeyValue(message.Key, message.Value) + + messageType, hasNext := decoder.HasNext() + require.True(t, hasNext) + require.Equal(t, common.MessageTypeDDL, messageType) + + decoded := decoder.NextDDLEvent() + require.Equal(t, "test", decoded.SchemaName) + require.Empty(t, decoded.TableName) + require.Equal(t, ddl.Query, decoded.Query) +} + func decodeConfluentAvroForTest(t *testing.T, data []byte) map[string]any { t.Helper() From dec3aade7ccb62a8ee26c2e0fa9d2b482d425dae Mon Sep 17 00:00:00 2001 From: wk989898 Date: Tue, 23 Jun 2026 06:50:05 +0000 Subject: [PATCH 06/10] fix Signed-off-by: wk989898 --- pkg/sink/codec/debezium/avro_decoder.go | 52 +++++++- pkg/sink/codec/debezium/avro_test.go | 170 ++++++++++++++++++++++++ 2 files changed, 219 insertions(+), 3 deletions(-) diff --git a/pkg/sink/codec/debezium/avro_decoder.go b/pkg/sink/codec/debezium/avro_decoder.go index 4a476fccc7..06bcb6a0e9 100644 --- a/pkg/sink/codec/debezium/avro_decoder.go +++ b/pkg/sink/codec/debezium/avro_decoder.go @@ -271,9 +271,17 @@ func avroNativeToConnectPayload(schema any, value any, namedSchemas map[string]a return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro field is missing name") } connectFieldName := avroConnectFieldName(field, avroFieldName) + rawValue, exists := valueMap[avroFieldName] + if !exists && connectFieldName != avroFieldName { + rawValue, exists = valueMap[connectFieldName] + } + if !exists { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs( + "avro record payload is missing field " + avroFieldName) + } fieldValue, err := avroNativeToConnectPayload( field["type"], - valueMap[avroFieldName], + rawValue, namedSchemas, ) if err != nil { @@ -523,8 +531,14 @@ func avroUnionBranch(union []any, value any) (any, any, error) { if value == nil { return nil, nil, nil } + var wrappedBranchName string + var wrappedBranchValue any + hasWrappedBranch := false if branchValueMap, ok := value.(map[string]any); ok && len(branchValueMap) == 1 { for branchName, branchValue := range branchValueMap { + wrappedBranchName = branchName + wrappedBranchValue = branchValue + hasWrappedBranch = true for _, branchSchema := range union { if avroBranchName(branchSchema) == branchName { return branchSchema, branchValue, nil @@ -533,19 +547,32 @@ func avroUnionBranch(union []any, value any) (any, any, error) { } } - branchSchema, _, err := avroNonNullUnionBranch(union) + branchSchema, isSingleNonNullBranch, err := avroNonNullUnionBranch(union) if err != nil { return nil, nil, err } + if hasWrappedBranch && + isSingleNonNullBranch && + avroShortBranchName(branchSchema) == wrappedBranchName { + return branchSchema, wrappedBranchValue, nil + } return branchSchema, value, nil } func avroNonNullUnionBranch(union []any) (any, bool, error) { + var result any + count := 0 for _, branch := range union { if avroBranchName(branch) != "null" { - return branch, true, nil + if count == 0 { + result = branch + } + count++ } } + if count > 0 { + return result, count == 1, nil + } return nil, false, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("avro union has no non-null branch") } @@ -573,6 +600,25 @@ func avroBranchName(schema any) string { } } +func avroShortBranchName(schema any) string { + switch typedSchema := schema.(type) { + case string: + if idx := strings.LastIndex(typedSchema, "."); idx >= 0 { + return typedSchema[idx+1:] + } + return typedSchema + case map[string]any: + typeName, _ := typedSchema["type"].(string) + if typeName == "record" { + name, _ := typedSchema["name"].(string) + return name + } + return avroBranchName(schema) + default: + return "" + } +} + func avroConnectFieldName(field map[string]any, fallback string) string { if fieldName, ok := field[debeziumAvroConnectFieldKey].(string); ok && fieldName != "" { return fieldName diff --git a/pkg/sink/codec/debezium/avro_test.go b/pkg/sink/codec/debezium/avro_test.go index 0ee7991cf5..3f600f1526 100644 --- a/pkg/sink/codec/debezium/avro_test.go +++ b/pkg/sink/codec/debezium/avro_test.go @@ -155,6 +155,176 @@ func TestDebeziumConfluentAvroDecodeRowEvent(t *testing.T) { common.CompareRow(t, row, helper.tableInfo, change, decoded.TableInfo) } +func TestDebeziumConfluentAvroDecodeAccountDMLEvents(t *testing.T) { + ctx := context.Background() + _, err := avro.SetupEncoderAndSchemaRegistry4Testing( + ctx, + common.NewConfig(config.ProtocolAvro), + ) + require.NoError(t, err) + defer avro.TeardownEncoderAndSchemaRegistry4Testing() + + helper := NewSQLTestHelper(t, "tp_account", ` + create table tp_account( + id int primary key, + account_id int not null + )`) + defer helper.Close() + + insertDML := helper.helper.DML2Event("test", "tp_account", + "insert into tp_account values (12, 34)") + updateDML, _ := helper.helper.DML2UpdateEvent("test", "tp_account", + "insert into tp_account values (13, 34)", + "update tp_account set account_id = 35 where id = 13") + deleteDML := helper.helper.DML2DeleteEvent("test", "tp_account", + "insert into tp_account values (14, 34)", + "delete from tp_account where id = 14") + + cfg := common.NewConfig(config.ProtocolDebezium) + cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" + cfg.EnableTiDBExtension = true + cfg.TimeZone = time.UTC + + encoder, err := NewAvroBatchEncoder(ctx, cfg, "dbserver1") + require.NoError(t, err) + + rows := make([]commonEvent.RowChange, 0, 3) + for _, dml := range []*commonEvent.DMLEvent{insertDML, updateDML, deleteDML} { + row, ok := dml.GetNextRow() + if !ok { + continue + } + rows = append(rows, row) + require.NoError(t, encoder.AppendRowChangedEvent(ctx, "dbserver1.test.tp_account", &commonEvent.RowEvent{ + TableInfo: helper.tableInfo, + CommitTs: 1, + Event: row, + ColumnSelector: columnselector.NewDefaultColumnSelector(), + Callback: func() {}, + })) + } + require.Len(t, rows, 3) + + messages := encoder.Build() + require.Len(t, messages, 3) + for idx, message := range messages { + decoder, err := NewAvroDecoder(ctx, cfg, 0, nil) + require.NoError(t, err) + decoder.AddKeyValue(message.Key, message.Value) + + messageType, hasNext := decoder.HasNext() + require.True(t, hasNext) + require.Equal(t, common.MessageTypeRow, messageType) + + decoded := decoder.NextDMLEvent() + require.Equal(t, "test", decoded.TableInfo.GetSchemaName()) + require.Equal(t, "tp_account", decoded.TableInfo.GetTableName()) + + change, ok := decoded.GetNextRow() + require.True(t, ok) + common.CompareRow(t, rows[idx], helper.tableInfo, change, decoded.TableInfo) + } +} + +func TestDebeziumConfluentAvroDecodeShortNamedUnionBranch(t *testing.T) { + valueSchema := map[string]any{ + "type": "record", + "name": "Value", + "namespace": "dbserver1.test.tp_account", + "fields": []any{ + map[string]any{ + "name": "id", + "type": "int", + debeziumAvroConnectFieldKey: "id", + }, + map[string]any{ + "name": "account_id", + "type": "int", + debeziumAvroConnectFieldKey: "account_id", + }, + }, + } + namedSchemas := map[string]any{ + "dbserver1.test.tp_account.Value": valueSchema, + } + + payload, err := avroNativeToConnectPayload( + []any{"null", "dbserver1.test.tp_account.Value"}, + map[string]any{ + "Value": map[string]any{ + "id": int32(12), + "account_id": int32(34), + }, + }, + namedSchemas, + ) + require.NoError(t, err) + require.Equal(t, map[string]any{ + "id": int32(12), + "account_id": int32(34), + }, payload) +} + +func TestDebeziumConfluentAvroDecodeSingleFieldUnionRecord(t *testing.T) { + valueSchema := map[string]any{ + "type": "record", + "name": "Value", + "namespace": "dbserver1.test.only_pk", + "fields": []any{ + map[string]any{ + "name": "id", + "type": "int", + debeziumAvroConnectFieldKey: "id", + }, + }, + } + namedSchemas := map[string]any{ + "dbserver1.test.only_pk.Value": valueSchema, + } + + payload, err := avroNativeToConnectPayload( + []any{"null", "dbserver1.test.only_pk.Value"}, + map[string]any{ + "id": int32(12), + }, + namedSchemas, + ) + require.NoError(t, err) + require.Equal(t, map[string]any{ + "id": int32(12), + }, payload) +} + +func TestDebeziumConfluentAvroDecodeMissingRecordField(t *testing.T) { + valueSchema := map[string]any{ + "type": "record", + "name": "Value", + "namespace": "dbserver1.test.tp_account", + "fields": []any{ + map[string]any{ + "name": "id", + "type": "int", + debeziumAvroConnectFieldKey: "id", + }, + map[string]any{ + "name": "account_id", + "type": "int", + debeziumAvroConnectFieldKey: "account_id", + }, + }, + } + + _, err := avroNativeToConnectPayload( + valueSchema, + map[string]any{ + "id": int32(12), + }, + nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "avro record payload is missing field account_id") +} + func TestDebeziumConfluentAvroDecodeDDLEvent(t *testing.T) { ctx := context.Background() _, err := avro.SetupEncoderAndSchemaRegistry4Testing( From 79d77801f466052a546afb9326f68270dcc546e7 Mon Sep 17 00:00:00 2001 From: wk989898 Date: Tue, 23 Jun 2026 09:14:54 +0000 Subject: [PATCH 07/10] refactor Signed-off-by: wk989898 --- api/v2/changefeed.go | 3 +- cmd/kafka-consumer/writer.go | 96 +++++- cmd/kafka-consumer/writer_test.go | 40 +++ downstreamadapter/sink/helper/helper.go | 2 +- downstreamadapter/sink/kafka/helper.go | 3 +- pkg/config/sink.go | 2 +- pkg/config/sink_protocol.go | 5 + pkg/config/sink_protocol_test.go | 12 + pkg/sink/codec/builder.go | 10 +- pkg/sink/codec/common/config.go | 69 ++++- pkg/sink/codec/common/config_test.go | 15 +- pkg/sink/codec/debezium/avro.go | 53 ++++ pkg/sink/codec/debezium/avro_decoder.go | 104 ++++++- pkg/sink/codec/debezium/avro_test.go | 240 +++++++++++---- pkg/sink/codec/debezium/codec.go | 283 +++++++++++++----- pkg/sink/codec/debezium/decoder.go | 16 + pkg/sink/codec/debezium/encoder.go | 7 + .../debezium_avro/data/ddl.sql | 3 + .../debezium_avro/data/post_ddl_workload.sql | 8 + .../debezium_avro/data/prepare.sql | 82 +---- .../debezium_avro/data/workload.sql | 80 +---- tests/integration_tests/debezium_avro/run.sh | 23 +- 22 files changed, 839 insertions(+), 317 deletions(-) create mode 100644 tests/integration_tests/debezium_avro/data/ddl.sql create mode 100644 tests/integration_tests/debezium_avro/data/post_ddl_workload.sql diff --git a/api/v2/changefeed.go b/api/v2/changefeed.go index 954ee75082..d2a76176bc 100644 --- a/api/v2/changefeed.go +++ b/api/v2/changefeed.go @@ -1770,7 +1770,8 @@ func verifyTable4MQ( return nil } - eventRouter, err := eventrouter.NewEventRouter(replicaConfig.Sink, topic, config.IsPulsarScheme(scheme), protocol == config.ProtocolAvro) + isAvroLike := protocol == config.ProtocolAvro || protocol == config.ProtocolDebeziumAvro + eventRouter, err := eventrouter.NewEventRouter(replicaConfig.Sink, topic, config.IsPulsarScheme(scheme), isAvroLike) if err != nil { return err } diff --git a/cmd/kafka-consumer/writer.go b/cmd/kafka-consumer/writer.go index 88c058f2cc..e68cf4d829 100644 --- a/cmd/kafka-consumer/writer.go +++ b/cmd/kafka-consumer/writer.go @@ -119,7 +119,8 @@ func newWriter(ctx context.Context, o *option) *writer { w.progresses[i] = newPartitionProgress(int32(i), decoder) } - eventRouter, err := eventrouter.NewEventRouter(o.sinkConfig, o.topic, false, o.protocol == config.ProtocolAvro) + isAvroLike := o.protocol == config.ProtocolAvro || o.protocol == config.ProtocolDebeziumAvro + eventRouter, err := eventrouter.NewEventRouter(o.sinkConfig, o.topic, false, isAvroLike) if err != nil { log.Panic("initialize the event router failed", zap.Any("protocol", o.protocol), zap.Any("topic", o.topic), @@ -356,6 +357,83 @@ func (w *writer) flushDMLEventsByWatermark(ctx context.Context) error { } } +func (w *writer) flushPartitionDMLEvents( + ctx context.Context, + progress *partitionProgress, + watermark uint64, +) error { + var ( + done = make(chan struct{}, 1) + + total int + flushed atomic.Int64 + ) + + resolvedEvents := make([]*event.DMLEvent, 0) + resolvedGroups := make([]struct { + group *util.EventsGroup + maxCommitTs uint64 + }, 0) + for _, group := range progress.eventsGroup { + before := len(resolvedEvents) + resolvedEvents = group.ResolveInto(watermark, resolvedEvents) + resolvedCount := len(resolvedEvents) - before + if resolvedCount == 0 { + continue + } + + resolvedGroups = append(resolvedGroups, struct { + group *util.EventsGroup + maxCommitTs uint64 + }{ + group: group, + maxCommitTs: resolvedEvents[len(resolvedEvents)-1].GetCommitTs(), + }) + total += resolvedCount + } + if total == 0 { + return nil + } + for _, e := range resolvedEvents { + e.AddPostFlushFunc(func() { + if flushed.Inc() == int64(total) { + close(done) + } + }) + w.mysqlSink.AddDMLEvent(e) + log.Debug("flush partition DML event", zap.Int64("tableID", e.GetTableID()), + zap.Uint64("commitTs", e.GetCommitTs()), zap.Any("startTs", e.GetStartTs())) + } + + log.Info("flush partition DML events", zap.Int32("partition", progress.partition), + zap.Uint64("watermark", watermark), zap.Int("total", total)) + start := time.Now() + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-done: + log.Info("flush partition DML events done", zap.Int32("partition", progress.partition), + zap.Uint64("watermark", watermark), zap.Int("total", total), + zap.Duration("duration", time.Since(start))) + progress.updateWatermark(watermark, progress.watermarkOffset) + for _, item := range resolvedGroups { + if item.maxCommitTs > item.group.AppliedWatermark { + item.group.AppliedWatermark = item.maxCommitTs + } + } + return nil + case <-ticker.C: + log.Warn("partition DML events cannot be flushed in time", + zap.Int32("partition", progress.partition), + zap.Uint64("watermark", watermark), + zap.Int("total", total), zap.Int64("flushed", flushed.Load())) + } + } +} + // WriteMessage is to decode kafka message to event. // return true if the message is flushed to the downstream. // return error if flush messages failed. @@ -428,6 +506,7 @@ func (w *writer) WriteMessage(ctx context.Context, message *kafka.Message) bool break } + maxCommitTs := row.GetCommitTs() w.appendRow2Group(row, progress, offset) counter++ for { @@ -436,6 +515,9 @@ func (w *writer) WriteMessage(ctx context.Context, message *kafka.Message) bool break } row = progress.decoder.NextDMLEvent() + if row.GetCommitTs() > maxCommitTs { + maxCommitTs = row.GetCommitTs() + } w.appendRow2Group(row, progress, offset) counter++ } @@ -451,6 +533,15 @@ func (w *writer) WriteMessage(ctx context.Context, message *kafka.Message) bool zap.Int("maxBatchSize", w.maxBatchSize), zap.Int("actualBatchSize", counter), zap.Int32("partition", partition), zap.Any("offset", offset)) } + if w.protocol == config.ProtocolDebeziumAvro { + progress.watermarkOffset = offset + if err := w.flushPartitionDMLEvents(ctx, progress, maxCommitTs); err != nil { + log.Panic("flush debezium avro dml events failed", zap.Error(err), + zap.Int32("partition", partition), zap.Any("offset", offset), + zap.Uint64("watermark", maxCommitTs)) + } + return true + } default: log.Panic("unknown message type", zap.Any("messageType", messageType), zap.Int32("partition", partition), zap.Any("offset", offset)) @@ -539,7 +630,8 @@ func (w *writer) onDDL(ddl *event.DDLEvent) { return } switch w.protocol { - case config.ProtocolCanalJSON, config.ProtocolOpen, config.ProtocolAvro, config.ProtocolSimple, config.ProtocolDebezium: + case config.ProtocolCanalJSON, config.ProtocolOpen, config.ProtocolAvro, config.ProtocolSimple, + config.ProtocolDebezium, config.ProtocolDebeziumAvro: default: return } diff --git a/cmd/kafka-consumer/writer_test.go b/cmd/kafka-consumer/writer_test.go index ec9aa34a43..d54027d44e 100644 --- a/cmd/kafka-consumer/writer_test.go +++ b/cmd/kafka-consumer/writer_test.go @@ -324,3 +324,43 @@ func TestAppendRow2Group_DoesNotDropCommitTsFallbackBeforeApplied(t *testing.T) resolved = group.ResolveInto(150, resolvedEvents) require.Empty(t, resolved) } + +func TestFlushPartitionDMLEventsFlushesWithoutResolved(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + s := sinkmock.NewMockSink(ctrl) + + flushedCommitTs := make([]uint64, 0) + s.EXPECT().AddDMLEvent(gomock.Any()).Do(func(e *commonEvent.DMLEvent) { + flushedCommitTs = append(flushedCommitTs, e.CommitTs) + e.PostFlush() + }).AnyTimes() + + w := &writer{mysqlSink: s} + progress := &partitionProgress{ + partition: 0, + eventsGroup: map[int64]*util.EventsGroup{1: util.NewEventsGroup(0, 1)}, + } + newDMLEvent := func(commitTs uint64) *commonEvent.DMLEvent { + return &commonEvent.DMLEvent{ + PhysicalTableID: 1, + CommitTs: commitTs, + RowTypes: []common.RowType{common.RowTypeUpdate}, + Rows: chunk.NewChunkWithCapacity(nil, 0), + TableInfo: &common.TableInfo{ + TableName: common.TableName{Schema: "test", Table: "t"}, + }, + } + } + group := progress.eventsGroup[1] + group.Append(newDMLEvent(100), false) + group.Append(newDMLEvent(200), false) + + require.NoError(t, w.flushPartitionDMLEvents(ctx, progress, 150)) + require.Equal(t, []uint64{100}, flushedCommitTs) + require.Equal(t, uint64(100), group.AppliedWatermark) + + remaining := group.ResolveInto(300, nil) + require.Len(t, remaining, 1) + require.Equal(t, uint64(200), remaining[0].CommitTs) +} diff --git a/downstreamadapter/sink/helper/helper.go b/downstreamadapter/sink/helper/helper.go index 47cd949220..11dd7b85de 100644 --- a/downstreamadapter/sink/helper/helper.go +++ b/downstreamadapter/sink/helper/helper.go @@ -100,7 +100,7 @@ func GetProtocol(protocolStr string) (config.Protocol, error) { // GetFileExtension returns the extension for specific protocol func GetFileExtension(protocol config.Protocol) string { switch protocol { - case config.ProtocolAvro, config.ProtocolCanalJSON, config.ProtocolMaxwell, + case config.ProtocolAvro, config.ProtocolDebeziumAvro, config.ProtocolCanalJSON, config.ProtocolMaxwell, config.ProtocolOpen, config.ProtocolSimple: return ".json" case config.ProtocolCraft: diff --git a/downstreamadapter/sink/kafka/helper.go b/downstreamadapter/sink/kafka/helper.go index de6ce2fb71..4b83310595 100644 --- a/downstreamadapter/sink/kafka/helper.go +++ b/downstreamadapter/sink/kafka/helper.go @@ -77,8 +77,9 @@ func newKafkaSinkComponentWithFactory(ctx context.Context, return kafkaComponent, protocol, errors.WrapError(errors.ErrKafkaNewProducer, err) } + isAvroLike := protocol == config.ProtocolAvro || protocol == config.ProtocolDebeziumAvro kafkaComponent.eventRouter, err = eventrouter.NewEventRouter( - sinkConfig, topic, false, protocol == config.ProtocolAvro) + sinkConfig, topic, false, isAvroLike) if err != nil { return kafkaComponent, protocol, errors.Trace(err) } diff --git a/pkg/config/sink.go b/pkg/config/sink.go index 753f470c14..dca4d50572 100644 --- a/pkg/config/sink.go +++ b/pkg/config/sink.go @@ -966,7 +966,7 @@ func (s *SinkConfig) ValidateProtocol(scheme string) error { if s.OpenProtocol != nil { outputOldValue = s.OpenProtocol.OutputOldValue } - case ProtocolDebezium: + case ProtocolDebezium, ProtocolDebeziumAvro: if s.Debezium != nil { outputOldValue = s.Debezium.OutputOldValue } diff --git a/pkg/config/sink_protocol.go b/pkg/config/sink_protocol.go index c9b9a7fd5a..113d66eab9 100644 --- a/pkg/config/sink_protocol.go +++ b/pkg/config/sink_protocol.go @@ -41,6 +41,7 @@ const ( ProtocolCsv ProtocolDebezium ProtocolSimple + ProtocolDebeziumAvro ) // IsBatchEncode returns whether the protocol is a batch encoder. @@ -71,6 +72,8 @@ func ParseSinkProtocolFromString(protocol string) (Protocol, error) { return ProtocolCsv, nil case "debezium": return ProtocolDebezium, nil + case "debezium-avro": + return ProtocolDebeziumAvro, nil case "simple": return ProtocolSimple, nil default: @@ -101,6 +104,8 @@ func (p Protocol) String() string { return "debezium" case ProtocolSimple: return "simple" + case ProtocolDebeziumAvro: + return "debezium-avro" default: panic("unreachable") } diff --git a/pkg/config/sink_protocol_test.go b/pkg/config/sink_protocol_test.go index 27e7106b3c..05b8525fda 100644 --- a/pkg/config/sink_protocol_test.go +++ b/pkg/config/sink_protocol_test.go @@ -62,6 +62,10 @@ func TestParseSinkProtocolFromString(t *testing.T) { protocol: "open-protocol", expectedProtocolEnum: ProtocolOpen, }, + { + protocol: "debezium-avro", + expectedProtocolEnum: ProtocolDebeziumAvro, + }, } for _, tc := range testCases { @@ -109,6 +113,10 @@ func TestString(t *testing.T) { protocolEnum: ProtocolOpen, expectedProtocol: "open-protocol", }, + { + protocolEnum: ProtocolDebeziumAvro, + expectedProtocol: "debezium-avro", + }, } for _, tc := range testCases { @@ -151,6 +159,10 @@ func TestIsBatchEncoder(t *testing.T) { protocolEnum: ProtocolOpen, expect: true, }, + { + protocolEnum: ProtocolDebeziumAvro, + expect: false, + }, } for _, tc := range testCases { diff --git a/pkg/sink/codec/builder.go b/pkg/sink/codec/builder.go index ea9d8da21c..e003710bfb 100644 --- a/pkg/sink/codec/builder.go +++ b/pkg/sink/codec/builder.go @@ -40,10 +40,9 @@ func NewEventEncoder(ctx context.Context, cfg *common.Config) (common.EventEncod case config.ProtocolCanalJSON: return canal.NewJSONRowEventEncoder(ctx, cfg) case config.ProtocolDebezium: - if cfg.AvroConfluentSchemaRegistry != "" { - return debezium.NewAvroBatchEncoder(ctx, cfg, config.GetGlobalServerConfig().ClusterID) - } return debezium.NewBatchEncoder(cfg, config.GetGlobalServerConfig().ClusterID), nil + case config.ProtocolDebeziumAvro: + return debezium.NewAvroBatchEncoder(ctx, cfg, config.GetGlobalServerConfig().ClusterID) case config.ProtocolSimple: return simple.NewEncoder(ctx, cfg) default: @@ -69,10 +68,9 @@ func NewEventDecoder( case config.ProtocolSimple: return simple.NewDecoder(ctx, codecConfig, upstreamTiDB) case config.ProtocolDebezium: - if codecConfig.AvroConfluentSchemaRegistry != "" { - return debezium.NewAvroDecoder(ctx, codecConfig, idx, upstreamTiDB) - } return debezium.NewDecoder(codecConfig, idx, upstreamTiDB), nil + case config.ProtocolDebeziumAvro: + return debezium.NewAvroDecoder(ctx, codecConfig, idx, upstreamTiDB) default: } log.Panic("Protocol not supported", zap.Any("Protocol", codecConfig.Protocol)) diff --git a/pkg/sink/codec/common/config.go b/pkg/sink/codec/common/config.go index 599d61f95b..bb3d9fcbc6 100644 --- a/pkg/sink/codec/common/config.go +++ b/pkg/sink/codec/common/config.go @@ -55,7 +55,7 @@ type Config struct { OutputRowKey bool - // avro only, except AvroConfluentSchemaRegistry is also used by debezium + // avro and debezium-avro only // protocol when Confluent Avro encoding is enabled. AvroConfluentSchemaRegistry string AvroDecimalHandlingMode string @@ -237,9 +237,10 @@ func (c *Config) Apply(sinkURI *url.URL, sinkConfig *config.SinkConfig) error { sinkConfig.KafkaConfig.GlueSchemaRegistryConfig != nil { c.AvroGlueSchemaRegistry = sinkConfig.KafkaConfig.GlueSchemaRegistryConfig } - if c.Protocol == config.ProtocolAvro && util.GetOrZero(sinkConfig.ForceReplicate) { + if (c.Protocol == config.ProtocolAvro || c.Protocol == config.ProtocolDebeziumAvro) && + util.GetOrZero(sinkConfig.ForceReplicate) { return errors.ErrCodecInvalidConfig.GenWithStack( - `force-replicate must be disabled, when using avro protocol`) + `force-replicate must be disabled, when using avro or debezium-avro protocol`) } if sinkConfig != nil { @@ -354,30 +355,57 @@ func (c *Config) WithChangefeedID(id common.ChangeFeedID) *Config { // Validate the Config func (c *Config) Validate() error { if c.EnableTiDBExtension && - (c.Protocol != config.ProtocolCanalJSON && c.Protocol != config.ProtocolAvro && c.Protocol != config.ProtocolDebezium) { + (c.Protocol != config.ProtocolCanalJSON && c.Protocol != config.ProtocolAvro && + c.Protocol != config.ProtocolDebezium && c.Protocol != config.ProtocolDebeziumAvro) { log.Warn("ignore invalid config, enable-tidb-extension"+ - "only supports canal-json/avro/debezium protocol", + "only supports canal-json/avro/debezium/debezium-avro protocol", zap.Bool("enableTidbExtension", c.EnableTiDBExtension), zap.String("protocol", c.Protocol.String())) } - if c.Protocol == config.ProtocolAvro { + if c.Protocol == config.ProtocolDebezium && + (c.AvroConfluentSchemaRegistry != "" || c.AvroGlueSchemaRegistry != nil) { + return errors.ErrCodecInvalidConfig.GenWithStack( + `Debezium protocol does not support schema registry; use protocol "debezium-avro"`, + ) + } + + if c.Protocol == config.ProtocolAvro || c.Protocol == config.ProtocolDebeziumAvro { if c.AvroConfluentSchemaRegistry != "" && c.AvroGlueSchemaRegistry != nil { + protocol := "Avro" + if c.Protocol == config.ProtocolDebeziumAvro { + protocol = "Debezium Avro" + } return errors.ErrCodecInvalidConfig.GenWithStack( - `Avro protocol requires only one of "%s" or "%s" to specify the schema registry`, + `%s protocol requires only one of "%s" or "%s" to specify the schema registry`, + protocol, codecOPTAvroSchemaRegistry, coderOPTAvroGlueSchemaRegistry, ) } if c.AvroConfluentSchemaRegistry == "" && c.AvroGlueSchemaRegistry == nil { + protocol := "Avro" + if c.Protocol == config.ProtocolDebeziumAvro { + protocol = "Debezium Avro" + } return errors.ErrCodecInvalidConfig.GenWithStack( - `Avro protocol requires parameter "%s" or "%s" to specify the schema registry`, + `%s protocol requires parameter "%s" or "%s" to specify the schema registry`, + protocol, codecOPTAvroSchemaRegistry, coderOPTAvroGlueSchemaRegistry, ) } + if c.Protocol == config.ProtocolDebeziumAvro && c.AvroGlueSchemaRegistry != nil { + return errors.ErrCodecInvalidConfig.GenWithStack( + `Debezium Avro protocol only supports "%s" for Confluent Avro Schema Registry`, + codecOPTAvroSchemaRegistry, + ) + } + } + + if c.Protocol == config.ProtocolAvro { if c.AvroDecimalHandlingMode != DecimalHandlingModePrecise && c.AvroDecimalHandlingMode != DecimalHandlingModeString { return errors.ErrCodecInvalidConfig.GenWithStack( @@ -411,11 +439,26 @@ func (c *Config) Validate() error { } } - if c.Protocol == config.ProtocolDebezium && c.AvroGlueSchemaRegistry != nil { - return errors.ErrCodecInvalidConfig.GenWithStack( - `Debezium protocol only supports "%s" for Confluent Avro Schema Registry`, - codecOPTAvroSchemaRegistry, - ) + if c.Protocol == config.ProtocolDebeziumAvro { + if c.AvroDecimalHandlingMode != DecimalHandlingModePrecise && + c.AvroDecimalHandlingMode != DecimalHandlingModeString { + return errors.ErrCodecInvalidConfig.GenWithStack( + `%s value could only be "%s" or "%s"`, + codecOPTAvroDecimalHandlingMode, + DecimalHandlingModeString, + DecimalHandlingModePrecise, + ) + } + + if c.AvroBigintUnsignedHandlingMode != BigintUnsignedHandlingModeLong && + c.AvroBigintUnsignedHandlingMode != BigintUnsignedHandlingModeString { + return errors.ErrCodecInvalidConfig.GenWithStack( + `%s value could only be "%s" or "%s"`, + codecOPTAvroBigintUnsignedHandlingMode, + BigintUnsignedHandlingModeLong, + BigintUnsignedHandlingModeString, + ) + } } if c.MaxMessageBytes <= 0 { diff --git a/pkg/sink/codec/common/config_test.go b/pkg/sink/codec/common/config_test.go index 682504c9d4..474ed689d6 100644 --- a/pkg/sink/codec/common/config_test.go +++ b/pkg/sink/codec/common/config_test.go @@ -20,14 +20,21 @@ import ( "github.com/stretchr/testify/require" ) -func TestDebeziumSchemaRegistryConfig(t *testing.T) { +func TestDebeziumAvroSchemaRegistryConfig(t *testing.T) { t.Parallel() - cfg := NewConfig(config.ProtocolDebezium) + cfg := NewConfig(config.ProtocolDebeziumAvro) cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" require.NoError(t, cfg.Validate()) - cfg = NewConfig(config.ProtocolDebezium) + cfg = NewConfig(config.ProtocolDebeziumAvro) + require.ErrorContains(t, cfg.Validate(), `Debezium Avro protocol requires parameter "schema-registry"`) + + cfg = NewConfig(config.ProtocolDebeziumAvro) cfg.AvroGlueSchemaRegistry = &config.GlueSchemaRegistryConfig{} - require.ErrorContains(t, cfg.Validate(), `Debezium protocol only supports "schema-registry"`) + require.ErrorContains(t, cfg.Validate(), `Debezium Avro protocol only supports "schema-registry"`) + + cfg = NewConfig(config.ProtocolDebezium) + cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" + require.ErrorContains(t, cfg.Validate(), `Debezium protocol does not support schema registry`) } diff --git a/pkg/sink/codec/debezium/avro.go b/pkg/sink/codec/debezium/avro.go index 8b4326fb81..42c6e2c122 100644 --- a/pkg/sink/codec/debezium/avro.go +++ b/pkg/sink/codec/debezium/avro.go @@ -18,6 +18,8 @@ import ( "context" "encoding/base64" "encoding/json" + "math/big" + "strconv" "strings" "github.com/linkedin/goavro/v2" @@ -32,6 +34,7 @@ const ( debeziumAvroConnectFieldKey = "connect.field" debeziumAvroTiDBTypeKey = "tidb_type" + debeziumAvroDecimalName = "org.apache.kafka.connect.data.Decimal" ) type debeziumAvroMessage struct { @@ -267,6 +270,21 @@ func (c *debeziumAvroSchemaConverter) toAvroSchema( addConnectMetadata(arraySchema, schema) return arraySchema, nil default: + if isDebeziumAvroDecimalSchema(schema) { + precision, scale, err := decimalSchemaPrecisionAndScale(schema) + if err != nil { + return nil, err + } + decimalSchema := map[string]any{ + "type": "bytes", + "logicalType": "decimal", + "precision": precision, + "scale": scale, + } + addConnectMetadata(decimalSchema, schema) + return decimalSchema, nil + } + avroType, err := connectPrimitiveToAvro(schema.Type) if err != nil { return nil, err @@ -360,6 +378,18 @@ func (c *debeziumAvroSchemaConverter) toNative( } return v, nil case "bytes": + if isDebeziumAvroDecimalSchema(schema) { + v, ok := value.(string) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("decimal payload is invalid") + } + rat, ok := new(big.Rat).SetString(v) + if !ok { + return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("decimal payload is invalid") + } + return rat, nil + } + v, ok := value.(string) if !ok { return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("bytes payload is invalid") @@ -411,6 +441,25 @@ func connectPrimitiveToAvro(connectType string) (string, error) { } } +func isDebeziumAvroDecimalSchema(schema *debeziumConnectSchema) bool { + return schema.Type == "bytes" && schema.Name == debeziumAvroDecimalName +} + +func decimalSchemaPrecisionAndScale(schema *debeziumConnectSchema) (int, int, error) { + if schema.Parameters == nil { + return 0, 0, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("decimal schema is missing parameters") + } + precision, err := strconv.Atoi(schema.Parameters["precision"]) + if err != nil { + return 0, 0, errors.WrapError(errors.ErrDebeziumInvalidMessage, err) + } + scale, err := strconv.Atoi(schema.Parameters["scale"]) + if err != nil { + return 0, 0, errors.WrapError(errors.ErrDebeziumInvalidMessage, err) + } + return precision, scale, nil +} + func addConnectMetadata(avroSchema map[string]any, schema *debeziumConnectSchema) { if schema.Name != "" { avroSchema["connect.name"] = schema.Name @@ -450,6 +499,10 @@ func avroFieldName(field string) string { } func avroUnionBranchName(schema *debeziumConnectSchema, fallbackName string) string { + if isDebeziumAvroDecimalSchema(schema) { + return "bytes.decimal" + } + switch schema.Type { case "struct": return avroFullName(schema.Name, fallbackName) diff --git a/pkg/sink/codec/debezium/avro_decoder.go b/pkg/sink/codec/debezium/avro_decoder.go index 06bcb6a0e9..8c91ec813d 100644 --- a/pkg/sink/codec/debezium/avro_decoder.go +++ b/pkg/sink/codec/debezium/avro_decoder.go @@ -19,6 +19,7 @@ import ( "encoding/binary" "encoding/json" "io" + "math/big" "net/http" "strconv" "strings" @@ -275,6 +276,9 @@ func avroNativeToConnectPayload(schema any, value any, namedSchemas map[string]a if !exists && connectFieldName != avroFieldName { rawValue, exists = valueMap[connectFieldName] } + if !exists { + rawValue, exists = avroMissingFieldValue(field) + } if !exists { return nil, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs( "avro record payload is missing field " + avroFieldName) @@ -311,6 +315,11 @@ func avroNativeToConnectPayload(schema any, value any, namedSchemas map[string]a result = append(result, itemValue) } return result, nil + case "bytes": + if avroSchemaIsDecimal(typedSchema) { + return avroDecimalNativeToString(typedSchema, value) + } + return value, nil default: return value, nil } @@ -440,6 +449,10 @@ func collectAvroNamedSchemas(schema any, namedSchemas map[string]any) { name := avroBranchName(typedSchema) if name != "" { namedSchemas[name] = typedSchema + shortName := avroShortBranchName(name) + if _, exists := namedSchemas[shortName]; shortName != "" && !exists { + namedSchemas[shortName] = typedSchema + } } fields, _ := typedSchema["fields"].([]any) for _, rawField := range fields { @@ -527,6 +540,50 @@ func avroPrimitiveToConnectType(avroType string, schemaMeta map[string]any) (str } } +func avroSchemaIsDecimal(schema map[string]any) bool { + typeName, _ := schema["type"].(string) + logicalType, _ := schema["logicalType"].(string) + return typeName == "bytes" && logicalType == "decimal" +} + +func avroDecimalNativeToString(schema map[string]any, value any) (string, error) { + scale, err := avroDecimalScale(schema) + if err != nil { + return "", err + } + switch v := value.(type) { + case *big.Rat: + return v.FloatString(scale), nil + case big.Rat: + return v.FloatString(scale), nil + case string: + return v, nil + default: + return "", errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("decimal payload is invalid") + } +} + +func avroDecimalScale(schema map[string]any) (int, error) { + switch scale := schema["scale"].(type) { + case float64: + return int(scale), nil + case int: + return scale, nil + case int32: + return int(scale), nil + case int64: + return int(scale), nil + case json.Number: + value, err := scale.Int64() + if err != nil { + return 0, errors.WrapError(errors.ErrDebeziumInvalidMessage, err) + } + return int(value), nil + default: + return 0, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("decimal schema is missing scale") + } +} + func avroUnionBranch(union []any, value any) (any, any, error) { if value == nil { return nil, nil, nil @@ -553,7 +610,7 @@ func avroUnionBranch(union []any, value any) (any, any, error) { } if hasWrappedBranch && isSingleNonNullBranch && - avroShortBranchName(branchSchema) == wrappedBranchName { + avroShortBranchName(branchSchema) == avroShortBranchName(wrappedBranchName) { return branchSchema, wrappedBranchValue, nil } return branchSchema, value, nil @@ -593,6 +650,9 @@ func avroBranchName(schema any) string { case "array": return "array" default: + if avroSchemaIsDecimal(typedSchema) { + return "bytes.decimal" + } return typeName } default: @@ -619,6 +679,48 @@ func avroShortBranchName(schema any) string { } } +func avroFieldAllowsMissing(field map[string]any) bool { + if _, hasDefault := field["default"]; hasDefault { + return true + } + return avroSchemaAllowsNull(field["type"]) +} + +func avroMissingFieldValue(field map[string]any) (any, bool) { + if avroFieldAllowsMissing(field) { + return nil, true + } + if avroSchemaIsArray(field["type"]) { + return []any{}, true + } + return nil, false +} + +func avroSchemaIsArray(schema any) bool { + switch typedSchema := schema.(type) { + case map[string]any: + typeName, _ := typedSchema["type"].(string) + return typeName == "array" + case string: + return typedSchema == "array" + default: + return false + } +} + +func avroSchemaAllowsNull(schema any) bool { + union, ok := schema.([]any) + if !ok { + return false + } + for _, branch := range union { + if avroBranchName(branch) == "null" { + return true + } + } + return false +} + func avroConnectFieldName(field map[string]any, fallback string) string { if fieldName, ok := field[debeziumAvroConnectFieldKey].(string); ok && fieldName != "" { return fieldName diff --git a/pkg/sink/codec/debezium/avro_test.go b/pkg/sink/codec/debezium/avro_test.go index 3f600f1526..45680ff413 100644 --- a/pkg/sink/codec/debezium/avro_test.go +++ b/pkg/sink/codec/debezium/avro_test.go @@ -28,7 +28,6 @@ import ( "github.com/pingcap/ticdc/pkg/config" "github.com/pingcap/ticdc/pkg/sink/codec/avro" "github.com/pingcap/ticdc/pkg/sink/codec/common" - timodel "github.com/pingcap/tidb/pkg/meta/model" "github.com/stretchr/testify/require" ) @@ -45,16 +44,21 @@ func TestDebeziumConfluentAvroEncodeRowEvent(t *testing.T) { create table foo( id int primary key, name varchar(16), + bin varbinary(16), + price decimal(10, 4), + ubig bigint unsigned, v bigint null )`) defer helper.Close() - dmls := helper.helper.DML2Event("test", "foo", "insert into foo values (1, 'alice', null)") + dmls := helper.helper.DML2Event("test", "foo", + "insert into foo values (1, 'alice', x'010203', 12.3400, 18446744073709551615, null)") row, ok := dmls.GetNextRow() require.True(t, ok) - cfg := common.NewConfig(config.ProtocolDebezium) + cfg := common.NewConfig(config.ProtocolDebeziumAvro) cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" + cfg.AvroBigintUnsignedHandlingMode = common.BigintUnsignedHandlingModeString cfg.DebeziumDisableSchema = true cfg.TimeZone = time.UTC @@ -74,29 +78,38 @@ func TestDebeziumConfluentAvroEncodeRowEvent(t *testing.T) { require.Equal(t, byte(0), messages[0].Value[0]) key := decodeConfluentAvroForTest(t, messages[0].Key) - require.Equal(t, int32(1), key["id"]) + require.Equal(t, int32(1), unwrapAvroUnionForTest(t, key["id"], "int")) value := decodeConfluentAvroForTest(t, messages[0].Value) require.Equal(t, "c", value["op"]) require.Nil(t, value["before"]) + require.NotContains(t, value, "transaction") + require.IsType(t, int64(0), value["ts_ms"]) afterUnion, ok := value["after"].(map[string]any) require.True(t, ok) - after, ok := afterUnion["dbserver1.test.foo.Value"].(map[string]any) + after, ok := afterUnion["dbserver1.test.foo"].(map[string]any) require.True(t, ok) - require.Equal(t, int32(1), after["id"]) - name, ok := after["name"].(map[string]any) - require.True(t, ok) - require.Equal(t, "alice", name["string"]) + require.Equal(t, int32(1), unwrapAvroUnionForTest(t, after["id"], "int")) + require.Equal(t, "alice", unwrapAvroUnionForTest(t, after["name"], "string")) + require.Equal(t, []byte{1, 2, 3}, unwrapAvroUnionForTest(t, after["bin"], "bytes")) + require.Equal(t, "18446744073709551615", unwrapAvroUnionForTest(t, after["ubig"], "string")) require.Nil(t, after["v"]) source, ok := value["source"].(map[string]any) require.True(t, ok) require.Equal(t, "test", source["db"]) - table, ok := source["table"].(map[string]any) - require.True(t, ok) - require.Equal(t, "foo", table["string"]) + require.Equal(t, "foo", source["table"]) + require.Nil(t, source["snapshot"]) + require.Nil(t, source["thread"]) require.Equal(t, "dbserver1", source["name"]) + + valueSchema := decodeConfluentAvroSchemaForTest(t, messages[0].Value) + require.Contains(t, valueSchema, `"name":"fooEnvelope"`) + require.Contains(t, valueSchema, `"name":"foo"`) + require.Contains(t, valueSchema, `"name":"Source"`) + require.Contains(t, valueSchema, `"logicalType":"decimal"`) + require.NotContains(t, valueSchema, `"field":"transaction"`) } func TestDebeziumConfluentAvroDecodeRowEvent(t *testing.T) { @@ -112,24 +125,30 @@ func TestDebeziumConfluentAvroDecodeRowEvent(t *testing.T) { create table foo( id int primary key, name varchar(16), + bin varbinary(16), + price decimal(10, 4), + ubig bigint unsigned, v bigint null )`) defer helper.Close() - dmls := helper.helper.DML2Event("test", "foo", "insert into foo values (1, 'alice', null)") + dmls := helper.helper.DML2Event("test", "foo", + "insert into foo values (1, 'alice', x'010203', 12.3400, 18446744073709551615, null)") row, ok := dmls.GetNextRow() require.True(t, ok) - cfg := common.NewConfig(config.ProtocolDebezium) + cfg := common.NewConfig(config.ProtocolDebeziumAvro) cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" + cfg.AvroBigintUnsignedHandlingMode = common.BigintUnsignedHandlingModeString cfg.EnableTiDBExtension = true cfg.TimeZone = time.UTC + commitTs := uint64(123) encoder, err := NewAvroBatchEncoder(ctx, cfg, "dbserver1") require.NoError(t, err) require.NoError(t, encoder.AppendRowChangedEvent(ctx, "dbserver1.test.foo", &commonEvent.RowEvent{ TableInfo: helper.tableInfo, - CommitTs: 1, + CommitTs: commitTs, Event: row, ColumnSelector: columnselector.NewDefaultColumnSelector(), Callback: func() {}, @@ -147,6 +166,7 @@ func TestDebeziumConfluentAvroDecodeRowEvent(t *testing.T) { require.Equal(t, common.MessageTypeRow, messageType) decoded := decoder.NextDMLEvent() + require.Equal(t, commitTs, decoded.CommitTs) require.Equal(t, "test", decoded.TableInfo.GetSchemaName()) require.Equal(t, "foo", decoded.TableInfo.GetTableName()) @@ -180,7 +200,7 @@ func TestDebeziumConfluentAvroDecodeAccountDMLEvents(t *testing.T) { "insert into tp_account values (14, 34)", "delete from tp_account where id = 14") - cfg := common.NewConfig(config.ProtocolDebezium) + cfg := common.NewConfig(config.ProtocolDebeziumAvro) cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" cfg.EnableTiDBExtension = true cfg.TimeZone = time.UTC @@ -265,6 +285,56 @@ func TestDebeziumConfluentAvroDecodeShortNamedUnionBranch(t *testing.T) { }, payload) } +func TestDebeziumConfluentAvroDecodeFullNamedWrapperForShortUnionBranch(t *testing.T) { + valueSchema := map[string]any{ + "type": "record", + "name": "Value", + "namespace": "default.test.tp_account", + "fields": []any{ + map[string]any{ + "name": "id", + "type": "int", + debeziumAvroConnectFieldKey: "id", + }, + map[string]any{ + "name": "account_id", + "type": "int", + debeziumAvroConnectFieldKey: "account_id", + }, + }, + } + envelopeSchema := map[string]any{ + "type": "record", + "name": "Envelope", + "namespace": "default.test.tp_account", + "fields": []any{ + map[string]any{ + "name": "after", + "type": []any{"null", "Value"}, + }, + }, + } + namedSchemas := map[string]any{} + collectAvroNamedSchemas(valueSchema, namedSchemas) + collectAvroNamedSchemas(envelopeSchema, namedSchemas) + + payload, err := avroNativeToConnectPayload( + []any{"null", "Value"}, + map[string]any{ + "default.test.tp_account.Value": map[string]any{ + "id": int32(12), + "account_id": int32(34), + }, + }, + namedSchemas, + ) + require.NoError(t, err) + require.Equal(t, map[string]any{ + "id": int32(12), + "account_id": int32(34), + }, payload) +} + func TestDebeziumConfluentAvroDecodeSingleFieldUnionRecord(t *testing.T) { valueSchema := map[string]any{ "type": "record", @@ -325,7 +395,69 @@ func TestDebeziumConfluentAvroDecodeMissingRecordField(t *testing.T) { require.Contains(t, err.Error(), "avro record payload is missing field account_id") } -func TestDebeziumConfluentAvroDecodeDDLEvent(t *testing.T) { +func TestDebeziumConfluentAvroDecodeMissingOptionalRecordField(t *testing.T) { + valueSchema := map[string]any{ + "type": "record", + "name": "Table", + "namespace": "io.debezium.connector.schema", + "fields": []any{ + map[string]any{ + "name": "defaultCharsetName", + "type": []any{"null", "string"}, + "default": nil, + }, + map[string]any{ + "name": "columns", + "type": map[string]any{ + "type": "array", + "items": "string", + }, + }, + }, + } + + payload, err := avroNativeToConnectPayload( + valueSchema, + map[string]any{ + "columns": []any{"id"}, + }, + nil, + ) + require.NoError(t, err) + require.Equal(t, map[string]any{ + "defaultCharsetName": nil, + "columns": []any{"id"}, + }, payload) +} + +func TestDebeziumConfluentAvroDecodeMissingArrayRecordField(t *testing.T) { + valueSchema := map[string]any{ + "type": "record", + "name": "Table", + "namespace": "io.debezium.connector.schema", + "fields": []any{ + map[string]any{ + "name": "columns", + "type": map[string]any{ + "type": "array", + "items": "string", + }, + }, + }, + } + + payload, err := avroNativeToConnectPayload( + valueSchema, + map[string]any{}, + nil, + ) + require.NoError(t, err) + require.Equal(t, map[string]any{ + "columns": []any{}, + }, payload) +} + +func TestDebeziumConfluentAvroDoesNotEncodeDDLEvent(t *testing.T) { ctx := context.Background() _, err := avro.SetupEncoderAndSchemaRegistry4Testing( ctx, @@ -334,7 +466,7 @@ func TestDebeziumConfluentAvroDecodeDDLEvent(t *testing.T) { require.NoError(t, err) defer avro.TeardownEncoderAndSchemaRegistry4Testing() - cfg := common.NewConfig(config.ProtocolDebezium) + cfg := common.NewConfig(config.ProtocolDebeziumAvro) cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" cfg.EnableTiDBExtension = true cfg.TimeZone = time.UTC @@ -345,23 +477,10 @@ func TestDebeziumConfluentAvroDecodeDDLEvent(t *testing.T) { routedDDL := common.NewRoutedDDLEvent4Test() message, err := encoder.EncodeDDLEvent(routedDDL) require.NoError(t, err) - require.NotNil(t, message) - - decoder, err := NewAvroDecoder(ctx, cfg, 0, nil) - require.NoError(t, err) - decoder.AddKeyValue(message.Key, message.Value) - - messageType, hasNext := decoder.HasNext() - require.True(t, hasNext) - require.Equal(t, common.MessageTypeDDL, messageType) - - decoded := decoder.NextDDLEvent() - require.Equal(t, "target_db", decoded.SchemaName) - require.Equal(t, "target_table", decoded.TableName) - require.Equal(t, routedDDL.Query, decoded.Query) + require.Nil(t, message) } -func TestDebeziumConfluentAvroDecodeSchemaDDLEvent(t *testing.T) { +func TestDebeziumConfluentAvroDoesNotEncodeCheckpointEvent(t *testing.T) { ctx := context.Background() _, err := avro.SetupEncoderAndSchemaRegistry4Testing( ctx, @@ -370,40 +489,43 @@ func TestDebeziumConfluentAvroDecodeSchemaDDLEvent(t *testing.T) { require.NoError(t, err) defer avro.TeardownEncoderAndSchemaRegistry4Testing() - cfg := common.NewConfig(config.ProtocolDebezium) + cfg := common.NewConfig(config.ProtocolDebeziumAvro) cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" cfg.EnableTiDBExtension = true + cfg.AvroEnableWatermark = true cfg.TimeZone = time.UTC encoder, err := NewAvroBatchEncoder(ctx, cfg, "dbserver1") require.NoError(t, err) - ddl := &commonEvent.DDLEvent{ - Version: commonEvent.DDLEventVersion1, - Type: byte(timodel.ActionCreateSchema), - SchemaName: "test", - Query: "CREATE DATABASE `test`", - FinishedTs: 100, - } - message, err := encoder.EncodeDDLEvent(ddl) + message, err := encoder.EncodeCheckpointEvent(100) require.NoError(t, err) - require.NotNil(t, message) + require.Nil(t, message) +} - decoder, err := NewAvroDecoder(ctx, cfg, 0, nil) +func decodeConfluentAvroForTest(t *testing.T, data []byte) map[string]any { + t.Helper() + + schema, binaryData := decodeConfluentAvroEnvelopeForTest(t, data) + codec, err := avro.GenCodec(schema) require.NoError(t, err) - decoder.AddKeyValue(message.Key, message.Value) - messageType, hasNext := decoder.HasNext() - require.True(t, hasNext) - require.Equal(t, common.MessageTypeDDL, messageType) + native, _, err := codec.NativeFromBinary(binaryData) + require.NoError(t, err) + + result, ok := native.(map[string]any) + require.True(t, ok) + return result +} - decoded := decoder.NextDDLEvent() - require.Equal(t, "test", decoded.SchemaName) - require.Empty(t, decoded.TableName) - require.Equal(t, ddl.Query, decoded.Query) +func decodeConfluentAvroSchemaForTest(t *testing.T, data []byte) string { + t.Helper() + + schema, _ := decodeConfluentAvroEnvelopeForTest(t, data) + return schema } -func decodeConfluentAvroForTest(t *testing.T, data []byte) map[string]any { +func decodeConfluentAvroEnvelopeForTest(t *testing.T, data []byte) (string, []byte) { t.Helper() require.GreaterOrEqual(t, len(data), 5) @@ -424,13 +546,15 @@ func decodeConfluentAvroForTest(t *testing.T, data []byte) map[string]any { } require.NoError(t, json.Unmarshal(body, &schemaResp)) - codec, err := avro.GenCodec(schemaResp.Schema) - require.NoError(t, err) + return schemaResp.Schema, binaryData +} - native, _, err := codec.NativeFromBinary(binaryData) - require.NoError(t, err) +func unwrapAvroUnionForTest(t *testing.T, value any, branch string) any { + t.Helper() - result, ok := native.(map[string]any) + union, ok := value.(map[string]any) + require.True(t, ok) + result, ok := union[branch] require.True(t, ok) return result } diff --git a/pkg/sink/codec/debezium/codec.go b/pkg/sink/codec/debezium/codec.go index 81707b6a99..4346f8c1e7 100644 --- a/pkg/sink/codec/debezium/codec.go +++ b/pkg/sink/codec/debezium/codec.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/log" commonType "github.com/pingcap/ticdc/pkg/common" commonEvent "github.com/pingcap/ticdc/pkg/common/event" + "github.com/pingcap/ticdc/pkg/config" "github.com/pingcap/ticdc/pkg/errors" "github.com/pingcap/ticdc/pkg/sink/codec/common" "github.com/pingcap/ticdc/pkg/util" @@ -42,6 +43,73 @@ type dbzCodec struct { nowFunc func() time.Time } +func (c *dbzCodec) isDebeziumAvro() bool { + return c.config.Protocol == config.ProtocolDebeziumAvro +} + +func (c *dbzCodec) debeziumAvroNamespace(schema string) string { + return fmt.Sprintf("%s.%s", + common.SanitizeName(c.clusterID), + common.SanitizeName(schema)) +} + +func (c *dbzCodec) debeziumAvroTableName(table string) string { + return common.SanitizeName(table) +} + +func (c *dbzCodec) keySchemaName(schema string, table string) string { + if c.isDebeziumAvro() { + return fmt.Sprintf("%s.%sKey", + c.debeziumAvroNamespace(schema), + c.debeziumAvroTableName(table)) + } + return fmt.Sprintf("%s.Key", getSchemaTopicName(c.clusterID, schema, table)) +} + +func (c *dbzCodec) envelopeSchemaName(schema string, table string) string { + if c.isDebeziumAvro() { + return fmt.Sprintf("%s.%sEnvelope", + c.debeziumAvroNamespace(schema), + c.debeziumAvroTableName(table)) + } + return fmt.Sprintf("%s.Envelope", getSchemaTopicName(c.clusterID, schema, table)) +} + +func (c *dbzCodec) valueSchemaName(schema string, table string) string { + if c.isDebeziumAvro() { + return fmt.Sprintf("%s.%s", + c.debeziumAvroNamespace(schema), + c.debeziumAvroTableName(table)) + } + return fmt.Sprintf("%s.Value", getSchemaTopicName(c.clusterID, schema, table)) +} + +func (c *dbzCodec) sourceSchemaName(schema string) string { + if c.isDebeziumAvro() { + return fmt.Sprintf("%s.Source", c.debeziumAvroNamespace(schema)) + } + return "io.debezium.connector.mysql.Source" +} + +func decimalPrecisionAndScale(ft *types.FieldType) (int, int) { + defaultPrecision, defaultScale := mysql.GetDefaultFieldLengthAndDecimal(ft.GetType()) + precision, scale := ft.GetFlen(), ft.GetDecimal() + if precision == -1 { + precision = defaultPrecision + } + if scale == -1 { + scale = defaultScale + } + return precision, scale +} + +func (c *dbzCodec) columnOptional(ft *types.FieldType) bool { + if c.isDebeziumAvro() { + return true + } + return !mysql.HasNotNullFlag(ft.GetFlag()) +} + func (c *dbzCodec) writeDebeziumFieldValues( writer *util.JSONWriter, fieldName string, @@ -119,14 +187,14 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } if n == 1 { writer.WriteStringField("type", "boolean") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("field", colName) if col.GetDefaultValue() != nil { writer.WriteBoolField("default", v != 0) // bool } } else { writer.WriteStringField("type", "bytes") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("name", "io.debezium.data.Bits") writer.WriteIntField("version", 1) writer.WriteObjectField("parameters", func() { @@ -139,15 +207,19 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: - writer.WriteStringField("type", "string") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + if c.isDebeziumAvro() && mysql.HasBinaryFlag(ft.GetFlag()) { + writer.WriteStringField("type", "bytes") + } else { + writer.WriteStringField("type", "string") + } + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("field", colName) if col.GetDefaultValue() != nil { writer.WriteAnyField("default", col.GetDefaultValue()) } case mysql.TypeEnum: writer.WriteStringField("type", "string") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("name", "io.debezium.data.Enum") writer.WriteIntField("version", 1) writer.WriteObjectField("parameters", func() { @@ -164,7 +236,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } case mysql.TypeSet: writer.WriteStringField("type", "string") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("name", "io.debezium.data.EnumSet") writer.WriteIntField("version", 1) writer.WriteObjectField("parameters", func() { @@ -176,7 +248,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } case mysql.TypeDate, mysql.TypeNewDate: writer.WriteStringField("type", "int32") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("name", "io.debezium.time.Date") writer.WriteIntField("version", 1) writer.WriteStringField("field", colName) @@ -206,7 +278,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } case mysql.TypeDatetime: writer.WriteStringField("type", "int64") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) if ft.GetDecimal() <= 3 { writer.WriteStringField("name", "io.debezium.time.Timestamp") } else { @@ -251,7 +323,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } case mysql.TypeTimestamp: writer.WriteStringField("type", "string") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("name", "io.debezium.time.ZonedTimestamp") writer.WriteIntField("version", 1) writer.WriteStringField("field", colName) @@ -293,7 +365,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } case mysql.TypeDuration: writer.WriteStringField("type", "int64") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("name", "io.debezium.time.MicroTime") writer.WriteIntField("version", 1) writer.WriteStringField("field", colName) @@ -310,7 +382,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } case mysql.TypeJSON: writer.WriteStringField("type", "string") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("name", "io.debezium.data.Json") writer.WriteIntField("version", 1) writer.WriteStringField("field", colName) @@ -319,7 +391,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } case mysql.TypeTiny: // TINYINT writer.WriteStringField("type", "int16") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("field", colName) if col.GetDefaultValue() != nil { v, ok := col.GetDefaultValue().(string) @@ -338,7 +410,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } else { writer.WriteStringField("type", "int16") } - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("field", colName) if col.GetDefaultValue() != nil { v, ok := col.GetDefaultValue().(string) @@ -353,7 +425,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } case mysql.TypeInt24: // MEDIUMINT writer.WriteStringField("type", "int32") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("field", colName) if col.GetDefaultValue() != nil { v, ok := col.GetDefaultValue().(string) @@ -372,7 +444,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } else { writer.WriteStringField("type", "int32") } - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("field", colName) if col.GetDefaultValue() != nil { v, ok := col.GetDefaultValue().(string) @@ -386,8 +458,14 @@ func (c *dbzCodec) writeDebeziumFieldSchema( writer.WriteFloat64Field("default", floatV) } case mysql.TypeLonglong: // BIGINT - writer.WriteStringField("type", "int64") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + if c.isDebeziumAvro() && + mysql.HasUnsignedFlag(ft.GetFlag()) && + c.config.AvroBigintUnsignedHandlingMode == common.BigintUnsignedHandlingModeString { + writer.WriteStringField("type", "string") + } else { + writer.WriteStringField("type", "int64") + } + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("field", colName) if col.GetDefaultValue() != nil { v, ok := col.GetDefaultValue().(string) @@ -406,7 +484,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } else { writer.WriteStringField("type", "float") } - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("field", colName) if col.GetDefaultValue() != nil { v, ok := col.GetDefaultValue().(string) @@ -419,11 +497,11 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } writer.WriteFloat64Field("default", floatV) } - case mysql.TypeDouble, mysql.TypeNewDecimal: + case mysql.TypeDouble: // https://dev.mysql.com/doc/refman/8.4/en/numeric-types.html // MySQL also treats REAL as a synonym for DOUBLE PRECISION (a nonstandard variation), unless the REAL_AS_FLOAT SQL mode is enabled. writer.WriteStringField("type", "double") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("field", colName) if col.GetDefaultValue() != nil { v, ok := col.GetDefaultValue().(string) @@ -436,9 +514,42 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } writer.WriteFloat64Field("default", floatV) } + case mysql.TypeNewDecimal: + if c.isDebeziumAvro() && + c.config.AvroDecimalHandlingMode == common.DecimalHandlingModePrecise { + precision, scale := decimalPrecisionAndScale(ft) + writer.WriteStringField("type", "bytes") + writer.WriteStringField("name", "org.apache.kafka.connect.data.Decimal") + writer.WriteObjectField("parameters", func() { + writer.WriteStringField("precision", strconv.Itoa(precision)) + writer.WriteStringField("scale", strconv.Itoa(scale)) + }) + } else if c.isDebeziumAvro() && + c.config.AvroDecimalHandlingMode == common.DecimalHandlingModeString { + writer.WriteStringField("type", "string") + } else { + writer.WriteStringField("type", "double") + } + writer.WriteBoolField("optional", c.columnOptional(ft)) + writer.WriteStringField("field", colName) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + if c.isDebeziumAvro() { + writer.WriteStringField("default", v) + return + } + floatV, err := strconv.ParseFloat(v, 64) + if err != nil { + return + } + writer.WriteFloat64Field("default", floatV) + } case mysql.TypeYear: writer.WriteStringField("type", "int32") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("name", "io.debezium.time.Year") writer.WriteIntField("version", 1) writer.WriteStringField("field", colName) @@ -462,7 +573,7 @@ func (c *dbzCodec) writeDebeziumFieldSchema( } case mysql.TypeTiDBVectorFloat32: writer.WriteStringField("type", "string") - writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteBoolField("optional", c.columnOptional(ft)) writer.WriteStringField("name", "io.debezium.data.TiDBVectorFloat32") writer.WriteStringField("field", colName) if col.GetDefaultValue() != nil { @@ -554,6 +665,10 @@ func (c *dbzCodec) writeDebeziumFieldValue( return nil case mysql.TypeNewDecimal: + if c.isDebeziumAvro() { + writer.WriteStringField(colName, datum.GetMysqlDecimal().String()) + return nil + } v, err := datum.GetMysqlDecimal().ToFloat64() if err != nil { return errors.WrapError( @@ -711,6 +826,14 @@ func (c *dbzCodec) writeDebeziumFieldValue( isUnsigned := mysql.HasUnsignedFlag(colInfo.GetFlag()) if isUnsigned { v := datum.GetUint64() + if c.isDebeziumAvro() && ft.GetType() == mysql.TypeLonglong { + if c.config.AvroBigintUnsignedHandlingMode == common.BigintUnsignedHandlingModeString { + writer.WriteStringField(colName, strconv.FormatUint(v, 10)) + } else { + writer.WriteInt64Field(colName, int64(v)) + } + return nil + } if ft.GetType() == mysql.TypeLonglong && v == maxValue.GetUint64() || v > maxValue.GetUint64() { writer.WriteAnyField(colName, -1) } else { @@ -758,7 +881,7 @@ func (c *dbzCodec) writeBinaryField(writer *util.JSONWriter, fieldName string, v writer.WriteBase64StringField(fieldName, value) } -func (c *dbzCodec) writeSourceSchema(writer *util.JSONWriter) { +func (c *dbzCodec) writeSourceSchema(writer *util.JSONWriter, schemaName string) { writer.WriteObjectElement(func() { writer.WriteStringField("type", "struct") writer.WriteArrayField("fields", func() { @@ -785,12 +908,14 @@ func (c *dbzCodec) writeSourceSchema(writer *util.JSONWriter) { writer.WriteObjectElement(func() { writer.WriteStringField("type", "string") writer.WriteBoolField("optional", true) - writer.WriteStringField("name", "io.debezium.data.Enum") - writer.WriteIntField("version", 1) - writer.WriteObjectField("parameters", func() { - writer.WriteStringField("allowed", "true,last,false,incremental") - }) - writer.WriteStringField("default", "false") + if !c.isDebeziumAvro() { + writer.WriteStringField("name", "io.debezium.data.Enum") + writer.WriteIntField("version", 1) + writer.WriteObjectField("parameters", func() { + writer.WriteStringField("allowed", "true,last,false,incremental") + }) + writer.WriteStringField("default", "false") + } writer.WriteStringField("field", "snapshot") }) writer.WriteObjectElement(func() { @@ -798,14 +923,16 @@ func (c *dbzCodec) writeSourceSchema(writer *util.JSONWriter) { writer.WriteBoolField("optional", false) writer.WriteStringField("field", "db") }) + if !c.isDebeziumAvro() { + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", true) + writer.WriteStringField("field", "sequence") + }) + } writer.WriteObjectElement(func() { writer.WriteStringField("type", "string") - writer.WriteBoolField("optional", true) - writer.WriteStringField("field", "sequence") - }) - writer.WriteObjectElement(func() { - writer.WriteStringField("type", "string") - writer.WriteBoolField("optional", true) + writer.WriteBoolField("optional", !c.isDebeziumAvro()) writer.WriteStringField("field", "table") }) writer.WriteObjectElement(func() { @@ -843,7 +970,7 @@ func (c *dbzCodec) writeSourceSchema(writer *util.JSONWriter) { writer.WriteBoolField("optional", true) writer.WriteStringField("field", "query") }) - if c.config.EnableTiDBExtension { + if c.config.EnableTiDBExtension || c.isDebeziumAvro() { writer.WriteObjectElement(func() { writer.WriteStringField("type", "int64") writer.WriteBoolField("optional", false) @@ -857,7 +984,7 @@ func (c *dbzCodec) writeSourceSchema(writer *util.JSONWriter) { } }) writer.WriteBoolField("optional", false) - writer.WriteStringField("name", "io.debezium.connector.mysql.Source") + writer.WriteStringField("name", c.sourceSchemaName(schemaName)) writer.WriteStringField("field", "source") }) } @@ -891,8 +1018,7 @@ func (c *dbzCodec) EncodeKey( if !c.config.DebeziumDisableSchema { jWriter.WriteObjectField("schema", func() { jWriter.WriteStringField("type", "struct") - jWriter.WriteStringField("name", - fmt.Sprintf("%s.Key", getSchemaTopicName(c.clusterID, schemaName, tableName))) + jWriter.WriteStringField("name", c.keySchemaName(schemaName, tableName)) jWriter.WriteBoolField("optional", false) jWriter.WriteArrayField("fields", func() { columns := e.TableInfo.GetColumns() @@ -932,7 +1058,11 @@ func (c *dbzCodec) EncodeValue( // https://debezium.io/documentation/reference/stable/connectors/mysql.html#mysql-create-events jWriter.WriteInt64Field("ts_ms", commitTime.UnixMilli()) // snapshot field is a string of true,last,false,incremental - jWriter.WriteStringField("snapshot", "false") + if c.isDebeziumAvro() { + jWriter.WriteNullField("snapshot") + } else { + jWriter.WriteStringField("snapshot", "false") + } jWriter.WriteStringField("db", schemaName) jWriter.WriteStringField("table", tableName) jWriter.WriteInt64Field("server_id", 0) @@ -940,7 +1070,11 @@ func (c *dbzCodec) EncodeValue( jWriter.WriteStringField("file", "") jWriter.WriteInt64Field("pos", 0) jWriter.WriteInt64Field("row", 0) - jWriter.WriteInt64Field("thread", 0) + if c.isDebeziumAvro() { + jWriter.WriteNullField("thread") + } else { + jWriter.WriteInt64Field("thread", 0) + } jWriter.WriteNullField("query") // The followings are TiDB extended fields @@ -951,7 +1085,9 @@ func (c *dbzCodec) EncodeValue( // ts_ms: displays the time at which the connector processed the event // https://debezium.io/documentation/reference/stable/connectors/mysql.html#mysql-create-events jWriter.WriteInt64Field("ts_ms", c.nowFunc().UnixMilli()) - jWriter.WriteNullField("transaction") + if !c.isDebeziumAvro() { + jWriter.WriteNullField("transaction") + } if e.IsInsert() { // op: Mandatory string that describes the type of operation that caused the connector to generate the event. // Valid values are: @@ -991,8 +1127,7 @@ func (c *dbzCodec) EncodeValue( jWriter.WriteObjectField("schema", func() { jWriter.WriteStringField("type", "struct") jWriter.WriteBoolField("optional", false) - jWriter.WriteStringField("name", - fmt.Sprintf("%s.Envelope", getSchemaTopicName(c.clusterID, schemaName, tableName))) + jWriter.WriteStringField("name", c.envelopeSchemaName(schemaName, tableName)) jWriter.WriteIntField("version", 1) jWriter.WriteArrayField("fields", func() { // schema is the same for `before` and `after`. So we build a new buffer to @@ -1021,8 +1156,7 @@ func (c *dbzCodec) EncodeValue( jWriter.WriteObjectElement(func() { jWriter.WriteStringField("type", "struct") jWriter.WriteBoolField("optional", true) - jWriter.WriteStringField("name", - fmt.Sprintf("%s.Value", getSchemaTopicName(c.clusterID, schemaName, tableName))) + jWriter.WriteStringField("name", c.valueSchemaName(schemaName, tableName)) jWriter.WriteStringField("field", "before") jWriter.WriteArrayField("fields", func() { jWriter.WriteRaw(fieldsJSON) @@ -1031,14 +1165,13 @@ func (c *dbzCodec) EncodeValue( jWriter.WriteObjectElement(func() { jWriter.WriteStringField("type", "struct") jWriter.WriteBoolField("optional", true) - jWriter.WriteStringField("name", - fmt.Sprintf("%s.Value", getSchemaTopicName(c.clusterID, schemaName, tableName))) + jWriter.WriteStringField("name", c.valueSchemaName(schemaName, tableName)) jWriter.WriteStringField("field", "after") jWriter.WriteArrayField("fields", func() { jWriter.WriteRaw(fieldsJSON) }) }) - c.writeSourceSchema(jWriter) + c.writeSourceSchema(jWriter, schemaName) jWriter.WriteObjectElement(func() { jWriter.WriteStringField("type", "string") jWriter.WriteBoolField("optional", false) @@ -1046,33 +1179,35 @@ func (c *dbzCodec) EncodeValue( }) jWriter.WriteObjectElement(func() { jWriter.WriteStringField("type", "int64") - jWriter.WriteBoolField("optional", true) + jWriter.WriteBoolField("optional", !c.isDebeziumAvro()) jWriter.WriteStringField("field", "ts_ms") }) - jWriter.WriteObjectElement(func() { - jWriter.WriteStringField("type", "struct") - jWriter.WriteArrayField("fields", func() { - jWriter.WriteObjectElement(func() { - jWriter.WriteStringField("type", "string") - jWriter.WriteBoolField("optional", false) - jWriter.WriteStringField("field", "id") - }) - jWriter.WriteObjectElement(func() { - jWriter.WriteStringField("type", "int64") - jWriter.WriteBoolField("optional", false) - jWriter.WriteStringField("field", "total_order") - }) - jWriter.WriteObjectElement(func() { - jWriter.WriteStringField("type", "int64") - jWriter.WriteBoolField("optional", false) - jWriter.WriteStringField("field", "data_collection_order") + if !c.isDebeziumAvro() { + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "struct") + jWriter.WriteArrayField("fields", func() { + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "string") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "id") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "int64") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "total_order") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "int64") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "data_collection_order") + }) }) + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("name", "event.block") + jWriter.WriteIntField("version", 1) + jWriter.WriteStringField("field", "transaction") }) - jWriter.WriteBoolField("optional", true) - jWriter.WriteStringField("name", "event.block") - jWriter.WriteIntField("version", 1) - jWriter.WriteStringField("field", "transaction") - }) + } }) }) } @@ -1324,7 +1459,7 @@ func (c *dbzCodec) EncodeDDLEvent( jWriter.WriteIntField("version", 1) jWriter.WriteStringField("name", "io.debezium.connector.mysql.SchemaChangeValue") jWriter.WriteArrayField("fields", func() { - c.writeSourceSchema(jWriter) + c.writeSourceSchema(jWriter, dbName) jWriter.WriteObjectElement(func() { jWriter.WriteStringField("field", "ts_ms") jWriter.WriteBoolField("optional", false) @@ -1563,7 +1698,7 @@ func (c *dbzCodec) EncodeCheckpointEvent( fmt.Sprintf("%s.%s.Envelope", common.SanitizeName(c.clusterID), "watermark")) jWriter.WriteIntField("version", 1) jWriter.WriteArrayField("fields", func() { - c.writeSourceSchema(jWriter) + c.writeSourceSchema(jWriter, "") jWriter.WriteObjectElement(func() { jWriter.WriteStringField("type", "string") jWriter.WriteBoolField("optional", false) diff --git a/pkg/sink/codec/debezium/decoder.go b/pkg/sink/codec/debezium/decoder.go index d9ddab35c6..7dc91f377a 100644 --- a/pkg/sink/codec/debezium/decoder.go +++ b/pkg/sink/codec/debezium/decoder.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "strconv" "strings" "time" @@ -346,6 +347,13 @@ func decodeColumn(value interface{}, colInfo *timodel.ColumnInfo, timeZone *time } value = types.NewDuration(0, 0, 0, int(val), types.MaxFsp) case mysql.TypeLonglong, mysql.TypeLong, mysql.TypeInt24, mysql.TypeShort, mysql.TypeTiny: + if strVal, ok := value.(string); ok && mysql.HasUnsignedFlag(colInfo.GetFlag()) { + uintVal, err := strconv.ParseUint(strVal, 10, 64) + if err != nil { + log.Panic("decode value failed", zap.Error(err), zap.String("value", util.RedactAny(value))) + } + return uintVal + } var intVal int64 intVal, err = value.(json.Number).Int64() if err != nil { @@ -371,6 +379,14 @@ func decodeColumn(value interface{}, colInfo *timodel.ColumnInfo, timeZone *time return types.NewBinaryLiteralFromUint(uint64(0), -1) } case mysql.TypeNewDecimal: + if strVal, ok := value.(string); ok { + dec := new(types.MyDecimal) + err = dec.FromString([]byte(strVal)) + if err != nil { + log.Panic("decode value failed", zap.Error(err), zap.String("value", util.RedactAny(value))) + } + return dec + } var f64 float64 f64, err = value.(json.Number).Float64() if err != nil { diff --git a/pkg/sink/codec/debezium/encoder.go b/pkg/sink/codec/debezium/encoder.go index 80deca2d44..683d8ff915 100644 --- a/pkg/sink/codec/debezium/encoder.go +++ b/pkg/sink/codec/debezium/encoder.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/log" commonEvent "github.com/pingcap/ticdc/pkg/common/event" + "github.com/pingcap/ticdc/pkg/config" "github.com/pingcap/ticdc/pkg/errors" "github.com/pingcap/ticdc/pkg/sink/codec/avro" "github.com/pingcap/ticdc/pkg/sink/codec/common" @@ -38,6 +39,9 @@ type BatchEncoder struct { // EncodeCheckpointEvent implements the RowEventEncoder interface func (d *BatchEncoder) EncodeCheckpointEvent(ts uint64) (*common.Message, error) { + if d.config.Protocol == config.ProtocolDebeziumAvro { + return nil, nil + } if !d.config.EnableTiDBExtension { return nil, nil } @@ -109,6 +113,9 @@ func (d *BatchEncoder) AppendRowChangedEvent( // EncodeDDLEvent implements the RowEventEncoder interface // DDL message unresolved tso func (d *BatchEncoder) EncodeDDLEvent(e *commonEvent.DDLEvent) (*common.Message, error) { + if d.config.Protocol == config.ProtocolDebeziumAvro { + return nil, nil + } valueBuf := bytes.Buffer{} keyMap := bytes.Buffer{} err := d.codec.EncodeDDLEvent(e, &keyMap, &valueBuf) diff --git a/tests/integration_tests/debezium_avro/data/ddl.sql b/tests/integration_tests/debezium_avro/data/ddl.sql new file mode 100644 index 0000000000..fce820e74f --- /dev/null +++ b/tests/integration_tests/debezium_avro/data/ddl.sql @@ -0,0 +1,3 @@ +USE test; + +ALTER TABLE tp_account ADD COLUMN note VARCHAR(32) NULL; diff --git a/tests/integration_tests/debezium_avro/data/post_ddl_workload.sql b/tests/integration_tests/debezium_avro/data/post_ddl_workload.sql new file mode 100644 index 0000000000..9a89bbb3be --- /dev/null +++ b/tests/integration_tests/debezium_avro/data/post_ddl_workload.sql @@ -0,0 +1,8 @@ +USE test; + +UPDATE tp_account SET note = 'after ddl' WHERE id = 12; +INSERT INTO tp_account(id, account_id, name, balance, payload, note) +VALUES (15, 45, 'carol', 90.1200, x'070809', 'deleted'); +DELETE FROM tp_account WHERE id = 15; +INSERT INTO tp_account(id, account_id, name, balance, payload, note) +VALUES (16, 46, 'dave', NULL, NULL, 'final'); diff --git a/tests/integration_tests/debezium_avro/data/prepare.sql b/tests/integration_tests/debezium_avro/data/prepare.sql index f4c762f03e..deee75e17a 100644 --- a/tests/integration_tests/debezium_avro/data/prepare.sql +++ b/tests/integration_tests/debezium_avro/data/prepare.sql @@ -2,82 +2,10 @@ DROP DATABASE IF EXISTS test; CREATE DATABASE test; USE test; -CREATE TABLE tp_int ( - id INT AUTO_INCREMENT, - c_tinyint TINYINT NULL, - c_smallint SMALLINT NULL, - c_mediumint MEDIUMINT NULL, - c_int INT NULL, - c_bigint BIGINT NULL, - PRIMARY KEY (id) -); - -CREATE TABLE tp_unsigned_int ( - id INT AUTO_INCREMENT, - c_unsigned_tinyint TINYINT UNSIGNED NULL, - c_unsigned_smallint SMALLINT UNSIGNED NULL, - c_unsigned_mediumint MEDIUMINT UNSIGNED NULL, - c_unsigned_int INT UNSIGNED NULL, - c_unsigned_bigint BIGINT UNSIGNED NULL, - PRIMARY KEY (id) -); - -CREATE TABLE tp_real ( - id INT AUTO_INCREMENT, - c_float FLOAT NULL, - c_double DOUBLE NULL, - c_decimal DECIMAL NULL, - c_decimal_2 DECIMAL(10, 4) NULL, - PRIMARY KEY (id) -); - -CREATE TABLE tp_time ( - id INT AUTO_INCREMENT, - c_date DATE NULL, - c_datetime DATETIME NULL, - c_timestamp TIMESTAMP NULL, - c_time TIME NULL, - c_year YEAR NULL, - PRIMARY KEY (id) -); - -CREATE TABLE tp_text ( - id INT AUTO_INCREMENT, - c_tinytext TINYTEXT NULL, - c_text TEXT NULL, - c_mediumtext MEDIUMTEXT NULL, - c_longtext LONGTEXT NULL, - PRIMARY KEY (id) -); - -CREATE TABLE tp_blob ( - id INT AUTO_INCREMENT, - c_tinyblob TINYBLOB NULL, - c_blob BLOB NULL, - c_mediumblob MEDIUMBLOB NULL, - c_longblob LONGBLOB NULL, - PRIMARY KEY (id) -); - -CREATE TABLE tp_char_binary ( - id INT AUTO_INCREMENT, - c_char CHAR(16) NULL, - c_varchar VARCHAR(16) NULL, - c_binary BINARY(16) NULL, - c_varbinary VARBINARY(16) NULL, - PRIMARY KEY (id) -); - -CREATE TABLE tp_other ( - id INT AUTO_INCREMENT, - c_enum ENUM ('a', 'b', 'c') NULL, - c_set SET ('a', 'b', 'c') NULL, - c_bit BIT(64) NULL, - c_json JSON NULL, - PRIMARY KEY (id) -); - CREATE TABLE tp_account ( - id INT PRIMARY KEY, - account_id INT NOT NULL + id BIGINT UNSIGNED PRIMARY KEY, + account_id INT NOT NULL, + name VARCHAR(64) NULL, + balance DECIMAL(20, 4) NULL, + payload VARBINARY(16) NULL ); diff --git a/tests/integration_tests/debezium_avro/data/workload.sql b/tests/integration_tests/debezium_avro/data/workload.sql index 9931a83409..a6e3ad3bc4 100644 --- a/tests/integration_tests/debezium_avro/data/workload.sql +++ b/tests/integration_tests/debezium_avro/data/workload.sql @@ -1,75 +1,9 @@ USE test; -INSERT INTO tp_int() VALUES (); -INSERT INTO tp_int(c_tinyint, c_smallint, c_mediumint, c_int, c_bigint) -VALUES (1, 2, 3, 4, 5); -INSERT INTO tp_int(c_tinyint, c_smallint, c_mediumint, c_int, c_bigint) -VALUES (127, 32767, 8388607, 2147483647, 9223372036854775807); -INSERT INTO tp_int(c_tinyint, c_smallint, c_mediumint, c_int, c_bigint) -VALUES (-128, -32768, -8388608, -2147483648, -9223372036854775808); -UPDATE tp_int SET c_int = 0, c_tinyint = 0 WHERE id = 2; -DELETE FROM tp_int WHERE id = 2; - -INSERT INTO tp_unsigned_int() VALUES (); -INSERT INTO tp_unsigned_int( - c_unsigned_tinyint, - c_unsigned_smallint, - c_unsigned_mediumint, - c_unsigned_int, - c_unsigned_bigint -) VALUES (1, 2, 3, 4, 5); -INSERT INTO tp_unsigned_int( - c_unsigned_tinyint, - c_unsigned_smallint, - c_unsigned_mediumint, - c_unsigned_int, - c_unsigned_bigint -) VALUES (255, 65535, 16777215, 4294967295, 18446744073709551615); -UPDATE tp_unsigned_int SET c_unsigned_int = 0, c_unsigned_tinyint = 0 WHERE id = 3; -DELETE FROM tp_unsigned_int WHERE id = 3; - -INSERT INTO tp_real() VALUES (); -INSERT INTO tp_real(c_float, c_double, c_decimal, c_decimal_2) -VALUES (2020.0202, 2020.0303, 2020.0404, 2021.1208); -INSERT INTO tp_real(c_float, c_double, c_decimal, c_decimal_2) -VALUES (-2.7182818284, -3.1415926, -8000, -179394.233); -UPDATE tp_real SET c_double = 2.333 WHERE id = 2; - -INSERT INTO tp_time() VALUES (); -INSERT INTO tp_time(c_date, c_datetime, c_timestamp, c_time, c_year) -VALUES ('2020-02-20', '2020-02-20 02:20:20', '2020-02-20 02:20:20', '02:20:20', '2020'); -INSERT INTO tp_time(c_date, c_datetime, c_timestamp, c_time, c_year) -VALUES ('2022-02-22', '2022-02-22 22:22:22', '2020-02-20 02:20:20', '02:20:20', '2021'); -UPDATE tp_time SET c_year = '2022' WHERE id = 2; - -INSERT INTO tp_text() VALUES (); -INSERT INTO tp_text(c_tinytext, c_text, c_mediumtext, c_longtext) -VALUES ('89504E470D0A1A0A', '89504E470D0A1A0A', '89504E470D0A1A0A', '89504E470D0A1A0A'); -INSERT INTO tp_text(c_tinytext, c_text, c_mediumtext, c_longtext) -VALUES ('89504E470D0A1A0B', '89504E470D0A1A0B', '89504E470D0A1A0B', '89504E470D0A1A0B'); -UPDATE tp_text SET c_text = '89504E470D0A1A0B' WHERE id = 2; - -INSERT INTO tp_blob() VALUES (); -INSERT INTO tp_blob(c_tinyblob, c_blob, c_mediumblob, c_longblob) -VALUES (x'89504E470D0A1A0A', x'89504E470D0A1A0A', x'89504E470D0A1A0A', x'89504E470D0A1A0A'); -INSERT INTO tp_blob(c_tinyblob, c_blob, c_mediumblob, c_longblob) -VALUES (x'89504E470D0A1A0B', x'89504E470D0A1A0B', x'89504E470D0A1A0B', x'89504E470D0A1A0B'); -UPDATE tp_blob SET c_blob = x'89504E470D0A1A0B' WHERE id = 2; - -INSERT INTO tp_char_binary() VALUES (); -INSERT INTO tp_char_binary(c_char, c_varchar, c_binary, c_varbinary) -VALUES ('89504E470D0A1A0A', '89504E470D0A1A0A', x'89504E470D0A1A0A', x'89504E470D0A1A0A'); -INSERT INTO tp_char_binary(c_char, c_varchar, c_binary, c_varbinary) -VALUES ('89504E470D0A1A0B', '89504E470D0A1A0B', x'89504E470D0A1A0B', x'89504E470D0A1A0B'); -UPDATE tp_char_binary SET c_varchar = '89504E470D0A1A0B' WHERE id = 2; - -INSERT INTO tp_other() VALUES (); -INSERT INTO tp_other(c_enum, c_set, c_bit, c_json) -VALUES ('a', 'a,b', b'1000001', '{"key1":"value1","key2":"value2"}'); -INSERT INTO tp_other(c_enum, c_set, c_bit, c_json) -VALUES ('b', 'b,c', b'1000001', '{"key1":"value1","key2":"value2","key3":"123"}'); -UPDATE tp_other SET c_enum = 'c' WHERE id = 3; - -INSERT INTO tp_account VALUES (12, 34); -UPDATE tp_account SET account_id = 35 WHERE id = 12; -DELETE FROM tp_account WHERE id = 12; +INSERT INTO tp_account VALUES (12, 34, 'alice', 12.3400, x'010203'); +UPDATE tp_account +SET account_id = 35, name = 'bob', balance = 56.7800, payload = x'040506' +WHERE id = 12; +INSERT INTO tp_account +VALUES (18446744073709551615, 99, 'max', 1.2300, x'ff'); +DELETE FROM tp_account WHERE id = 18446744073709551615; diff --git a/tests/integration_tests/debezium_avro/run.sh b/tests/integration_tests/debezium_avro/run.sh index bfb1ceaede..31c96d73ce 100644 --- a/tests/integration_tests/debezium_avro/run.sh +++ b/tests/integration_tests/debezium_avro/run.sh @@ -26,6 +26,13 @@ function start_schema_registry() { curl -X PUT -H "Content-Type: application/vnd.schemaregistry.v1+json" --data '{"compatibility": "NONE"}' http://127.0.0.1:8088/config } +function check_schema_registry_subject() { + local subject=$1 + local expected=$2 + + curl -fsS "http://127.0.0.1:8088/subjects/${subject}/versions/latest" | grep -q "$expected" +} + function run() { if [ "$SINK_TYPE" != "kafka" ]; then return @@ -36,12 +43,15 @@ function run() { start_schema_registry start_tidb_cluster --workdir "$WORK_DIR" + run_sql_file "$CUR/data/prepare.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" + run_sql_file "$CUR/data/prepare.sql" "$DOWN_TIDB_HOST" "$DOWN_TIDB_PORT" + start_ts=$(run_cdc_cli_tso_query "$UP_PD_HOST_1" "$UP_PD_PORT_1") run_cdc_server --workdir "$WORK_DIR" --binary "$CDC_BINARY" TOPIC_NAME="ticdc-debezium-avro-$RANDOM" - SINK_URI="kafka://127.0.0.1:9092/$TOPIC_NAME?protocol=debezium&enable-tidb-extension=true&partition-num=1&kafka-version=${KAFKA_VERSION}&max-message-bytes=10485760" + SINK_URI="kafka://127.0.0.1:9092/$TOPIC_NAME?protocol=debezium-avro&enable-tidb-extension=true&partition-num=1&kafka-version=${KAFKA_VERSION}&max-message-bytes=10485760&avro-decimal-handling-mode=precise&avro-bigint-unsigned-handling-mode=string" schema_registry_uri="http://127.0.0.1:8088" changefeed_id="debezium-avro-$RANDOM" @@ -49,11 +59,14 @@ function run() { sleep 5 # wait for changefeed to start run_kafka_consumer "$WORK_DIR" "$SINK_URI" "" "$schema_registry_uri" - run_sql_file "$CUR/data/prepare.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" run_sql_file "$CUR/data/workload.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" - run_sql "CREATE TABLE test.finish_mark (id int primary key);" "$UP_TIDB_HOST" "$UP_TIDB_PORT" - check_table_exists test.finish_mark "$DOWN_TIDB_HOST" "$DOWN_TIDB_PORT" 200 - check_sync_diff "$WORK_DIR" "$CUR/conf/diff_config.toml" + run_sql_file "$CUR/data/ddl.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" + run_sql_file "$CUR/data/ddl.sql" "$DOWN_TIDB_HOST" "$DOWN_TIDB_PORT" + run_sql_file "$CUR/data/post_ddl_workload.sql" "$UP_TIDB_HOST" "$UP_TIDB_PORT" + + check_sync_diff "$WORK_DIR" "$CUR/conf/diff_config.toml" 120 + check_schema_registry_subject "$TOPIC_NAME-key" "tp_accountKey" + check_schema_registry_subject "$TOPIC_NAME-value" "tp_accountEnvelope" cleanup_process "$CDC_BINARY" } From 27329eae9ac8a155fccf8cbd6e8bb58d3555f00e Mon Sep 17 00:00:00 2001 From: wk989898 Date: Wed, 24 Jun 2026 07:14:45 +0000 Subject: [PATCH 08/10] update Signed-off-by: wk989898 --- pkg/sink/codec/debezium/avro.go | 19 +++++++++++++++++-- pkg/sink/codec/debezium/avro_decoder.go | 4 +++- pkg/sink/codec/debezium/codec.go | 5 +++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/pkg/sink/codec/debezium/avro.go b/pkg/sink/codec/debezium/avro.go index 42c6e2c122..0fec1cc024 100644 --- a/pkg/sink/codec/debezium/avro.go +++ b/pkg/sink/codec/debezium/avro.go @@ -18,6 +18,7 @@ import ( "context" "encoding/base64" "encoding/json" + "math" "math/big" "strconv" "strings" @@ -404,7 +405,7 @@ func (c *debeziumAvroSchemaConverter) toNative( if err != nil { return nil, err } - return int32(v), nil + return int64ToInt32(v) case "int64": return numberToInt64(value) case "float": @@ -540,7 +541,7 @@ func numberToInt64(value any) (int64, error) { case int64: return v, nil case uint64: - return int64(v), nil + return uint64ToInt64(v) case float64: return int64(v), nil default: @@ -548,6 +549,20 @@ func numberToInt64(value any) (int64, error) { } } +func int64ToInt32(value int64) (int32, error) { + if value < math.MinInt32 || value > math.MaxInt32 { + return 0, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("number payload is out of int32 range") + } + return int32(value), nil +} + +func uint64ToInt64(value uint64) (int64, error) { + if value > math.MaxInt64 { + return 0, errors.ErrDebeziumInvalidMessage.GenWithStackByArgs("number payload is out of int64 range") + } + return int64(value), nil +} + func numberToFloat64(value any) (float64, error) { switch v := value.(type) { case json.Number: diff --git a/pkg/sink/codec/debezium/avro_decoder.go b/pkg/sink/codec/debezium/avro_decoder.go index 8c91ec813d..083fc3481d 100644 --- a/pkg/sink/codec/debezium/avro_decoder.go +++ b/pkg/sink/codec/debezium/avro_decoder.go @@ -184,7 +184,9 @@ func (d *avroDecoder) getSchema(schemaID int) (*registeredDebeziumAvroSchema, er if err != nil { return nil, errors.WrapError(errors.ErrAvroSchemaAPIError, err) } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/pkg/sink/codec/debezium/codec.go b/pkg/sink/codec/debezium/codec.go index 4346f8c1e7..67467a6919 100644 --- a/pkg/sink/codec/debezium/codec.go +++ b/pkg/sink/codec/debezium/codec.go @@ -18,6 +18,7 @@ import ( "encoding/hex" "fmt" "io" + "math" "strconv" "strings" "time" @@ -830,6 +831,10 @@ func (c *dbzCodec) writeDebeziumFieldValue( if c.config.AvroBigintUnsignedHandlingMode == common.BigintUnsignedHandlingModeString { writer.WriteStringField(colName, strconv.FormatUint(v, 10)) } else { + if v > math.MaxInt64 { + return errors.ErrDebeziumEncodeFailed.GenWithStackByArgs( + fmt.Sprintf("unsigned bigint value %d overflows avro long", v)) + } writer.WriteInt64Field(colName, int64(v)) } return nil From f2aa4231795abb90a23f728f216340c5fc85d142 Mon Sep 17 00:00:00 2001 From: wk989898 Date: Wed, 24 Jun 2026 08:01:23 +0000 Subject: [PATCH 09/10] . Signed-off-by: wk989898 --- cmd/kafka-consumer/writer.go | 5 ----- cmd/kafka-consumer/writer_test.go | 1 - 2 files changed, 6 deletions(-) diff --git a/cmd/kafka-consumer/writer.go b/cmd/kafka-consumer/writer.go index eeec7ac5b4..d7196b8190 100644 --- a/cmd/kafka-consumer/writer.go +++ b/cmd/kafka-consumer/writer.go @@ -373,11 +373,6 @@ func (w *writer) flushPartitionDMLEvents( zap.Uint64("watermark", watermark), zap.Int("total", total), zap.Duration("duration", time.Since(start))) progress.updateWatermark(watermark, progress.watermarkOffset) - for _, item := range resolvedGroups { - if item.maxCommitTs > item.group.AppliedWatermark { - item.group.AppliedWatermark = item.maxCommitTs - } - } return nil case <-ticker.C: log.Warn("partition DML events cannot be flushed in time", diff --git a/cmd/kafka-consumer/writer_test.go b/cmd/kafka-consumer/writer_test.go index 1460057a57..270216fb08 100644 --- a/cmd/kafka-consumer/writer_test.go +++ b/cmd/kafka-consumer/writer_test.go @@ -352,7 +352,6 @@ func TestFlushPartitionDMLEventsFlushesWithoutResolved(t *testing.T) { require.NoError(t, w.flushPartitionDMLEvents(ctx, progress, 150)) require.Equal(t, []uint64{100}, flushedCommitTs) - require.Equal(t, uint64(100), group.AppliedWatermark) remaining := group.ResolveInto(300, nil) require.Len(t, remaining, 1) From 3c875f63d3a524c96934668a6b1b752ae2f769b6 Mon Sep 17 00:00:00 2001 From: wk989898 Date: Wed, 24 Jun 2026 09:32:17 +0000 Subject: [PATCH 10/10] update Signed-off-by: wk989898 --- cmd/kafka-consumer/writer.go | 85 ---------------------------- cmd/kafka-consumer/writer_test.go | 39 ------------- pkg/sink/codec/debezium/avro_test.go | 33 ++++++++++- pkg/sink/codec/debezium/codec.go | 2 +- pkg/sink/codec/debezium/encoder.go | 2 +- 5 files changed, 34 insertions(+), 127 deletions(-) diff --git a/cmd/kafka-consumer/writer.go b/cmd/kafka-consumer/writer.go index d7196b8190..897a42c4ff 100644 --- a/cmd/kafka-consumer/writer.go +++ b/cmd/kafka-consumer/writer.go @@ -311,78 +311,6 @@ func (w *writer) flushDMLEventsByWatermark(ctx context.Context) error { } } -func (w *writer) flushPartitionDMLEvents( - ctx context.Context, - progress *partitionProgress, - watermark uint64, -) error { - var ( - done = make(chan struct{}, 1) - - total int - flushed atomic.Int64 - ) - - resolvedEvents := make([]*event.DMLEvent, 0) - resolvedGroups := make([]struct { - group *util.EventsGroup - maxCommitTs uint64 - }, 0) - for _, group := range progress.eventsGroup { - before := len(resolvedEvents) - resolvedEvents = group.ResolveInto(watermark, resolvedEvents) - resolvedCount := len(resolvedEvents) - before - if resolvedCount == 0 { - continue - } - - resolvedGroups = append(resolvedGroups, struct { - group *util.EventsGroup - maxCommitTs uint64 - }{ - group: group, - maxCommitTs: resolvedEvents[len(resolvedEvents)-1].GetCommitTs(), - }) - total += resolvedCount - } - if total == 0 { - return nil - } - for _, e := range resolvedEvents { - e.AddPostFlushFunc(func() { - if flushed.Inc() == int64(total) { - close(done) - } - }) - w.mysqlSink.AddDMLEvent(e) - log.Debug("flush partition DML event", zap.Int64("tableID", e.GetTableID()), - zap.Uint64("commitTs", e.GetCommitTs()), zap.Any("startTs", e.GetStartTs())) - } - - log.Info("flush partition DML events", zap.Int32("partition", progress.partition), - zap.Uint64("watermark", watermark), zap.Int("total", total)) - start := time.Now() - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return context.Cause(ctx) - case <-done: - log.Info("flush partition DML events done", zap.Int32("partition", progress.partition), - zap.Uint64("watermark", watermark), zap.Int("total", total), - zap.Duration("duration", time.Since(start))) - progress.updateWatermark(watermark, progress.watermarkOffset) - return nil - case <-ticker.C: - log.Warn("partition DML events cannot be flushed in time", - zap.Int32("partition", progress.partition), - zap.Uint64("watermark", watermark), - zap.Int("total", total), zap.Int64("flushed", flushed.Load())) - } - } -} - // WriteMessage is to decode kafka message to event. // return true if the message is flushed to the downstream. // return error if flush messages failed. @@ -455,7 +383,6 @@ func (w *writer) WriteMessage(ctx context.Context, message *kafka.Message) bool break } - maxCommitTs := row.GetCommitTs() w.appendRow2Group(row, progress, offset) counter++ for { @@ -464,9 +391,6 @@ func (w *writer) WriteMessage(ctx context.Context, message *kafka.Message) bool break } row = progress.decoder.NextDMLEvent() - if row.GetCommitTs() > maxCommitTs { - maxCommitTs = row.GetCommitTs() - } w.appendRow2Group(row, progress, offset) counter++ } @@ -482,15 +406,6 @@ func (w *writer) WriteMessage(ctx context.Context, message *kafka.Message) bool zap.Int("maxBatchSize", w.maxBatchSize), zap.Int("actualBatchSize", counter), zap.Int32("partition", partition), zap.Any("offset", offset)) } - if w.protocol == config.ProtocolDebeziumAvro { - progress.watermarkOffset = offset - if err := w.flushPartitionDMLEvents(ctx, progress, maxCommitTs); err != nil { - log.Panic("flush debezium avro dml events failed", zap.Error(err), - zap.Int32("partition", partition), zap.Any("offset", offset), - zap.Uint64("watermark", maxCommitTs)) - } - return true - } default: log.Panic("unknown message type", zap.Any("messageType", messageType), zap.Int32("partition", partition), zap.Any("offset", offset)) diff --git a/cmd/kafka-consumer/writer_test.go b/cmd/kafka-consumer/writer_test.go index 270216fb08..0ff0c3d859 100644 --- a/cmd/kafka-consumer/writer_test.go +++ b/cmd/kafka-consumer/writer_test.go @@ -318,42 +318,3 @@ func TestOnDDLMarksRoutedCreateTableLikePartitionTableForAvro(t *testing.T) { require.Len(t, resolved, 1) require.Equal(t, uint64(100), resolved[0].CommitTs) } - -func TestFlushPartitionDMLEventsFlushesWithoutResolved(t *testing.T) { - ctx := context.Background() - ctrl := gomock.NewController(t) - s := sinkmock.NewMockSink(ctrl) - - flushedCommitTs := make([]uint64, 0) - s.EXPECT().AddDMLEvent(gomock.Any()).Do(func(e *commonEvent.DMLEvent) { - flushedCommitTs = append(flushedCommitTs, e.CommitTs) - e.PostFlush() - }).AnyTimes() - - w := &writer{mysqlSink: s} - progress := &partitionProgress{ - partition: 0, - eventsGroup: map[int64]*util.EventsGroup{1: util.NewEventsGroup(0, 1)}, - } - newDMLEvent := func(commitTs uint64) *commonEvent.DMLEvent { - return &commonEvent.DMLEvent{ - PhysicalTableID: 1, - CommitTs: commitTs, - RowTypes: []common.RowType{common.RowTypeUpdate}, - Rows: chunk.NewChunkWithCapacity(nil, 0), - TableInfo: &common.TableInfo{ - TableName: common.TableName{Schema: "test", Table: "t"}, - }, - } - } - group := progress.eventsGroup[1] - group.Append(newDMLEvent(100), false) - group.Append(newDMLEvent(200), false) - - require.NoError(t, w.flushPartitionDMLEvents(ctx, progress, 150)) - require.Equal(t, []uint64{100}, flushedCommitTs) - - remaining := group.ResolveInto(300, nil) - require.Len(t, remaining, 1) - require.Equal(t, uint64(200), remaining[0].CommitTs) -} diff --git a/pkg/sink/codec/debezium/avro_test.go b/pkg/sink/codec/debezium/avro_test.go index 45680ff413..0fe63b6bce 100644 --- a/pkg/sink/codec/debezium/avro_test.go +++ b/pkg/sink/codec/debezium/avro_test.go @@ -480,7 +480,7 @@ func TestDebeziumConfluentAvroDoesNotEncodeDDLEvent(t *testing.T) { require.Nil(t, message) } -func TestDebeziumConfluentAvroDoesNotEncodeCheckpointEvent(t *testing.T) { +func TestDebeziumConfluentAvroEncodeCheckpointEvent(t *testing.T) { ctx := context.Background() _, err := avro.SetupEncoderAndSchemaRegistry4Testing( ctx, @@ -498,6 +498,37 @@ func TestDebeziumConfluentAvroDoesNotEncodeCheckpointEvent(t *testing.T) { encoder, err := NewAvroBatchEncoder(ctx, cfg, "dbserver1") require.NoError(t, err) + message, err := encoder.EncodeCheckpointEvent(100) + require.NoError(t, err) + require.NotNil(t, message) + + decoder, err := NewAvroDecoder(ctx, cfg, 0, nil) + require.NoError(t, err) + decoder.AddKeyValue(message.Key, message.Value) + + messageType, hasNext := decoder.HasNext() + require.True(t, hasNext) + require.Equal(t, common.MessageTypeResolved, messageType) + require.Equal(t, uint64(100), decoder.NextResolvedEvent()) +} + +func TestDebeziumConfluentAvroDoesNotEncodeCheckpointEventByDefault(t *testing.T) { + ctx := context.Background() + _, err := avro.SetupEncoderAndSchemaRegistry4Testing( + ctx, + common.NewConfig(config.ProtocolAvro), + ) + require.NoError(t, err) + defer avro.TeardownEncoderAndSchemaRegistry4Testing() + + cfg := common.NewConfig(config.ProtocolDebeziumAvro) + cfg.AvroConfluentSchemaRegistry = "http://127.0.0.1:8081" + cfg.EnableTiDBExtension = true + cfg.TimeZone = time.UTC + + encoder, err := NewAvroBatchEncoder(ctx, cfg, "dbserver1") + require.NoError(t, err) + message, err := encoder.EncodeCheckpointEvent(100) require.NoError(t, err) require.Nil(t, message) diff --git a/pkg/sink/codec/debezium/codec.go b/pkg/sink/codec/debezium/codec.go index 67467a6919..1ceea135e8 100644 --- a/pkg/sink/codec/debezium/codec.go +++ b/pkg/sink/codec/debezium/codec.go @@ -1703,7 +1703,7 @@ func (c *dbzCodec) EncodeCheckpointEvent( fmt.Sprintf("%s.%s.Envelope", common.SanitizeName(c.clusterID), "watermark")) jWriter.WriteIntField("version", 1) jWriter.WriteArrayField("fields", func() { - c.writeSourceSchema(jWriter, "") + c.writeSourceSchema(jWriter, "watermark") jWriter.WriteObjectElement(func() { jWriter.WriteStringField("type", "string") jWriter.WriteBoolField("optional", false) diff --git a/pkg/sink/codec/debezium/encoder.go b/pkg/sink/codec/debezium/encoder.go index 683d8ff915..2f49412e59 100644 --- a/pkg/sink/codec/debezium/encoder.go +++ b/pkg/sink/codec/debezium/encoder.go @@ -39,7 +39,7 @@ type BatchEncoder struct { // EncodeCheckpointEvent implements the RowEventEncoder interface func (d *BatchEncoder) EncodeCheckpointEvent(ts uint64) (*common.Message, error) { - if d.config.Protocol == config.ProtocolDebeziumAvro { + if d.config.Protocol == config.ProtocolDebeziumAvro && !d.config.AvroEnableWatermark { return nil, nil } if !d.config.EnableTiDBExtension {