From 972b0263c631f8e12a35eaf49dc79ca1706a925a Mon Sep 17 00:00:00 2001 From: jx2lee Date: Sun, 14 Jun 2026 22:31:42 +0900 Subject: [PATCH] modified cond statements when heatbeat returned cancellation --- go-sdk/pkg/worker/runner.go | 2 +- go-sdk/pkg/worker/runner_test.go | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/go-sdk/pkg/worker/runner.go b/go-sdk/pkg/worker/runner.go index 4065e30ad70ed..d30c2351bbe00 100644 --- a/go-sdk/pkg/worker/runner.go +++ b/go-sdk/pkg/worker/runner.go @@ -327,7 +327,7 @@ func (w *worker) ExecuteTaskWorkload(ctx context.Context, workload api.ExecuteTa var finalState api.TerminalTIState body := &api.TIUpdateStatePayload{} - if taskContext.Err() == ErrTaskCancelledAfterFailedHeartbeat { + if errors.Is(context.Cause(taskContext), ErrTaskCancelledAfterFailedHeartbeat) { // We've already logged when we failed to heartbeat, don't do it again finalState = api.TerminalTIStateFailed body.FromTITerminalStatePayload(api.TITerminalStatePayload{ diff --git a/go-sdk/pkg/worker/runner_test.go b/go-sdk/pkg/worker/runner_test.go index 72673c00df376..01ab4ba62b29a 100644 --- a/go-sdk/pkg/worker/runner_test.go +++ b/go-sdk/pkg/worker/runner_test.go @@ -21,6 +21,7 @@ import ( "context" "fmt" "log/slog" + "net/http" "sync/atomic" "testing" "time" @@ -31,6 +32,7 @@ import ( "github.com/spf13/viper" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "resty.dev/v3" "github.com/apache/airflow/go-sdk/bundle/bundlev1" "github.com/apache/airflow/go-sdk/pkg/api" @@ -218,6 +220,38 @@ func (s *WorkerSuite) TestTaskHeartbeatsWhileRunning() { True(count <= 11 && count >= 9, fmt.Sprintf("Call count of %d was not within the margin of error of 10+/-1", count)) } +func (s *WorkerSuite) TestTaskHeartbeatConflictStopsTask() { + id := uuid.New().String() + testWorkload := newTestWorkLoad(id, id[:8]) + + s.registry.AddDag(testWorkload.TI.DagId). + AddTaskWithName(testWorkload.TI.TaskId, func(ctx context.Context) error { + select { + case <-ctx.Done(): + return nil + case <-time.After(2 * time.Second): + return fmt.Errorf("task context was not cancelled") + } + }) + + s.ExpectTaskRun(id) + s.ExpectTaskState(id, api.TerminalTIStateFailed) + s.ti.EXPECT(). + Heartbeat(mock.Anything, uuid.MustParse(id), mock.Anything). + Return(&api.GeneralHTTPError{ + Response: &resty.Response{ + RawResponse: &http.Response{ + Status: "409 Conflict", + StatusCode: http.StatusConflict, + }, + }, + }) + s.client.EXPECT().TaskInstances().Return(s.ti) + + err := s.worker.ExecuteTaskWorkload(context.Background(), testWorkload) + s.NoError(err) +} + func (s *WorkerSuite) TestTaskHeartbeatErrorStopsTaskAndLogs() { s.T().Skip("TODO: Not implemented yet") }