diff --git a/logservice/logpuller/priority_queue.go b/logservice/logpuller/priority_queue.go deleted file mode 100644 index f19b6d34b9..0000000000 --- a/logservice/logpuller/priority_queue.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2025 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package logpuller - -import ( - "context" - "sync" - - "github.com/pingcap/errors" - "github.com/pingcap/ticdc/utils/heap" -) - -// PriorityQueue is a thread-safe priority queue for region tasks -// It integrates a signal channel to support blocking operations -type PriorityQueue struct { - mu sync.Mutex - heap *heap.Heap[PriorityTask] - - // signal channel for blocking operations - signal chan struct{} -} - -// NewPriorityQueue creates a new priority queue -func NewPriorityQueue() *PriorityQueue { - return &PriorityQueue{ - heap: heap.NewHeap[PriorityTask](), - signal: make(chan struct{}, 1024), - } -} - -// Push adds a task to the priority queue and sends a signal -// This is a non-blocking operation -func (pq *PriorityQueue) Push(task PriorityTask) { - pq.mu.Lock() - pq.heap.AddOrUpdate(task) - pq.mu.Unlock() - - // Send signal to notify waiting consumers - select { - case pq.signal <- struct{}{}: - default: - // Signal channel is full, ignore - } -} - -// Pop removes and returns the highest priority task -// This is a blocking operation that waits for a signal -// Returns nil if the context is cancelled -func (pq *PriorityQueue) Pop(ctx context.Context) (PriorityTask, error) { - for { - // First try to pop without waiting - pq.mu.Lock() - task, ok := pq.heap.PopTop() - pq.mu.Unlock() - - if ok { - return task, nil - } - - // Queue is empty, wait for signal - select { - case <-ctx.Done(): - return nil, ctx.Err() - case _, ok := <-pq.signal: - if !ok { - // Signal channel is closed. - return nil, errors.New("signal channel is closed") - } - // Got signal, try to pop again - continue - } - } -} - -// TryPop attempts to pop a task without blocking -// Returns nil if the queue is empty -func (pq *PriorityQueue) TryPop() PriorityTask { - pq.mu.Lock() - defer pq.mu.Unlock() - - task, ok := pq.heap.PopTop() - if !ok { - return nil - } - return task -} - -// Peek returns the highest priority task without removing it -// Returns nil if the queue is empty -func (pq *PriorityQueue) Peek() PriorityTask { - pq.mu.Lock() - defer pq.mu.Unlock() - - task, ok := pq.heap.PeekTop() - if !ok { - return nil - } - return task -} - -// Len returns the number of tasks in the queue -func (pq *PriorityQueue) Len() int { - pq.mu.Lock() - defer pq.mu.Unlock() - - return pq.heap.Len() -} - -// Close closes the signal channel -func (pq *PriorityQueue) Close() { - // pop all tasks - for pq.Len() > 0 { - pq.TryPop() - } -} diff --git a/logservice/logpuller/priority_queue_test.go b/logservice/logpuller/priority_queue_test.go deleted file mode 100644 index 2bae2456db..0000000000 --- a/logservice/logpuller/priority_queue_test.go +++ /dev/null @@ -1,493 +0,0 @@ -// Copyright 2025 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package logpuller - -import ( - "context" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/pingcap/ticdc/heartbeatpb" - "github.com/stretchr/testify/require" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikv" -) - -// mockPriorityTask is a simple mock implementation of PriorityTask for testing -type mockPriorityTask struct { - priority int - heapIndex int - regionInfo regionInfo - description string -} - -func newMockPriorityTask(priority int, description string) *mockPriorityTask { - // Create a minimal regionInfo for testing - verID := tikv.NewRegionVerID(1, 1, 1) - span := heartbeatpb.TableSpan{TableID: 1, StartKey: []byte("a"), EndKey: []byte("z")} - - // Create a subscribedSpan with atomic resolvedTs - subscribedSpan := &subscribedSpan{ - resolvedTs: atomic.Uint64{}, - } - subscribedSpan.resolvedTs.Store(oracle.GoTimeToTS(time.Now())) - - regionInfo := regionInfo{ - verID: verID, - span: span, - subscribedSpan: subscribedSpan, - } - - return &mockPriorityTask{ - priority: priority, - heapIndex: 0, - regionInfo: regionInfo, - description: description, - } -} - -func (m *mockPriorityTask) Priority() int { - return m.priority -} - -func (m *mockPriorityTask) GetRegionInfo() regionInfo { - return m.regionInfo -} - -func (m *mockPriorityTask) SetHeapIndex(index int) { - m.heapIndex = index -} - -func (m *mockPriorityTask) GetHeapIndex() int { - return m.heapIndex -} - -func (m *mockPriorityTask) LessThan(other PriorityTask) bool { - return m.Priority() < other.Priority() -} - -func TestNewPriorityQueue(t *testing.T) { - pq := NewPriorityQueue() - require.NotNil(t, pq) - require.NotNil(t, pq.heap) - require.NotNil(t, pq.signal) - require.Equal(t, 0, pq.Len()) -} - -func TestPriorityQueue_Push(t *testing.T) { - pq := NewPriorityQueue() - - task1 := newMockPriorityTask(10, "task1") - task2 := newMockPriorityTask(5, "task2") - - // Test pushing single task - pq.Push(task1) - require.Equal(t, 1, pq.Len()) - - // Test pushing multiple tasks - pq.Push(task2) - require.Equal(t, 2, pq.Len()) - - // Verify signal channel receives notifications - select { - case <-pq.signal: - // Expected - signal received - case <-time.After(time.Millisecond * 100): - t.Fatal("Expected signal but none received") - } -} - -func TestPriorityQueue_Peek(t *testing.T) { - pq := NewPriorityQueue() - - // Test peek on empty queue - task := pq.Peek() - require.Nil(t, task) - - // Add tasks with different priorities - task1 := newMockPriorityTask(10, "task1") - task2 := newMockPriorityTask(5, "task2") // Higher priority (lower value) - task3 := newMockPriorityTask(15, "task3") - - pq.Push(task1) - pq.Push(task2) - pq.Push(task3) - - // Peek should return highest priority task (lowest value) - topTask := pq.Peek() - require.NotNil(t, topTask) - require.Equal(t, 5, topTask.Priority()) - require.Equal(t, "task2", topTask.(*mockPriorityTask).description) - - // Verify peek doesn't remove the task - require.Equal(t, 3, pq.Len()) - - // Peek again should return the same task - topTaskAgain := pq.Peek() - require.Equal(t, topTask, topTaskAgain) -} - -func TestPriorityQueue_PopBlocking(t *testing.T) { - pq := NewPriorityQueue() - - // Test pop on empty queue with context cancellation - t.Run("PopWithCancellation", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) - defer cancel() - - start := time.Now() - task, err := pq.Pop(ctx) - require.Error(t, err) - elapsed := time.Since(start) - - require.Nil(t, task) - require.True(t, elapsed >= time.Millisecond*50) - }) - - // Test pop with signal - t.Run("PopWithSignal", func(t *testing.T) { - ctx := context.Background() - - // Add a task in a goroutine after a short delay - go func() { - time.Sleep(time.Millisecond * 50) - task1 := newMockPriorityTask(10, "task1") - pq.Push(task1) - }() - - start := time.Now() - task, err := pq.Pop(ctx) - require.NoError(t, err) - elapsed := time.Since(start) - - require.NotNil(t, task) - require.Equal(t, 10, task.Priority()) - require.True(t, elapsed >= time.Millisecond*50) - require.True(t, elapsed < time.Millisecond*200) // Should not wait too long - }) -} - -func TestPriorityQueue_PopOrder(t *testing.T) { - pq := NewPriorityQueue() - ctx := context.Background() - - // Add tasks with different priorities - tasks := []*mockPriorityTask{ - newMockPriorityTask(10, "task1"), - newMockPriorityTask(5, "task2"), // Highest priority - newMockPriorityTask(15, "task3"), - newMockPriorityTask(7, "task4"), - newMockPriorityTask(12, "task5"), - } - - for _, task := range tasks { - pq.Push(task) - } - - // Pop tasks and verify they come out in priority order - expectedOrder := []string{"task2", "task4", "task1", "task5", "task3"} - expectedPriorities := []int{5, 7, 10, 12, 15} - - for i, expectedDesc := range expectedOrder { - task, err := pq.Pop(ctx) - require.NoError(t, err) - require.NotNil(t, task) - require.Equal(t, expectedPriorities[i], task.Priority()) - require.Equal(t, expectedDesc, task.(*mockPriorityTask).description) - } - - // Verify queue is empty - require.Equal(t, 0, pq.Len()) -} - -func TestPriorityQueue_Len(t *testing.T) { - pq := NewPriorityQueue() - - // Test empty queue - require.Equal(t, 0, pq.Len()) - - // Add tasks and verify length - for i := 0; i < 5; i++ { - task := newMockPriorityTask(i, "task") - pq.Push(task) - require.Equal(t, i+1, pq.Len()) - } - - // Remove tasks and verify length - ctx := context.Background() - for i := 4; i >= 0; i-- { - pq.Pop(ctx) - require.Equal(t, i, pq.Len()) - } -} - -func TestPriorityQueue_ConcurrentOperations(t *testing.T) { - pq := NewPriorityQueue() - - numProducers := 3 - numConsumers := 2 - tasksPerProducer := 10 - totalTasks := numProducers * tasksPerProducer - - var wg sync.WaitGroup - var consumedCount int64 - var mu sync.Mutex - consumedTasks := make([]PriorityTask, 0, totalTasks) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Start consumers - for i := 0; i < numConsumers; i++ { - wg.Add(1) - go func(consumerID int) { - defer wg.Done() - for { - task, err := pq.Pop(ctx) - if err != nil { - return - } - - mu.Lock() - consumedTasks = append(consumedTasks, task) - count := atomic.AddInt64(&consumedCount, 1) - mu.Unlock() - - if count >= int64(totalTasks) { - cancel() // Signal other consumers to stop - return - } - } - }(i) - } - - // Start producers - for i := 0; i < numProducers; i++ { - wg.Add(1) - go func(producerID int) { - defer wg.Done() - for j := 0; j < tasksPerProducer; j++ { - priority := (producerID * tasksPerProducer) + j - task := newMockPriorityTask(priority, "concurrent_task") - pq.Push(task) - time.Sleep(time.Microsecond * 10) // Small delay to simulate real work - } - }(i) - } - - // Wait for all producers to finish - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - - // Wait with timeout - select { - case <-done: - // Success - case <-time.After(time.Second * 5): - cancel() // Cancel to stop consumers - t.Fatal("Test timed out") - } - - // Verify all tasks were consumed - require.Equal(t, int64(totalTasks), atomic.LoadInt64(&consumedCount)) - require.Equal(t, totalTasks, len(consumedTasks)) - - // Verify all tasks were processed - for i := 0; i < len(consumedTasks); i++ { - require.NotNil(t, consumedTasks[i]) - } -} - -func TestPriorityQueue_SignalChannelFull(t *testing.T) { - pq := NewPriorityQueue() - - // Fill the signal channel to capacity - for i := 0; i < cap(pq.signal); i++ { - select { - case pq.signal <- struct{}{}: - default: - t.Fatalf("Failed to fill signal channel at iteration %d", i) - } - } - - // Push a task when signal channel is full - should not block - task := newMockPriorityTask(10, "task") - start := time.Now() - pq.Push(task) - elapsed := time.Since(start) - - // Should complete quickly even though signal channel is full - require.True(t, elapsed < time.Millisecond*100) - require.Equal(t, 1, pq.Len()) -} - -func TestPriorityQueue_UpdateExistingTask(t *testing.T) { - pq := NewPriorityQueue() - - // Create a task and add it to queue - task := newMockPriorityTask(10, "task") - pq.Push(task) - require.Equal(t, 1, pq.Len()) - - // Update the task's priority and push again - task.priority = 5 - pq.Push(task) - - // Length should still be 1 (task was updated, not added) - require.Equal(t, 1, pq.Len()) - - // Verify the task has the updated priority - ctx := context.Background() - poppedTask, err := pq.Pop(ctx) - require.NoError(t, err) - require.NotNil(t, poppedTask) - require.Equal(t, 5, poppedTask.Priority()) -} - -func TestPriorityQueue_Close(t *testing.T) { - pq := NewPriorityQueue() - - // Add 3 task before closing - task := newMockPriorityTask(10, "task") - pq.Push(task) - require.Equal(t, 1, pq.Len()) - task2 := newMockPriorityTask(5, "task2") - pq.Push(task2) - require.Equal(t, 2, pq.Len()) - task3 := newMockPriorityTask(15, "task3") - pq.Push(task3) - require.Equal(t, 3, pq.Len()) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - wg := sync.WaitGroup{} - wg.Add(3) - go func() { - defer wg.Done() - for i := 0; i < 1000; i++ { - pq.Push(newMockPriorityTask(i, "task")) - } - }() - - go func() { - defer wg.Done() - defer cancel() - - for i := 0; i < 1000; i++ { - // Test that close doesn't panic - require.NotPanics(t, func() { - pq.Close() - }) - } - }() - - go func() { - defer wg.Done() - for i := 0; i < 1000; i++ { - // Make sure it won't block when the queue is closed - pq.Pop(ctx) - } - }() - - wg.Wait() - require.NotPanics(t, func() { - pq.Close() - }) - // Test that the tasks are popped - require.Equal(t, 0, pq.Len()) -} - -func TestPriorityQueue_EmptyQueueOperations(t *testing.T) { - pq := NewPriorityQueue() - - // Test peek on empty queue - task := pq.Peek() - require.Nil(t, task) - - // Test len on empty queue - require.Equal(t, 0, pq.Len()) - - // Test pop on empty queue with immediate cancellation - ctx, cancel := context.WithCancel(context.Background()) - cancel() // Cancel immediately - - task2, err := pq.Pop(ctx) - require.Nil(t, task2) - require.Error(t, err) -} - -func TestPriorityQueue_RealPriorityTaskIntegration(t *testing.T) { - pq := NewPriorityQueue() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - currentTs := oracle.GoTimeToTS(time.Now()) - - // Create real priority tasks with different types - verID := tikv.NewRegionVerID(1, 1, 1) - span := heartbeatpb.TableSpan{TableID: 1, StartKey: []byte("a"), EndKey: []byte("z")} - - subscribedSpan := &subscribedSpan{ - resolvedTs: atomic.Uint64{}, - } - subscribedSpan.resolvedTs.Store(oracle.GoTimeToTS(time.Now().Add(-time.Second))) - - regionInfo := regionInfo{ - verID: verID, - span: span, - subscribedSpan: subscribedSpan, - } - - // Create tasks with different priorities - errorTask := NewRegionPriorityTask(TaskHighPrior, regionInfo, currentTs+1) - highTask := NewRegionPriorityTask(TaskHighPrior, regionInfo, currentTs) - lowTask := NewRegionPriorityTask(TaskLowPrior, regionInfo, currentTs) - - // Add tasks in non-priority order - pq.Push(lowTask) - pq.Push(errorTask) - pq.Push(highTask) - - require.Equal(t, 3, pq.Len()) - - // Pop tasks and verify they come out in priority order - // TaskRegionError should have highest priority (lowest value) - first, err := pq.Pop(ctx) - require.NoError(t, err) - require.NotNil(t, first) - require.Equal(t, TaskHighPrior, first.(*regionPriorityTask).taskType) - - second, err := pq.Pop(ctx) - require.NoError(t, err) - require.NotNil(t, second) - require.Equal(t, TaskHighPrior, second.(*regionPriorityTask).taskType) - - third, err := pq.Pop(ctx) - require.NoError(t, err) - require.NotNil(t, third) - require.Equal(t, TaskLowPrior, third.(*regionPriorityTask).taskType) - - require.Equal(t, 0, pq.Len()) - - pq.Close() - cancel() - task, err := pq.Pop(ctx) - require.Nil(t, task) - require.Error(t, err) -} diff --git a/logservice/logpuller/priority_task_test.go b/logservice/logpuller/priority_task_test.go index 5b1b373afb..d8b3f26d48 100644 --- a/logservice/logpuller/priority_task_test.go +++ b/logservice/logpuller/priority_task_test.go @@ -14,11 +14,15 @@ package logpuller import ( + "sync/atomic" "testing" "time" + "github.com/pingcap/ticdc/heartbeatpb" + "github.com/pingcap/ticdc/utils/priorityqueue" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" ) // TestPriorityCalculationLogic tests the priority calculation logic in isolation @@ -198,3 +202,45 @@ func TestEdgeCases(t *testing.T) { require.Less(t, priority2, priority1, "wait time longer task priority should be higher") }) } + +func TestRegionPriorityTaskQueueOrder(t *testing.T) { + queue := priorityqueue.New[PriorityTask]() + ctx := t.Context() + + currentTs := oracle.GoTimeToTS(time.Now()) + verID := tikv.NewRegionVerID(1, 1, 1) + span := heartbeatpb.TableSpan{TableID: 1, StartKey: []byte("a"), EndKey: []byte("z")} + + subscribedSpan := &subscribedSpan{ + resolvedTs: atomic.Uint64{}, + } + subscribedSpan.resolvedTs.Store(oracle.GoTimeToTS(time.Now().Add(-time.Second))) + + regionInfo := regionInfo{ + verID: verID, + span: span, + subscribedSpan: subscribedSpan, + } + + errorTask := NewRegionPriorityTask(TaskHighPrior, regionInfo, currentTs+1) + highTask := NewRegionPriorityTask(TaskHighPrior, regionInfo, currentTs) + lowTask := NewRegionPriorityTask(TaskLowPrior, regionInfo, currentTs) + + require.True(t, queue.Push(lowTask)) + require.True(t, queue.Push(errorTask)) + require.True(t, queue.Push(highTask)) + + first, err := queue.Pop(ctx) + require.NoError(t, err) + require.Equal(t, TaskHighPrior, first.(*regionPriorityTask).taskType) + + second, err := queue.Pop(ctx) + require.NoError(t, err) + require.Equal(t, TaskHighPrior, second.(*regionPriorityTask).taskType) + + third, err := queue.Pop(ctx) + require.NoError(t, err) + require.Equal(t, TaskLowPrior, third.(*regionPriorityTask).taskType) + + require.Equal(t, 0, queue.Len()) +} diff --git a/logservice/logpuller/subscription_client.go b/logservice/logpuller/subscription_client.go index 0ba3b4be24..fbcffdc4f1 100644 --- a/logservice/logpuller/subscription_client.go +++ b/logservice/logpuller/subscription_client.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/ticdc/pkg/spanz" "github.com/pingcap/ticdc/pkg/util" "github.com/pingcap/ticdc/utils/dynstream" + "github.com/pingcap/ticdc/utils/priorityqueue" "github.com/prometheus/client_golang/prometheus" kvclientv2 "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/oracle" @@ -212,7 +213,7 @@ type subscriptionClient struct { rangeTaskCh chan rangeTask // regionTaskQueue is used to receive region tasks with priority. // The region will be handled in `handleRegions` goroutine. - regionTaskQueue *PriorityQueue + regionTaskQueue *priorityqueue.PriorityQueue[PriorityTask] // resolveLockTaskCh is used to receive resolve lock tasks. // The tasks will be handled in `handleResolveLockTasks` goroutine. resolveLockTaskCh chan resolveLockTask @@ -241,7 +242,7 @@ func NewSubscriptionClient( credential: credential, rangeTaskCh: make(chan rangeTask, 1024), - regionTaskQueue: NewPriorityQueue(), + regionTaskQueue: priorityqueue.New[PriorityTask](), resolveLockTaskCh: make(chan resolveLockTask, 1024), resolveLockRateLimiter: newResolveLockRateLimiter(), errCache: newErrCache(), @@ -602,6 +603,9 @@ func (s *subscriptionClient) handleRegions(ctx context.Context, eg *errgroup.Gro // Use blocking Pop to wait for tasks regionTask, err := s.regionTaskQueue.Pop(ctx) if err != nil { + if errors.Is(err, priorityqueue.ErrClosed) { + return nil + } return err } diff --git a/logservice/logpuller/subscription_client_test.go b/logservice/logpuller/subscription_client_test.go index 2e87faa4fc..b502c4d7d1 100644 --- a/logservice/logpuller/subscription_client_test.go +++ b/logservice/logpuller/subscription_client_test.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/ticdc/pkg/pdutil" "github.com/pingcap/ticdc/pkg/security" "github.com/pingcap/ticdc/utils/dynstream" + "github.com/pingcap/ticdc/utils/priorityqueue" "github.com/pingcap/tidb/pkg/store/mockstore/mockcopr" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" @@ -278,7 +279,7 @@ func TestResolveLockTaskDroppedWhenChannelFull(t *testing.T) { func TestStopTaskUsesSubscribedSpanFilterLoop(t *testing.T) { client := &subscriptionClient{ resolveLockTaskCh: make(chan resolveLockTask, 1), - regionTaskQueue: NewPriorityQueue(), + regionTaskQueue: priorityqueue.New[PriorityTask](), } client.ctx, client.cancel = context.WithCancel(context.Background()) defer client.cancel() @@ -384,7 +385,7 @@ func (s *mockDynamicStream) GetMetrics() dynstream.Metrics[int, SubscriptionID] func TestPushRegionEventToDSUnblocksOnClose(t *testing.T) { client := &subscriptionClient{ ds: &mockDynamicStream{}, - regionTaskQueue: NewPriorityQueue(), + regionTaskQueue: priorityqueue.New[PriorityTask](), } client.ctx, client.cancel = context.WithCancel(context.Background()) client.cond = sync.NewCond(&client.mu) diff --git a/utils/priorityqueue/priority_queue.go b/utils/priorityqueue/priority_queue.go new file mode 100644 index 0000000000..e6f5fca0c6 --- /dev/null +++ b/utils/priorityqueue/priority_queue.go @@ -0,0 +1,151 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package priorityqueue provides a thread-safe blocking priority queue. +package priorityqueue + +import ( + "context" + "sync" + + "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/ticdc/utils/heap" +) + +// ErrClosed is returned by Pop when the queue has been closed and drained. +var ErrClosed = errors.New("priority queue is closed") + +// PriorityQueue is a thread-safe priority queue based on utils/heap. +// +// The queue uses heap.Item.LessThan to order items. Push and AddOrUpdate both +// use heap.AddOrUpdate semantics: if an item is already queued, its heap +// position is updated instead of inserting a duplicate. If an item's ordering +// fields change while it is queued, callers must call AddOrUpdate again to +// restore heap order. +// +// Pop blocks until an item is available, the context is canceled, or the queue +// is closed and drained. TryPop never blocks: it returns the current top item if +// one exists, or ok=false immediately when the queue is empty. +type PriorityQueue[T heap.Item[T]] struct { + mu sync.Mutex + heap *heap.Heap[T] + notify chan struct{} + closed bool +} + +// New creates an empty priority queue. +func New[T heap.Item[T]]() *PriorityQueue[T] { + return &PriorityQueue[T]{ + heap: heap.NewHeap[T](), + notify: make(chan struct{}, 1), + } +} + +// Push adds or updates an item and wakes one blocked Pop caller. +// It returns false if the queue has been closed. +func (q *PriorityQueue[T]) Push(item T) bool { + return q.AddOrUpdate(item) +} + +// AddOrUpdate adds an item if it is not in the queue, or updates its heap +// position if it is already queued. +func (q *PriorityQueue[T]) AddOrUpdate(item T) bool { + q.mu.Lock() + if q.closed { + q.mu.Unlock() + return false + } + q.heap.AddOrUpdate(item) + q.notifyOneLocked() + q.mu.Unlock() + return true +} + +// Pop blocks until an item is available, the queue is closed and drained, or ctx +// is done. +func (q *PriorityQueue[T]) Pop(ctx context.Context) (item T, err error) { + for { + q.mu.Lock() + var ok bool + item, ok = q.heap.PopTop() + if ok { + if q.heap.Len() > 0 { + q.notifyOneLocked() + } + q.mu.Unlock() + return item, nil + } + if q.closed { + q.mu.Unlock() + return item, ErrClosed + } + q.mu.Unlock() + + select { + case <-ctx.Done(): + return item, ctx.Err() + case _, open := <-q.notify: + if !open { + continue + } + } + } +} + +// TryPop removes and returns the top item without blocking. +func (q *PriorityQueue[T]) TryPop() (item T, ok bool) { + q.mu.Lock() + defer q.mu.Unlock() + + item, ok = q.heap.PopTop() + return item, ok +} + +// Peek returns the top item without removing it. +func (q *PriorityQueue[T]) Peek() (item T, ok bool) { + q.mu.Lock() + defer q.mu.Unlock() + + return q.heap.PeekTop() +} + +// Len returns the number of queued items. +func (q *PriorityQueue[T]) Len() int { + q.mu.Lock() + defer q.mu.Unlock() + + return q.heap.Len() +} + +// Close prevents future pushes and wakes blocked Pop callers. Items already in +// the queue remain available to Pop. +func (q *PriorityQueue[T]) Close() { + q.mu.Lock() + if q.closed { + q.mu.Unlock() + return + } + q.closed = true + close(q.notify) + q.mu.Unlock() +} + +func (q *PriorityQueue[T]) notifyOneLocked() { + if q.closed { + return + } + select { + case q.notify <- struct{}{}: + default: + } +} diff --git a/utils/priorityqueue/priority_queue_test.go b/utils/priorityqueue/priority_queue_test.go new file mode 100644 index 0000000000..7e9dba6d1e --- /dev/null +++ b/utils/priorityqueue/priority_queue_test.go @@ -0,0 +1,248 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package priorityqueue + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type mockItem struct { + priority int + heapIndex int + description string +} + +func newMockItem(priority int, description string) *mockItem { + return &mockItem{priority: priority, description: description} +} + +func (m *mockItem) SetHeapIndex(index int) { + m.heapIndex = index +} + +func (m *mockItem) GetHeapIndex() int { + return m.heapIndex +} + +func (m *mockItem) LessThan(other *mockItem) bool { + return m.priority < other.priority +} + +func TestQueuePushPeekPopOrder(t *testing.T) { + q := New[*mockItem]() + + _, ok := q.Peek() + require.False(t, ok) + + tasks := []*mockItem{ + newMockItem(10, "task1"), + newMockItem(5, "task2"), + newMockItem(15, "task3"), + newMockItem(7, "task4"), + newMockItem(12, "task5"), + } + for _, task := range tasks { + require.True(t, q.Push(task)) + } + require.Equal(t, 5, q.Len()) + + top, ok := q.Peek() + require.True(t, ok) + require.Equal(t, "task2", top.description) + require.Equal(t, 5, q.Len()) + + expectedOrder := []string{"task2", "task4", "task1", "task5", "task3"} + for _, expected := range expectedOrder { + task, err := q.Pop(context.Background()) + require.NoError(t, err) + require.Equal(t, expected, task.description) + } + require.Equal(t, 0, q.Len()) +} + +func TestQueuePopBlocking(t *testing.T) { + q := New[*mockItem]() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + start := time.Now() + task, err := q.Pop(ctx) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.Nil(t, task) + require.GreaterOrEqual(t, time.Since(start), 50*time.Millisecond) + + go func() { + time.Sleep(50 * time.Millisecond) + q.Push(newMockItem(10, "task1")) + }() + + start = time.Now() + task, err = q.Pop(context.Background()) + require.NoError(t, err) + require.Equal(t, "task1", task.description) + require.GreaterOrEqual(t, time.Since(start), 50*time.Millisecond) +} + +func TestQueueTryPopAndUpdateExistingItem(t *testing.T) { + q := New[*mockItem]() + + _, ok := q.TryPop() + require.False(t, ok) + + task := newMockItem(10, "task") + require.True(t, q.Push(task)) + task.priority = 5 + require.True(t, q.AddOrUpdate(task)) + require.Equal(t, 1, q.Len()) + + poppedTask, ok := q.TryPop() + require.True(t, ok) + require.Equal(t, 5, poppedTask.priority) + require.Equal(t, 0, poppedTask.heapIndex) +} + +func TestQueueConcurrentOperations(t *testing.T) { + q := New[*mockItem]() + + const ( + numProducers = 3 + numConsumers = 2 + tasksPerProducer = 10 + ) + totalTasks := numProducers * tasksPerProducer + + var wg sync.WaitGroup + var consumedCount int64 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for range numConsumers { + wg.Go(func() { + for { + task, err := q.Pop(ctx) + if err != nil { + return + } + require.NotNil(t, task) + if atomic.AddInt64(&consumedCount, 1) >= int64(totalTasks) { + cancel() + return + } + } + }) + } + + for producerID := range numProducers { + wg.Go(func() { + for j := range tasksPerProducer { + priority := producerID*tasksPerProducer + j + require.True(t, q.Push(newMockItem(priority, "task"))) + } + }) + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + cancel() + t.Fatal("test timed out") + } + require.Equal(t, int64(totalTasks), consumedCount) +} + +func TestQueueClose(t *testing.T) { + q := New[*mockItem]() + require.True(t, q.Push(newMockItem(10, "task1"))) + require.True(t, q.Push(newMockItem(5, "task2"))) + + q.Close() + require.False(t, q.Push(newMockItem(1, "closed"))) + + task, err := q.Pop(context.Background()) + require.NoError(t, err) + require.Equal(t, "task2", task.description) + + task, err = q.Pop(context.Background()) + require.NoError(t, err) + require.Equal(t, "task1", task.description) + + task, err = q.Pop(context.Background()) + require.ErrorIs(t, err, ErrClosed) + require.Nil(t, task) + + require.NotPanics(t, q.Close) +} + +func TestQueueCloseWakesBlockedPop(t *testing.T) { + q := New[*mockItem]() + + done := make(chan struct{}) + go func() { + defer close(done) + task, err := q.Pop(context.Background()) + require.ErrorIs(t, err, ErrClosed) + require.Nil(t, task) + }() + + q.Close() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("Pop was not woken by Close") + } +} + +func TestQueuePushMultipleItemsWakesMultipleBlockedPop(t *testing.T) { + q := New[*mockItem]() + + const waiters = 2 + ready := make(chan struct{}, waiters) + done := make(chan struct{}, waiters) + for range waiters { + go func() { + ready <- struct{}{} + task, err := q.Pop(context.Background()) + require.NoError(t, err) + require.NotNil(t, task) + done <- struct{}{} + }() + } + + for range waiters { + <-ready + } + require.True(t, q.Push(newMockItem(10, "task1"))) + require.True(t, q.Push(newMockItem(20, "task2"))) + + for range waiters { + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("blocked Pop was not woken") + } + } +}