diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index a482124f16..8d8f105344 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -33,6 +33,7 @@ import ( "go.temporal.io/server/client/matching" "go.temporal.io/server/common" "go.temporal.io/server/common/backoff" + "go.temporal.io/server/common/cache" "go.temporal.io/server/common/clock" hlc "go.temporal.io/server/common/clock/hybrid_logical_clock" "go.temporal.io/server/common/cluster" @@ -69,6 +70,14 @@ const ( // If sticky poller is not seem in last 10s, we treat it as sticky worker unavailable // This seems aggressive, but the default sticky schedule_to_start timeout is 5s, so 10s seems reasonable. stickyPollerUnavailableWindow = 10 * time.Second + + // shutdownWorkersCacheMaxSize is generous: each entry is a UUID string (~36 bytes), + // entries auto-expire after shutdownWorkersCacheTTL, and the cache only grows when + // workers shut down. Even with aggressive autoscaling, a single matching node is + // unlikely to see more than a few hundred worker shutdowns within the TTL window. + // LRU eviction ensures the oldest entries (least likely to re-poll) are evicted first. + shutdownWorkersCacheMaxSize = 10000 + shutdownWorkersCacheTTL = 30 * time.Second // If a compatible poller hasn't been seen for this time, we fail the CommitBuildId // Set to 70s so that it's a little over the max time a poller should be kept waiting. versioningPollerSeenWindow = 70 * time.Second @@ -166,6 +175,10 @@ type ( outstandingPollers collection.SyncMap[string, context.CancelFunc] // workerInstancePollers tracks pollers by worker instance key for bulk cancellation during shutdown. workerInstancePollers workerPollerTracker + // shutdownWorkers is a TTL cache of recently-shutdown worker instance keys. + // Polls from workers in this cache are rejected immediately to prevent + // zombie re-polls from stealing tasks after ShutdownWorker. + shutdownWorkers cache.Cache // Only set if global namespaces are enabled on the cluster. namespaceReplicationQueue persistence.NamespaceReplicationQueue // Lock to serialize replication queue updates. @@ -294,6 +307,7 @@ func NewEngine( nexusResults: collection.NewSyncMap[string, chan *nexusResult](), outstandingPollers: collection.NewSyncMap[string, context.CancelFunc](), workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(shutdownWorkersCacheMaxSize, &cache.Options{TTL: shutdownWorkersCacheTTL}), namespaceReplicationQueue: namespaceReplicationQueue, userDataUpdateBatchers: collection.NewSyncMap[namespace.ID, *stream_batcher.Batcher[*userDataUpdate, error]](), rateLimiter: rateLimiter, @@ -1218,6 +1232,9 @@ func (e *matchingEngineImpl) CancelOutstandingWorkerPolls( ctx context.Context, request *matchingservice.CancelOutstandingWorkerPollsRequest, ) (*matchingservice.CancelOutstandingWorkerPollsResponse, error) { + if request.WorkerInstanceKey != "" { + e.shutdownWorkers.Put(request.WorkerInstanceKey, struct{}{}) + } cancelledCount := e.workerInstancePollers.CancelAll(request.WorkerInstanceKey) e.removePollerFromHistory(ctx, request) return &matchingservice.CancelOutstandingWorkerPollsResponse{CancelledCount: cancelledCount}, nil @@ -2850,6 +2867,11 @@ func (e *matchingEngineImpl) pollTask( // reached, instead of emptyTask, context timeout error is returned to the frontend by the rpc stack, // which counts against our SLO. By shortening the timeout by a very small amount, the emptyTask can be // returned to the handler before a context timeout error is generated. + workerInstanceKey := pollMetadata.workerInstanceKey + if workerInstanceKey != "" && e.shutdownWorkers.Get(workerInstanceKey) != nil { + return nil, false, errNoTasks + } + ctx, cancel := contextutil.WithDeadlineBuffer(ctx, pm.LongPollExpirationInterval(), returnEmptyTaskTimeBudget) defer cancel() @@ -2858,7 +2880,6 @@ func (e *matchingEngineImpl) pollTask( // Also track by worker instance key for bulk cancellation during shutdown. // Use UUID (not pollerID) because pollerID is reused when forwarded. - workerInstanceKey := pollMetadata.workerInstanceKey pollerTrackerKey := uuid.NewString() if workerInstanceKey != "" { e.workerInstancePollers.Add(workerInstanceKey, pollerTrackerKey, cancel) diff --git a/service/matching/matching_engine_test.go b/service/matching/matching_engine_test.go index 5ed0cd24d7..d2274e263f 100644 --- a/service/matching/matching_engine_test.go +++ b/service/matching/matching_engine_test.go @@ -42,6 +42,7 @@ import ( taskqueuespb "go.temporal.io/server/api/taskqueue/v1" tokenspb "go.temporal.io/server/api/token/v1" "go.temporal.io/server/common" + "go.temporal.io/server/common/cache" "go.temporal.io/server/common/clock" hlc "go.temporal.io/server/common/clock/hybrid_logical_clock" "go.temporal.io/server/common/cluster" @@ -5690,6 +5691,7 @@ func TestCancelOutstandingWorkerPolls(t *testing.T) { t.Parallel() engine := &matchingEngineImpl{ workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(shutdownWorkersCacheMaxSize, &cache.Options{TTL: shutdownWorkersCacheTTL}), } resp, err := engine.CancelOutstandingWorkerPolls(context.Background(), @@ -5705,6 +5707,7 @@ func TestCancelOutstandingWorkerPolls(t *testing.T) { t.Parallel() engine := &matchingEngineImpl{ workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(shutdownWorkersCacheMaxSize, &cache.Options{TTL: shutdownWorkersCacheTTL}), } workerKey := "test-worker" @@ -5731,6 +5734,7 @@ func TestCancelOutstandingWorkerPolls(t *testing.T) { worker2Cancelled := false engine := &matchingEngineImpl{ workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(shutdownWorkersCacheMaxSize, &cache.Options{TTL: shutdownWorkersCacheTTL}), } // Set up pollers for worker1 and worker2 @@ -5753,6 +5757,7 @@ func TestCancelOutstandingWorkerPolls(t *testing.T) { t.Parallel() engine := &matchingEngineImpl{ workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(shutdownWorkersCacheMaxSize, &cache.Options{TTL: shutdownWorkersCacheTTL}), } workerKey := "test-worker" @@ -5775,4 +5780,38 @@ func TestCancelOutstandingWorkerPolls(t *testing.T) { require.True(t, childCancelled, "child partition poll should be cancelled") require.True(t, parentCancelled, "parent partition poll should be cancelled") }) + + t.Run("adds worker to shutdown cache", func(t *testing.T) { + t.Parallel() + engine := &matchingEngineImpl{ + workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(shutdownWorkersCacheMaxSize, &cache.Options{TTL: shutdownWorkersCacheTTL}), + } + + workerKey := "test-worker" + + _, err := engine.CancelOutstandingWorkerPolls(context.Background(), + &matchingservice.CancelOutstandingWorkerPollsRequest{ + WorkerInstanceKey: workerKey, + }) + + require.NoError(t, err) + require.NotNil(t, engine.shutdownWorkers.Get(workerKey), "worker should be in shutdown cache") + }) + + t.Run("empty worker key does not populate shutdown cache", func(t *testing.T) { + t.Parallel() + engine := &matchingEngineImpl{ + workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(shutdownWorkersCacheMaxSize, &cache.Options{TTL: shutdownWorkersCacheTTL}), + } + + _, err := engine.CancelOutstandingWorkerPolls(context.Background(), + &matchingservice.CancelOutstandingWorkerPollsRequest{ + WorkerInstanceKey: "", + }) + + require.NoError(t, err) + require.Equal(t, 0, engine.shutdownWorkers.Size()) + }) } diff --git a/tests/task_queue_test.go b/tests/task_queue_test.go index b9818d643b..a071ec34f2 100644 --- a/tests/task_queue_test.go +++ b/tests/task_queue_test.go @@ -1500,4 +1500,43 @@ func (s *TaskQueueSuite) TestShutdownWorkerCancelsOutstandingPolls() { s.NotEqual(tv.WorkerIdentity(), poller.GetIdentity(), "poller should be removed from DescribeTaskQueue after shutdown") } + + // Verify that subsequent polls from the same worker are rejected immediately + // (the shutdown worker cache prevents zombie re-polls from stealing tasks). + // Use a long timeout so we can distinguish "rejected quickly" from "timed out". + rePollTimeout := 5 * time.Minute + + // Workflow poll should be rejected immediately. + wfStart := time.Now() + rePollCtx, rePollCancel := context.WithTimeout(ctx, rePollTimeout) + defer rePollCancel() + rePollResp, err := s.FrontendClient().PollWorkflowTaskQueue(rePollCtx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: s.Namespace().String(), + TaskQueue: tv.TaskQueue(), + Identity: tv.WorkerIdentity(), + WorkerInstanceKey: workerInstanceKey, + }) + s.NoError(err) + s.NotNil(rePollResp) + s.Empty(rePollResp.GetTaskToken(), "re-poll from shutdown worker should return empty response") + // TODO: Replace timing assertion with an explicit poll response field indicating + // shutdown rejection, so we don't rely on timing to distinguish cache rejection + // from natural poll timeout. Requires adding a field to PollWorkflowTaskQueueResponse + // and PollActivityTaskQueueResponse in the public API proto. + s.Less(time.Since(wfStart), 2*time.Minute, "workflow re-poll should be rejected quickly, not wait for timeout") + + // Activity poll should also be rejected immediately. + actStart := time.Now() + actCtx, actCancel := context.WithTimeout(ctx, rePollTimeout) + defer actCancel() + actResp, err := s.FrontendClient().PollActivityTaskQueue(actCtx, &workflowservice.PollActivityTaskQueueRequest{ + Namespace: s.Namespace().String(), + TaskQueue: tv.TaskQueue(), + Identity: tv.WorkerIdentity(), + WorkerInstanceKey: workerInstanceKey, + }) + s.NoError(err) + s.NotNil(actResp) + s.Empty(actResp.GetTaskToken(), "activity re-poll from shutdown worker should return empty response") + s.Less(time.Since(actStart), 2*time.Minute, "activity re-poll should be rejected quickly, not wait for timeout") }