diff --git a/maintainer/maintainer.go b/maintainer/maintainer.go index 3e4e152312..580286ff9e 100644 --- a/maintainer/maintainer.go +++ b/maintainer/maintainer.go @@ -309,7 +309,7 @@ func (m *Maintainer) HandleEvent(event *Event) bool { zap.Stringer("changefeedID", m.changefeedID), zap.Int("eventType", event.eventType), zap.Duration("duration", duration), - zap.Any("Message", event.message), + zap.Any("MessageType", event.message.Type), ) } else { log.Info("maintainer is too slow", diff --git a/maintainer/replica/replication_span.go b/maintainer/replica/replication_span.go index 17fc2b1a84..e35aa9ac40 100644 --- a/maintainer/replica/replication_span.go +++ b/maintainer/replica/replication_span.go @@ -173,13 +173,18 @@ func (r *SpanReplication) GetMode() int64 { // // The new status is only stored if its checkpointTs is greater than or equal to // the current status's checkpointTs. -func (r *SpanReplication) UpdateStatus(newStatus *heartbeatpb.TableSpanStatus) { - if newStatus != nil { - oldStatus := r.status.Load() - if newStatus.CheckpointTs >= oldStatus.CheckpointTs { - r.status.Store(newStatus) - } +// +// It returns true when the stored checkpointTs changes. +func (r *SpanReplication) UpdateStatus(newStatus *heartbeatpb.TableSpanStatus) bool { + if newStatus == nil { + return false + } + oldStatus := r.status.Load() + if newStatus.CheckpointTs < oldStatus.CheckpointTs { + return false } + r.status.Store(newStatus) + return newStatus.CheckpointTs != oldStatus.CheckpointTs } // ShouldRun always returns true. diff --git a/maintainer/span/checkpoint_ts_tracker.go b/maintainer/span/checkpoint_ts_tracker.go new file mode 100644 index 0000000000..b6c41346fe --- /dev/null +++ b/maintainer/span/checkpoint_ts_tracker.go @@ -0,0 +1,175 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package span + +import ( + "container/heap" + + "github.com/pingcap/ticdc/pkg/common" +) + +// checkpointTsTracker maintains the minimum checkpointTs among non-DDL spans +// that are not replicating yet. The owning SpanController must hold its mutex +// when accessing this tracker. +type checkpointTsTracker struct { + // checkpointTsBySpanID contains exactly the non-DDL spans that are absent or + // scheduling. Replicating spans are removed because they no longer block the + // non-replicating minimum checkpoint. + checkpointTsBySpanID map[common.DispatcherID]uint64 + + // checkpointTsRefCounts tracks how many spans currently hold each checkpointTs. + // The heap stores each checkpointTs once, so the count keeps duplicate values + // from being removed too early. + checkpointTsRefCounts map[uint64]int + + // minCheckpointTsHeap stores unique checkpointTs values and gives O(1) access + // to the current minimum. Insertions and removals are O(log n). + minCheckpointTsHeap checkpointTsHeap +} + +func newCheckpointTsTracker() *checkpointTsTracker { + return &checkpointTsTracker{ + checkpointTsBySpanID: make(map[common.DispatcherID]uint64), + checkpointTsRefCounts: make(map[uint64]int), + minCheckpointTsHeap: newCheckpointTsHeap(), + } +} + +// trackSpan records a span that has entered a non-replicating state. It also +// handles duplicate calls for the same span by replacing the old checkpointTs. +func (t *checkpointTsTracker) trackSpan(id common.DispatcherID, checkpointTs uint64) { + if old, ok := t.checkpointTsBySpanID[id]; ok { + if old == checkpointTs { + return + } + t.decrement(old) + } + t.checkpointTsBySpanID[id] = checkpointTs + t.increment(checkpointTs) +} + +// updateTrackedSpan updates checkpointTs only for spans that are already +// tracked. Missing spans are ignored because DDL or replicating spans are not +// part of the non-replicating minimum. +func (t *checkpointTsTracker) updateTrackedSpan(id common.DispatcherID, checkpointTs uint64) { + old, ok := t.checkpointTsBySpanID[id] + if !ok || old == checkpointTs { + return + } + t.decrement(old) + t.checkpointTsBySpanID[id] = checkpointTs + t.increment(checkpointTs) +} + +// untrackSpan removes a span after it becomes replicating or leaves the +// controller. Missing spans are ignored for the same reason as updateTrackedSpan. +func (t *checkpointTsTracker) untrackSpan(id common.DispatcherID) { + old, ok := t.checkpointTsBySpanID[id] + if !ok { + return + } + delete(t.checkpointTsBySpanID, id) + t.decrement(old) + if len(t.checkpointTsBySpanID) == 0 { + // Release large maps after a bootstrap wave drains. A 1M-table changefeed + // can otherwise retain the tracker backing storage for its whole lifetime. + t.reset() + } +} + +// min returns the current minimum checkpointTs among tracked spans. +func (t *checkpointTsTracker) min() (uint64, bool) { + if t.minCheckpointTsHeap.Len() == 0 { + return 0, false + } + return t.minCheckpointTsHeap.peek(), true +} + +func (t *checkpointTsTracker) increment(checkpointTs uint64) { + if t.checkpointTsRefCounts[checkpointTs] > 0 { + t.checkpointTsRefCounts[checkpointTs]++ + return + } + t.checkpointTsRefCounts[checkpointTs] = 1 + heap.Push(&t.minCheckpointTsHeap, checkpointTs) +} + +func (t *checkpointTsTracker) decrement(checkpointTs uint64) { + count := t.checkpointTsRefCounts[checkpointTs] + if count <= 1 { + delete(t.checkpointTsRefCounts, checkpointTs) + t.minCheckpointTsHeap.remove(checkpointTs) + return + } + t.checkpointTsRefCounts[checkpointTs] = count - 1 +} + +func (t *checkpointTsTracker) reset() { + t.checkpointTsBySpanID = make(map[common.DispatcherID]uint64) + t.checkpointTsRefCounts = make(map[uint64]int) + t.minCheckpointTsHeap = newCheckpointTsHeap() +} + +// checkpointTsHeap is a removable min-heap for unique checkpointTs values. +type checkpointTsHeap struct { + values []uint64 + indexes map[uint64]int +} + +func newCheckpointTsHeap() checkpointTsHeap { + return checkpointTsHeap{ + indexes: make(map[uint64]int), + } +} + +func (h checkpointTsHeap) Len() int { + return len(h.values) +} + +func (h checkpointTsHeap) Less(i, j int) bool { + return h.values[i] < h.values[j] +} + +func (h checkpointTsHeap) Swap(i, j int) { + h.values[i], h.values[j] = h.values[j], h.values[i] + h.indexes[h.values[i]] = i + h.indexes[h.values[j]] = j +} + +func (h *checkpointTsHeap) Push(x any) { + checkpointTs := x.(uint64) + h.indexes[checkpointTs] = len(h.values) + h.values = append(h.values, checkpointTs) +} + +func (h *checkpointTsHeap) Pop() any { + n := len(h.values) + checkpointTs := h.values[n-1] + delete(h.indexes, checkpointTs) + h.values[n-1] = 0 + h.values = h.values[:n-1] + return checkpointTs +} + +func (h *checkpointTsHeap) peek() uint64 { + return h.values[0] +} + +func (h *checkpointTsHeap) remove(checkpointTs uint64) { + index, ok := h.indexes[checkpointTs] + if !ok { + return + } + heap.Remove(h, index) +} diff --git a/maintainer/span/checkpoint_ts_tracker_test.go b/maintainer/span/checkpoint_ts_tracker_test.go new file mode 100644 index 0000000000..e71e14ef63 --- /dev/null +++ b/maintainer/span/checkpoint_ts_tracker_test.go @@ -0,0 +1,92 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package span + +import ( + "testing" + + "github.com/pingcap/ticdc/pkg/common" + "github.com/stretchr/testify/require" +) + +func TestCheckpointTsTrackerMin(t *testing.T) { + t.Parallel() + + tracker := newCheckpointTsTracker() + id1 := common.NewDispatcherID() + id2 := common.NewDispatcherID() + id3 := common.NewDispatcherID() + + tracker.trackSpan(id1, 100) + tracker.trackSpan(id2, 80) + tracker.trackSpan(id3, 80) + + got, ok := tracker.min() + require.True(t, ok) + require.Equal(t, uint64(80), got) + + tracker.updateTrackedSpan(id2, 120) + got, ok = tracker.min() + require.True(t, ok) + require.Equal(t, uint64(80), got) + + tracker.untrackSpan(id3) + got, ok = tracker.min() + require.True(t, ok) + require.Equal(t, uint64(100), got) + + tracker.untrackSpan(id1) + got, ok = tracker.min() + require.True(t, ok) + require.Equal(t, uint64(120), got) + + tracker.untrackSpan(id2) + got, ok = tracker.min() + require.False(t, ok) + require.Equal(t, uint64(0), got) +} + +func TestCheckpointTsTrackerIgnoresMissingUpdate(t *testing.T) { + t.Parallel() + + tracker := newCheckpointTsTracker() + id := common.NewDispatcherID() + tracker.updateTrackedSpan(id, 100) + tracker.untrackSpan(id) + + got, ok := tracker.min() + require.False(t, ok) + require.Equal(t, uint64(0), got) +} + +func TestCheckpointTsTrackerRemovesStaleCheckpointTs(t *testing.T) { + t.Parallel() + + tracker := newCheckpointTsTracker() + blockingID := common.NewDispatcherID() + movingID := common.NewDispatcherID() + tracker.trackSpan(blockingID, 1) + tracker.trackSpan(movingID, 2) + + for checkpointTs := uint64(3); checkpointTs < 100; checkpointTs++ { + tracker.updateTrackedSpan(movingID, checkpointTs) + } + + require.Equal(t, 2, tracker.minCheckpointTsHeap.Len()) + + tracker.untrackSpan(blockingID) + got, ok := tracker.min() + require.True(t, ok) + require.Equal(t, uint64(99), got) +} diff --git a/maintainer/span/span_controller.go b/maintainer/span/span_controller.go index 45938cb9bc..40fcea4d47 100644 --- a/maintainer/span/span_controller.go +++ b/maintainer/span/span_controller.go @@ -57,7 +57,8 @@ type Controller struct { // so no need to schedule it ddlSpan *replica.SpanReplication - // mu protects concurrent access to [pkgreplica.ReplicationDB, ddlSpan, allTasks, schemaTasks, tableTasks] + // mu protects concurrent access to [pkgreplica.ReplicationDB, ddlSpan, allTasks, schemaTasks, tableTasks, + // nonReplicatingCheckpointTs] mu sync.RWMutex // ReplicationDB tracks the scheduling status of spans pkgreplica.ReplicationDB[common.DispatcherID, *replica.SpanReplication] @@ -67,6 +68,8 @@ type Controller struct { schemaTasks map[int64]map[common.DispatcherID]*replica.SpanReplication // tableTasks provides quick access to spans by table ID tableTasks map[int64]map[common.DispatcherID]*replica.SpanReplication + // nonReplicatingCheckpointTs tracks absent and scheduling spans so checkpoint calculation does not scan all spans. + nonReplicatingCheckpointTs *checkpointTsTracker // newGroupChecker creates a GroupChecker for validating span groups newGroupChecker func(groupID pkgreplica.GroupID) pkgreplica.GroupChecker[common.DispatcherID, *replica.SpanReplication] @@ -108,9 +111,10 @@ func NewController( keyspaceID: keyspaceID, maintainerCommittedCheckpointTs: atomic.NewUint64(ddlSpan.GetStatus().CheckpointTs), - schemaTasks: make(map[int64]map[common.DispatcherID]*replica.SpanReplication), - tableTasks: make(map[int64]map[common.DispatcherID]*replica.SpanReplication), - allTasks: make(map[common.DispatcherID]*replica.SpanReplication), + schemaTasks: make(map[int64]map[common.DispatcherID]*replica.SpanReplication), + tableTasks: make(map[int64]map[common.DispatcherID]*replica.SpanReplication), + allTasks: make(map[common.DispatcherID]*replica.SpanReplication), + nonReplicatingCheckpointTs: newCheckpointTsTracker(), } c.ReplicationDB = pkgreplica.NewReplicationDB(changefeedID.String(), c.doWithRLock, c.newGroupChecker) c.initializeDDLSpan(ddlSpan) @@ -229,15 +233,12 @@ func (c *Controller) AddNewSpans(schemaID int64, tableSpans []*heartbeatpb.Table } func (c *Controller) GetMinCheckpointTsForNonReplicatingSpans(minCheckpointTs uint64) uint64 { - for _, span := range c.GetAbsent() { - if span.GetStatus().CheckpointTs < minCheckpointTs { - minCheckpointTs = span.GetStatus().CheckpointTs - } - } - for _, span := range c.GetScheduling() { - if span.GetStatus().CheckpointTs < minCheckpointTs { - minCheckpointTs = span.GetStatus().CheckpointTs - } + c.mu.RLock() + defer c.mu.RUnlock() + + checkpointTs, ok := c.nonReplicatingCheckpointTs.min() + if ok && checkpointTs < minCheckpointTs { + return checkpointTs } return minCheckpointTs } @@ -352,10 +353,9 @@ func (c *Controller) UpdateSchemaID(tableID, newSchemaID int64) { // UpdateStatus updates the status of a span func (c *Controller) UpdateStatus(span *replica.SpanReplication, status *heartbeatpb.TableSpanStatus) { - span.UpdateStatus(status) - if span == c.ddlSpan { // ddl span don't need check by checker + span.UpdateStatus(status) return } // Note: a read lock is required inside the `GetGroupChecker` method. @@ -363,6 +363,9 @@ func (c *Controller) UpdateStatus(span *replica.SpanReplication, status *heartbe c.mu.Lock() defer c.mu.Unlock() + if span.UpdateStatus(status) { + c.nonReplicatingCheckpointTs.updateTrackedSpan(span.ID, span.GetStatus().CheckpointTs) + } checker.UpdateStatus(span) } @@ -388,6 +391,7 @@ func (c *Controller) AddReplicatingSpan(span *replica.SpanReplication) { c.allTasks[span.ID] = span c.addToSchemaAndTableMap(span) c.AddReplicatingWithoutLock(span) + c.untrackNonReplicatingSpan(span) } // MarkSpanAbsent marks span as absent @@ -395,6 +399,7 @@ func (c *Controller) MarkSpanAbsent(span *replica.SpanReplication) { c.mu.Lock() defer c.mu.Unlock() c.MarkAbsentWithoutLock(span) + c.trackNonReplicatingSpan(span) } // MarkSpanScheduling marks span as scheduling @@ -402,6 +407,7 @@ func (c *Controller) MarkSpanScheduling(span *replica.SpanReplication) { c.mu.Lock() defer c.mu.Unlock() c.MarkSchedulingWithoutLock(span) + c.trackNonReplicatingSpan(span) } // MarkSpanReplicating marks span as replicating @@ -409,6 +415,7 @@ func (c *Controller) MarkSpanReplicating(span *replica.SpanReplication) { c.mu.Lock() defer c.mu.Unlock() c.MarkReplicatingWithoutLock(span) + c.untrackNonReplicatingSpan(span) } // BindSpanToNode binds span to node @@ -416,6 +423,7 @@ func (c *Controller) BindSpanToNode(old, new node.ID, span *replica.SpanReplicat c.mu.Lock() defer c.mu.Unlock() c.BindReplicaToNodeWithoutLock(old, new, span) + c.trackNonReplicatingSpan(span) } // RemoveReplicatingSpan removes replicating span @@ -432,6 +440,7 @@ func (c *Controller) addAbsentReplicaSetWithoutLock(spans ...*replica.SpanReplic c.allTasks[span.ID] = span c.AddAbsentWithoutLock(span) c.addToSchemaAndTableMap(span) + c.trackNonReplicatingSpan(span) } } @@ -441,6 +450,7 @@ func (c *Controller) addSchedulingReplicaSetWithoutLock(span *replica.SpanReplic c.allTasks[span.ID] = span c.AddSchedulingReplicaWithoutLock(span, targetNodeID) c.addToSchemaAndTableMap(span) + c.trackNonReplicatingSpan(span) } // ReplaceReplicaSet replaces old replica sets with new spans and returns the newly created replicas. @@ -581,6 +591,7 @@ func (c *Controller) RemoveBySchemaID(schemaID int64) { func (c *Controller) removeSpanWithoutLock(spans ...*replica.SpanReplication) { for _, span := range spans { c.RemoveReplicaWithoutLock(span) + c.untrackNonReplicatingSpan(span) tableID := span.Span.TableID schemaID := span.GetSchemaID() @@ -596,6 +607,17 @@ func (c *Controller) removeSpanWithoutLock(spans ...*replica.SpanReplication) { } } +func (c *Controller) trackNonReplicatingSpan(span *replica.SpanReplication) { + if span == c.ddlSpan { + return + } + c.nonReplicatingCheckpointTs.trackSpan(span.ID, span.GetStatus().CheckpointTs) +} + +func (c *Controller) untrackNonReplicatingSpan(span *replica.SpanReplication) { + c.nonReplicatingCheckpointTs.untrackSpan(span.ID) +} + // addToSchemaAndTableMap adds the span to the schema and table map func (c *Controller) addToSchemaAndTableMap(span *replica.SpanReplication) { tableID := span.Span.TableID diff --git a/maintainer/span/span_controller_test.go b/maintainer/span/span_controller_test.go index 0e22c88039..827f175e12 100644 --- a/maintainer/span/span_controller_test.go +++ b/maintainer/span/span_controller_test.go @@ -29,6 +29,172 @@ import ( "github.com/stretchr/testify/require" ) +func newControllerForCheckpointTsTrackerTest(t *testing.T) *Controller { + t.Helper() + + changefeedID := common.NewChangeFeedIDWithName("test", common.DefaultKeyspaceName) + ddlDispatcherID := common.NewDispatcherID() + ddlSpan := replica.NewWorkingSpanReplication(changefeedID, ddlDispatcherID, + common.DDLSpanSchemaID, + common.KeyspaceDDLSpan(common.DefaultKeyspaceID), &heartbeatpb.TableSpanStatus{ + ID: ddlDispatcherID.ToPB(), + ComponentStatus: heartbeatpb.ComponentState_Working, + CheckpointTs: 1, + }, "node1", false) + appcontext.SetService(watcher.NodeManagerName, watcher.NewNodeManager(nil, nil)) + return NewController(changefeedID, ddlSpan, nil, nil, nil, common.DefaultKeyspaceID, common.DefaultMode) +} + +func newSpanReplicationForCheckpointTsTrackerTest( + controller *Controller, + schemaID int64, + tableID int64, + checkpointTs uint64, +) *replica.SpanReplication { + span := common.TableIDToComparableSpan(common.DefaultKeyspaceID, tableID) + tableSpan := &heartbeatpb.TableSpan{ + TableID: tableID, + StartKey: span.StartKey, + EndKey: span.EndKey, + KeyspaceID: common.DefaultKeyspaceID, + } + return replica.NewSpanReplication( + controller.changefeedID, + common.NewDispatcherID(), + schemaID, + tableSpan, + checkpointTs, + common.DefaultMode, + false, + ) +} + +func TestControllerGetMinCheckpointTsForNonReplicatingSpans(t *testing.T) { + controller := newControllerForCheckpointTsTrackerTest(t) + span1 := newSpanReplicationForCheckpointTsTrackerTest(controller, 1, 100, 100) + span2 := newSpanReplicationForCheckpointTsTrackerTest(controller, 1, 101, 80) + controller.AddAbsentReplicaSet(span1, span2) + + require.Equal(t, uint64(80), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.MarkSpanScheduling(span2) + controller.UpdateStatus(span2, &heartbeatpb.TableSpanStatus{ + ID: span2.ID.ToPB(), + ComponentStatus: heartbeatpb.ComponentState_Working, + CheckpointTs: 120, + }) + require.Equal(t, uint64(100), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.MarkSpanReplicating(span1) + require.Equal(t, uint64(120), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.MarkSpanReplicating(span2) + require.Equal(t, uint64(1000), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) +} + +func TestControllerCheckpointTsTrackerBindSpanToNode(t *testing.T) { + controller := newControllerForCheckpointTsTrackerTest(t) + span1 := newSpanReplicationForCheckpointTsTrackerTest(controller, 1, 100, 100) + span2 := newSpanReplicationForCheckpointTsTrackerTest(controller, 1, 101, 80) + controller.AddAbsentReplicaSet(span1, span2) + + controller.BindSpanToNode("", "node1", span2) + require.Equal(t, uint64(80), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.UpdateStatus(span2, &heartbeatpb.TableSpanStatus{ + ID: span2.ID.ToPB(), + ComponentStatus: heartbeatpb.ComponentState_Working, + CheckpointTs: 120, + }) + require.Equal(t, uint64(100), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.MarkSpanReplicating(span2) + require.Equal(t, uint64(100), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.MarkSpanReplicating(span1) + require.Equal(t, uint64(1000), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) +} + +func TestControllerCheckpointTsTrackerDirectScheduling(t *testing.T) { + controller := newControllerForCheckpointTsTrackerTest(t) + span := newSpanReplicationForCheckpointTsTrackerTest(controller, 1, 100, 90) + controller.AddSchedulingReplicaSet(span, "node1") + + require.Equal(t, uint64(90), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.UpdateStatus(span, &heartbeatpb.TableSpanStatus{ + ID: span.ID.ToPB(), + ComponentStatus: heartbeatpb.ComponentState_Working, + CheckpointTs: 80, + }) + require.Equal(t, uint64(90), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.UpdateStatus(controller.GetDDLDispatcher(), &heartbeatpb.TableSpanStatus{ + ID: controller.GetDDLDispatcherID().ToPB(), + ComponentStatus: heartbeatpb.ComponentState_Working, + CheckpointTs: 1, + }) + require.Equal(t, uint64(90), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.MarkSpanReplicating(span) + require.Equal(t, uint64(1000), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) +} + +func TestControllerCheckpointTsTrackerReplaceAndRemove(t *testing.T) { + controller := newControllerForCheckpointTsTrackerTest(t) + oldSpan := newSpanReplicationForCheckpointTsTrackerTest(controller, 1, 100, 50) + controller.AddAbsentReplicaSet(oldSpan) + newSpan := common.TableIDToComparableSpan(common.DefaultKeyspaceID, 101) + newTableSpan := &heartbeatpb.TableSpan{ + TableID: 101, + StartKey: newSpan.StartKey, + EndKey: newSpan.EndKey, + KeyspaceID: common.DefaultKeyspaceID, + } + + newSpans, inScheduling := controller.ReplaceReplicaSet( + []*replica.SpanReplication{oldSpan}, + []*heartbeatpb.TableSpan{newTableSpan}, + 80, + nil, + ) + + require.False(t, inScheduling) + require.Len(t, newSpans, 1) + require.Nil(t, controller.GetTaskByID(oldSpan.ID)) + require.Equal(t, uint64(50), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.RemoveByTableIDs(101) + require.Equal(t, uint64(1000), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) +} + +func TestControllerCheckpointTsTrackerReplaceIntoScheduling(t *testing.T) { + controller := newControllerForCheckpointTsTrackerTest(t) + oldSpan := newSpanReplicationForCheckpointTsTrackerTest(controller, 1, 100, 50) + controller.AddAbsentReplicaSet(oldSpan) + newSpan := common.TableIDToComparableSpan(common.DefaultKeyspaceID, 101) + newTableSpan := &heartbeatpb.TableSpan{ + TableID: 101, + StartKey: newSpan.StartKey, + EndKey: newSpan.EndKey, + KeyspaceID: common.DefaultKeyspaceID, + } + + newSpans, inScheduling := controller.ReplaceReplicaSet( + []*replica.SpanReplication{oldSpan}, + []*heartbeatpb.TableSpan{newTableSpan}, + 80, + []node.ID{"node1"}, + ) + + require.True(t, inScheduling) + require.Len(t, newSpans, 1) + require.Equal(t, uint64(50), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) + + controller.MarkSpanReplicating(newSpans[0]) + require.Equal(t, uint64(1000), controller.GetMinCheckpointTsForNonReplicatingSpans(1000)) +} + func TestNewController(t *testing.T) { cfID := common.NewChangeFeedIDWithName("test", common.DefaultKeyspaceName) ddlDispatcherID := common.NewDispatcherID() diff --git a/pkg/scheduler/replica/replication.go b/pkg/scheduler/replica/replication.go index 0bcf8cd583..faba76fc6e 100644 --- a/pkg/scheduler/replica/replication.go +++ b/pkg/scheduler/replica/replication.go @@ -171,15 +171,10 @@ func (db *replicationDB[T, R]) GetAbsentSize() int { } func (db *replicationDB[T, R]) GetAbsentByGroup(id GroupID, batch int) []R { - buffer := make([]R, 0, batch) + var buffer []R db.withRLock(func() { g := db.mustGetGroup(id) - for _, stm := range g.GetAbsent() { - buffer = append(buffer, stm) - if len(buffer) >= batch { - break - } - } + buffer = g.GetAbsentBatch(batch) }) return buffer } @@ -249,9 +244,11 @@ func (db *replicationDB[T, R]) GetSchedulingWithoutLock() (ret []R) { func (db *replicationDB[T, R]) GetSchedulingSize() int { size := 0 - for _, g := range db.taskGroups { - size += g.GetSchedulingSize() - } + db.withRLock(func() { + for _, g := range db.taskGroups { + size += g.GetSchedulingSize() + } + }) return size } diff --git a/pkg/scheduler/replica/replication_group.go b/pkg/scheduler/replica/replication_group.go index afd9b1ea93..3f24c5e02d 100644 --- a/pkg/scheduler/replica/replication_group.go +++ b/pkg/scheduler/replica/replication_group.go @@ -15,6 +15,7 @@ package replica import ( "sync" + "sync/atomic" "github.com/pingcap/log" "github.com/pingcap/ticdc/pkg/node" @@ -224,6 +225,24 @@ func (g *replicationGroup[T, R]) GetAbsent() []R { return res } +func (g *replicationGroup[T, R]) GetAbsentBatch(batch int) []R { + if batch <= 0 { + return nil + } + capacity := batch + if absentSize := g.absent.Len(); absentSize < capacity { + capacity = absentSize + } + res := make([]R, 0, capacity) + g.absent.Range(func(_ T, r R) bool { + if r.ShouldRun() { + res = append(res, r) + } + return len(res) < batch + }) + return res +} + func (g *replicationGroup[T, R]) GetSchedulingSize() int { return g.scheduling.Len() } @@ -268,6 +287,7 @@ func (g *replicationGroup[T, R]) IsReplicating(replica R) bool { type iMap[T ReplicationID, R Replication[T]] struct { inner sync.Map + size atomic.Int64 } func newIMap[T ReplicationID, R Replication[T]]() *iMap[T, R] { @@ -289,20 +309,19 @@ func (m *iMap[T, R]) Get(key T) (R, bool) { } func (m *iMap[T, R]) Set(key T, value R) { - m.inner.Store(key, value) + if _, loaded := m.inner.Swap(key, value); !loaded { + m.size.Add(1) + } } func (m *iMap[T, R]) Delete(key T) { - m.inner.Delete(key) + if _, loaded := m.inner.LoadAndDelete(key); loaded { + m.size.Add(-1) + } } func (m *iMap[T, R]) Len() int { - var count int - m.inner.Range(func(_, _ interface{}) bool { - count++ - return true - }) - return count + return int(m.size.Load()) } func (m *iMap[T, R]) Range(f func(T, R) bool) { diff --git a/pkg/scheduler/replica/replication_group_test.go b/pkg/scheduler/replica/replication_group_test.go new file mode 100644 index 0000000000..63fc5df26f --- /dev/null +++ b/pkg/scheduler/replica/replication_group_test.go @@ -0,0 +1,124 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package replica + +import ( + "fmt" + "sync/atomic" + "testing" + + "github.com/pingcap/ticdc/pkg/node" + "github.com/stretchr/testify/require" +) + +type testReplicationID string + +func (id testReplicationID) String() string { + return string(id) +} + +type testReplication struct { + id testReplicationID + groupID GroupID + nodeID node.ID + shouldRun bool + shouldRunCalls *atomic.Int64 +} + +func (r *testReplication) GetID() testReplicationID { + return r.id +} + +func (r *testReplication) GetGroupID() GroupID { + return r.groupID +} + +func (r *testReplication) GetNodeID() node.ID { + return r.nodeID +} + +func (r *testReplication) SetNodeID(nodeID node.ID) { + r.nodeID = nodeID +} + +func (r *testReplication) ShouldRun() bool { + if r.shouldRunCalls != nil { + r.shouldRunCalls.Add(1) + } + return r.shouldRun +} + +func TestIMapLenTracksOverwriteAndDelete(t *testing.T) { + t.Parallel() + + replicaMap := newIMap[testReplicationID, *testReplication]() + id := testReplicationID("a") + + replicaMap.Set(id, &testReplication{id: id}) + replicaMap.Set(id, &testReplication{id: id}) + require.Equal(t, 1, replicaMap.Len()) + + replicaMap.Delete(testReplicationID("missing")) + require.Equal(t, 1, replicaMap.Len()) + + replicaMap.Delete(id) + require.Equal(t, 0, replicaMap.Len()) +} + +func TestGetAbsentByGroupStopsAtBatch(t *testing.T) { + t.Parallel() + + var shouldRunCalls atomic.Int64 + db := NewReplicationDB[testReplicationID, *testReplication]( + "test", + func(action func()) { action() }, + NewEmptyChecker[testReplicationID, *testReplication], + ) + for i := 0; i < 100; i++ { + id := testReplicationID(fmt.Sprintf("r%d", i)) + db.AddAbsentWithoutLock(&testReplication{ + id: id, + groupID: DefaultGroupID, + shouldRun: true, + shouldRunCalls: &shouldRunCalls, + }) + } + + absent := db.GetAbsentByGroup(DefaultGroupID, 3) + require.Len(t, absent, 3) + require.Equal(t, int64(3), shouldRunCalls.Load()) +} + +func TestGetAbsentByGroupSkipsNotRunnableTasks(t *testing.T) { + t.Parallel() + + var shouldRunCalls atomic.Int64 + db := NewReplicationDB[testReplicationID, *testReplication]( + "test", + func(action func()) { action() }, + NewEmptyChecker[testReplicationID, *testReplication], + ) + for i := 0; i < 100; i++ { + id := testReplicationID(fmt.Sprintf("r%d", i)) + db.AddAbsentWithoutLock(&testReplication{ + id: id, + groupID: DefaultGroupID, + shouldRunCalls: &shouldRunCalls, + }) + } + + absent := db.GetAbsentByGroup(DefaultGroupID, 3) + require.Len(t, absent, 0) + require.Equal(t, int64(100), shouldRunCalls.Load()) +}