From 5415114105d2c9978cd4fd05ab54d4125b552030 Mon Sep 17 00:00:00 2001 From: wlggraham Date: Mon, 17 Nov 2025 16:44:16 -0700 Subject: [PATCH 01/10] cancel jobs --- .../databases/heimdall/data/job_statuses.sql | 4 +- internal/pkg/heimdall/heimdall.go | 5 +- internal/pkg/heimdall/job.go | 40 ++++++++- internal/pkg/heimdall/job_dal.go | 19 +++++ internal/pkg/heimdall/jobs_async.go | 15 +--- .../object/command/clickhouse/clickhouse.go | 3 +- internal/pkg/object/command/dynamo/dynamo.go | 2 +- internal/pkg/object/command/ecs/ecs.go | 15 ++-- internal/pkg/object/command/glue/glue.go | 4 +- internal/pkg/object/command/ping/ping.go | 3 +- internal/pkg/object/command/shell/shell.go | 3 +- .../pkg/object/command/snowflake/snowflake.go | 3 +- internal/pkg/object/command/spark/spark.go | 2 +- .../pkg/object/command/sparkeks/sparkeks.go | 7 +- internal/pkg/object/command/trino/trino.go | 5 +- internal/pkg/pool/pool.go | 84 ++++++++++++++++++- internal/pkg/pool/pool_dal.go | 59 +++++++++++++ .../pool/queries/cancelling_jobs_select.sql | 9 ++ .../queries/job_status_cancelled_update.sql | 6 ++ pkg/object/job/status/status.go | 28 ++++--- pkg/plugin/plugin.go | 4 +- 21 files changed, 270 insertions(+), 50 deletions(-) create mode 100644 internal/pkg/pool/pool_dal.go create mode 100644 internal/pkg/pool/queries/cancelling_jobs_select.sql create mode 100644 internal/pkg/pool/queries/job_status_cancelled_update.sql diff --git a/assets/databases/heimdall/data/job_statuses.sql b/assets/databases/heimdall/data/job_statuses.sql index 0acc895..b0e5420 100644 --- a/assets/databases/heimdall/data/job_statuses.sql +++ b/assets/databases/heimdall/data/job_statuses.sql @@ -9,7 +9,9 @@ values (3, 'RUNNING'), (4, 'FAILED'), (5, 'KILLED'), - (6, 'SUCCEEDED') + (6, 'SUCCEEDED'), + (7, 'CANCELLING'), + (8, 'CANCELLED') on conflict (job_status_id) do update set job_status_name = excluded.job_status_name; diff --git a/internal/pkg/heimdall/heimdall.go b/internal/pkg/heimdall/heimdall.go index d8a5840..dbf5fd7 100644 --- a/internal/pkg/heimdall/heimdall.go +++ b/internal/pkg/heimdall/heimdall.go @@ -15,12 +15,12 @@ import ( "github.com/patterninc/heimdall/internal/pkg/database" "github.com/patterninc/heimdall/internal/pkg/janitor" "github.com/patterninc/heimdall/internal/pkg/pool" + "github.com/patterninc/heimdall/internal/pkg/rbac" "github.com/patterninc/heimdall/internal/pkg/server" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/command" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" - "github.com/patterninc/heimdall/internal/pkg/rbac" rbacI "github.com/patterninc/heimdall/pkg/rbac" ) @@ -156,7 +156,7 @@ func (h *Heimdall) Init() error { } // let's start the agent - return h.Pool.Start(h.runAsyncJob, h.getAsyncJobs) + return h.Pool.Start(h.runAsyncJob, h.getAsyncJobs, h.Database) } @@ -173,6 +173,7 @@ func (h *Heimdall) Start() error { // job(s) endpoints... apiRouter.Methods(methodGET).PathPrefix(`/job/statuses`).HandlerFunc(payloadHandler(h.getJobStatuses)) apiRouter.Methods(methodGET).PathPrefix(`/job/{id}/status`).HandlerFunc(payloadHandler(h.getJobStatus)) + apiRouter.Methods(methodPOST).PathPrefix(`/job/{id}/cancel`).HandlerFunc(payloadHandler(h.cancelJob)) apiRouter.Methods(methodGET).PathPrefix(`/job/{id}/{file}`).HandlerFunc(h.getJobFile) apiRouter.Methods(methodGET).PathPrefix(`/job/{id}`).HandlerFunc(payloadHandler(h.getJob)) apiRouter.Methods(methodGET).PathPrefix(`/jobs`).HandlerFunc(payloadHandler(h.getJobs)) diff --git a/internal/pkg/heimdall/job.go b/internal/pkg/heimdall/job.go index b0cd043..62e01b9 100644 --- a/internal/pkg/heimdall/job.go +++ b/internal/pkg/heimdall/job.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" "crypto/rand" "encoding/json" "fmt" @@ -67,7 +68,7 @@ func (h *Heimdall) submitJob(j *job.Job) (any, error) { } // let's run the job - err = h.runJob(j, command, cluster) + err = h.runJob(j, command, cluster, context.Background()) // before we process the error, we'll make the best effort to record this job in the database go h.insertJob(j, cluster.ID, command.ID) @@ -76,7 +77,7 @@ func (h *Heimdall) submitJob(j *job.Job) (any, error) { } -func (h *Heimdall) runJob(job *job.Job, command *command.Command, cluster *cluster.Cluster) error { +func (h *Heimdall) runJob(job *job.Job, command *command.Command, cluster *cluster.Cluster, ctx context.Context) error { defer runJobMethod.RecordLatency(time.Now(), command.Name, cluster.Name) runJobMethod.CountRequest(command.Name, cluster.Name) @@ -107,7 +108,7 @@ func (h *Heimdall) runJob(job *job.Job, command *command.Command, cluster *clust go h.jobKeepalive(keepaliveActive, job.SystemID, h.agentName) // let's execute command - if err := h.commandHandlers[command.ID](runtime, job, cluster); err != nil { + if err := h.commandHandlers[command.ID](ctx, runtime, job, cluster); err != nil { job.Status = jobStatus.Failed job.Error = err.Error() @@ -118,6 +119,14 @@ func (h *Heimdall) runJob(job *job.Job, command *command.Command, cluster *clust } + if ctx.Err() != nil { + + job.Status = jobStatus.Cancelled + runJobMethod.LogAndCountError(ctx.Err(), command.Name, cluster.Name) + + return nil + } + if job.StoreResultSync || !job.IsSync { h.storeResults(runtime, job) } else { @@ -157,6 +166,31 @@ func (h *Heimdall) storeResults(runtime *plugin.Runtime, job *job.Job) error { return nil } +func (h *Heimdall) cancelJob(req *jobRequest) (any, error) { + + // validate that job exists and get its current status + currentJob, err := h.getJob(req) + if err != nil { + return nil, err + } + + job := currentJob.(*job.Job) + + // check if job can be cancelled (must be running or accepted) + if job.Status != jobStatus.Running && job.Status != jobStatus.Accepted { + return nil, fmt.Errorf("job cannot be cancelled: current status is %v", job.Status) + } + // update job status to CANCELLING + job.Status = jobStatus.Cancelling + if err := h.updateJobStatusToCancelling(job); err != nil { + return nil, err + } + + job.UpdatedAt = int(time.Now().Unix()) + + return job, nil +} + func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { // get vars diff --git a/internal/pkg/heimdall/job_dal.go b/internal/pkg/heimdall/job_dal.go index aaa34cd..53a4e75 100644 --- a/internal/pkg/heimdall/job_dal.go +++ b/internal/pkg/heimdall/job_dal.go @@ -304,3 +304,22 @@ func (h *Heimdall) getJobStatuses(_ *database.Filter) (any, error) { return database.GetSlice(h.Database, queryJobStatusesSelect) } + +func (h *Heimdall) updateJobStatusToCancelling(job *job.Job) error { + + // open connection + sess, err := h.Database.NewSession(true) + if err != nil { + return err + } + defer sess.Close() + + // update job status to CANCELLING + _, err = sess.Exec(queryJobStatusUpdate, job.Status, job.Error, job.SystemID) + if err != nil { + return err + } + + // commit transaction + return sess.Commit() +} diff --git a/internal/pkg/heimdall/jobs_async.go b/internal/pkg/heimdall/jobs_async.go index bdb0c71..cebc819 100644 --- a/internal/pkg/heimdall/jobs_async.go +++ b/internal/pkg/heimdall/jobs_async.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" _ "embed" "fmt" @@ -88,9 +89,9 @@ func (h *Heimdall) getAsyncJobs(limit int) ([]*job.Job, error) { } -func (h *Heimdall) runAsyncJob(j *job.Job) error { +func (h *Heimdall) runAsyncJob(ctx context.Context, j *job.Job) error { - // let's updte job status that we're running it... + // let's update job status to RUNNING... sess, err := h.Database.NewSession(false) if err != nil { return h.updateAsyncJobStatus(j, err) @@ -113,20 +114,12 @@ func (h *Heimdall) runAsyncJob(j *job.Job) error { return h.updateAsyncJobStatus(j, fmt.Errorf(formatErrUnknownCluster, j.CluserID)) } - return h.updateAsyncJobStatus(j, h.runJob(j, command, cluster)) + return h.updateAsyncJobStatus(j, h.runJob(j, command, cluster, ctx)) } func (h *Heimdall) updateAsyncJobStatus(j *job.Job, jobError error) error { - // we updte the final job status based on presence of the error - if jobError == nil { - j.Status = status.Succeeded - } else { - j.Status = status.Failed - j.Error = jobError.Error() - } - // now we update that status in the database sess, err := h.Database.NewSession(true) if err != nil { diff --git a/internal/pkg/object/command/clickhouse/clickhouse.go b/internal/pkg/object/command/clickhouse/clickhouse.go index 819f83a..472c781 100644 --- a/internal/pkg/object/command/clickhouse/clickhouse.go +++ b/internal/pkg/object/command/clickhouse/clickhouse.go @@ -57,8 +57,7 @@ func New(ctx *hdctx.Context) (plugin.Handler, error) { return t.handler, nil } -func (cmd *commandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { - ctx := context.Background() +func (cmd *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { jobContext, err := cmd.createJobContext(j, c) if err != nil { diff --git a/internal/pkg/object/command/dynamo/dynamo.go b/internal/pkg/object/command/dynamo/dynamo.go index f0a3334..b36b0d8 100644 --- a/internal/pkg/object/command/dynamo/dynamo.go +++ b/internal/pkg/object/command/dynamo/dynamo.go @@ -44,7 +44,7 @@ func New(_ *context.Context) (plugin.Handler, error) { } // Handler for the Spark job submission. -func (d *dynamoCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { +func (d *dynamoCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { // let's unmarshal job context jobContext := &dynamoJobContext{} diff --git a/internal/pkg/object/command/ecs/ecs.go b/internal/pkg/object/command/ecs/ecs.go index fc4c12a..85c2bf2 100644 --- a/internal/pkg/object/command/ecs/ecs.go +++ b/internal/pkg/object/command/ecs/ecs.go @@ -121,7 +121,7 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } // handler implements the main ECS plugin logic -func (e *ecsCommandContext) handler(r *plugin.Runtime, job *job.Job, cluster *cluster.Cluster) error { +func (e *ecsCommandContext) handler(ctx ct.Context, r *plugin.Runtime, job *job.Job, cluster *cluster.Cluster) error { // Build execution context with resolved configuration and loaded template execCtx, err := buildExecutionContext(e, job, cluster) @@ -140,7 +140,7 @@ func (e *ecsCommandContext) handler(r *plugin.Runtime, job *job.Job, cluster *cl } // Poll for completion - if err := execCtx.pollForCompletion(); err != nil { + if err := execCtx.pollForCompletion(ctx); err != nil { return err } @@ -197,7 +197,7 @@ func (execCtx *executionContext) startTasks(jobID string) error { } // monitor tasks until completion, faliure, or timeout -func (execCtx *executionContext) pollForCompletion() error { +func (execCtx *executionContext) pollForCompletion(ctx ct.Context) error { startTime := time.Now() stopTime := startTime.Add(time.Duration(execCtx.Timeout)) @@ -305,8 +305,13 @@ func (execCtx *executionContext) pollForCompletion() error { return fmt.Errorf("%s", reason) } - // Sleep until next poll time - time.Sleep(time.Duration(execCtx.PollingInterval)) + // Check for cancellation or sleep until next poll time + select { + case <-ctx.Done(): + stopAllTasks(execCtx, "Job cancelled by user") + return nil + case <-time.After(time.Duration(execCtx.PollingInterval)): + } } // If you're here, all tasks are complete diff --git a/internal/pkg/object/command/glue/glue.go b/internal/pkg/object/command/glue/glue.go index 63128a9..b2ac5b2 100644 --- a/internal/pkg/object/command/glue/glue.go +++ b/internal/pkg/object/command/glue/glue.go @@ -1,6 +1,8 @@ package glue import ( + ct "context" + "github.com/patterninc/heimdall/internal/pkg/aws" "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" @@ -31,7 +33,7 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } -func (g *glueCommandContext) handler(_ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { +func (g *glueCommandContext) handler(ct ct.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { // let's unmarshal job context jc := &glueJobContext{} diff --git a/internal/pkg/object/command/ping/ping.go b/internal/pkg/object/command/ping/ping.go index 64fd333..3b438f4 100644 --- a/internal/pkg/object/command/ping/ping.go +++ b/internal/pkg/object/command/ping/ping.go @@ -1,6 +1,7 @@ package ping import ( + ct "context" "fmt" "github.com/patterninc/heimdall/pkg/context" @@ -23,7 +24,7 @@ func New(_ *context.Context) (plugin.Handler, error) { } -func (p *pingCommandContext) handler(_ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { +func (p *pingCommandContext) handler(ct ct.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { j.Result, err = result.FromMessage(fmt.Sprintf(messageFormat, j.User)) return diff --git a/internal/pkg/object/command/shell/shell.go b/internal/pkg/object/command/shell/shell.go index d16b20d..9b21037 100644 --- a/internal/pkg/object/command/shell/shell.go +++ b/internal/pkg/object/command/shell/shell.go @@ -1,6 +1,7 @@ package shell import ( + ct "context" "encoding/json" "os" "os/exec" @@ -47,7 +48,7 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } -func (s *shellCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *shellCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // let's unmarshal job context jc := &shellJobContext{} diff --git a/internal/pkg/object/command/snowflake/snowflake.go b/internal/pkg/object/command/snowflake/snowflake.go index be267a7..f4d2c1a 100644 --- a/internal/pkg/object/command/snowflake/snowflake.go +++ b/internal/pkg/object/command/snowflake/snowflake.go @@ -1,6 +1,7 @@ package snowflake import ( + ct "context" "crypto/rsa" "crypto/x509" "database/sql" @@ -70,7 +71,7 @@ func New(_ *context.Context) (plugin.Handler, error) { return s.handler, nil } -func (s *snowflakeCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *snowflakeCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { clusterContext := &snowflakeClusterContext{} if c.Context != nil { diff --git a/internal/pkg/object/command/spark/spark.go b/internal/pkg/object/command/spark/spark.go index 9d0f382..a624933 100644 --- a/internal/pkg/object/command/spark/spark.go +++ b/internal/pkg/object/command/spark/spark.go @@ -92,7 +92,7 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } // Handler for the Spark job submission. -func (s *sparkCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { +func (s *sparkCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { // let's unmarshal job context jobContext := &sparkJobContext{} diff --git a/internal/pkg/object/command/sparkeks/sparkeks.go b/internal/pkg/object/command/sparkeks/sparkeks.go index 4750889..2aa007a 100644 --- a/internal/pkg/object/command/sparkeks/sparkeks.go +++ b/internal/pkg/object/command/sparkeks/sparkeks.go @@ -3,6 +3,7 @@ package sparkeks import ( "bytes" "context" + ct "context" "encoding/json" "fmt" "io" @@ -155,7 +156,11 @@ func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { } // handler executes the Spark EKS job submission and execution. -func (s *sparkEksCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *sparkEksCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { + + // Assign global context to incoming cancellation context + ctx = ct + // 1. Build execution context, create URIs, and upload query execCtx, err := buildExecutionContextAndURI(r, j, c, s) if err != nil { diff --git a/internal/pkg/object/command/trino/trino.go b/internal/pkg/object/command/trino/trino.go index ce1891c..1a96146 100644 --- a/internal/pkg/object/command/trino/trino.go +++ b/internal/pkg/object/command/trino/trino.go @@ -1,6 +1,7 @@ package trino import ( + ct "context" "fmt" "log" "time" @@ -52,7 +53,7 @@ func New(ctx *context.Context) (plugin.Handler, error) { } -func (t *commandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (t *commandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // get job context jobCtx := &jobContext{} @@ -112,7 +113,7 @@ func canQueryBeExecuted(query, user, id string, c *cluster.Cluster) bool { return false } } - + canBeExecutedMethod.CountSuccess() return true } diff --git a/internal/pkg/pool/pool.go b/internal/pkg/pool/pool.go index 368a12d..6ceb0f1 100644 --- a/internal/pkg/pool/pool.go +++ b/internal/pkg/pool/pool.go @@ -1,8 +1,13 @@ package pool import ( + "context" "fmt" + "sync" "time" + + "github.com/patterninc/heimdall/internal/pkg/database" + "github.com/patterninc/heimdall/pkg/object/job" ) const ( @@ -13,10 +18,17 @@ const ( type Pool[T any] struct { Size int `yaml:"size,omitempty" json:"size,omitempty"` Sleep int `yaml:"sleep,omitempty" json:"sleep,omitempty"` - queue chan T + queue chan *job.Job + + runningJobs map[string]context.CancelFunc + runningJobsMux sync.RWMutex + db *database.Database } -func (p *Pool[T]) Start(worker func(T) error, getWork func(int) ([]T, error)) error { +func (p *Pool[T]) Start(worker func(context.Context, *job.Job) error, getWork func(int) ([]*job.Job, error), database *database.Database) error { + + // record database context + p.db = database // do we have the size set? if p.Size <= 0 { @@ -29,11 +41,17 @@ func (p *Pool[T]) Start(worker func(T) error, getWork func(int) ([]T, error)) er } // set the queue of the size of double the pool size - p.queue = make(chan T, p.Size*2) + p.queue = make(chan *job.Job, p.Size*2) // let's set the counter tokens := &counter{} + // Initialize tracking + p.runningJobs = make(map[string]context.CancelFunc) + + // Start cancellation polling loop + go p.pollForCancellations() + // let's provision workers for i := 0; i < p.Size; i++ { @@ -52,8 +70,18 @@ func (p *Pool[T]) Start(worker func(T) error, getWork func(int) ([]T, error)) er break } + ctx, cancel := context.WithCancel(context.Background()) + + // register job as running + p.registerRunningJob(w.ID, cancel) + // do the work.... - if err := worker(w); err != nil { + err := worker(ctx, w) + + // remove job from running jobs + p.unregisterRunningJob(w.ID) + + if err != nil { // TODO: implement proper error logging fmt.Println(`worker:`, err) } @@ -106,5 +134,53 @@ func (p *Pool[T]) Start(worker func(T) error, getWork func(int) ([]T, error)) er }(tokens) return nil +} + +func (p *Pool[T]) pollForCancellations() { + // let's poll for cancellations every 15 seconds + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for range ticker.C { + + // Get jobs in CANCELLING state from database + cancellingJobs := p.getCancellingJobs() + + // Check each cancelling job + for _, jobID := range cancellingJobs { + if cancelFunc, isLocal := p.isJobRunningLocally(jobID); isLocal { + cancelFunc() // Trigger context cancellation + + // Update job status to CANCELLED in database + p.updateJobStatusToCancelled(jobID) + } + } + } +} + +func (p *Pool[T]) registerRunningJob(jobID string, cancel context.CancelFunc) { + + p.runningJobsMux.Lock() + defer p.runningJobsMux.Unlock() + + p.runningJobs[jobID] = cancel + +} + +func (p *Pool[T]) unregisterRunningJob(jobID string) { + + p.runningJobsMux.Lock() + defer p.runningJobsMux.Unlock() + + delete(p.runningJobs, jobID) +} + +// Check if a job is running locally +func (p *Pool[T]) isJobRunningLocally(jobID string) (context.CancelFunc, bool) { + + p.runningJobsMux.RLock() + defer p.runningJobsMux.RUnlock() + cancelFunc, exists := p.runningJobs[jobID] + return cancelFunc, exists } diff --git a/internal/pkg/pool/pool_dal.go b/internal/pkg/pool/pool_dal.go new file mode 100644 index 0000000..c09d89b --- /dev/null +++ b/internal/pkg/pool/pool_dal.go @@ -0,0 +1,59 @@ +package pool + +import ( + _ "embed" +) + +//go:embed queries/cancelling_jobs_select.sql +var queryCancellingJobsSelect string + +//go:embed queries/job_status_cancelled_update.sql +var queryJobStatusCancelledUpdate string + +// getCancellingJobs retrieves jobs in CANCELLING state from database +func (p *Pool[T]) getCancellingJobs() []string { + + sess, err := p.db.NewSession(false) + if err != nil { + return nil + } + defer sess.Close() + + rows, err := sess.Query(queryCancellingJobsSelect) + if err != nil { + return nil + } + defer rows.Close() + + var jobIDs []string + for rows.Next() { + var jobID string + if err := rows.Scan(&jobID); err != nil { + continue + } + jobIDs = append(jobIDs, jobID) + } + + return jobIDs +} + +// updateJobStatusToCancelled updates job status to CANCELLED in database +func (p *Pool[T]) updateJobStatusToCancelled(jobID string) error { + if p.db == nil { + return nil + } + + sess, err := p.db.NewSession(true) + if err != nil { + return err + } + defer sess.Close() + + _, err = sess.Exec(queryJobStatusCancelledUpdate, jobID) + + if err == nil { + return sess.Commit() + } + + return err +} diff --git a/internal/pkg/pool/queries/cancelling_jobs_select.sql b/internal/pkg/pool/queries/cancelling_jobs_select.sql new file mode 100644 index 0000000..520c9b9 --- /dev/null +++ b/internal/pkg/pool/queries/cancelling_jobs_select.sql @@ -0,0 +1,9 @@ +select + j.job_id +from + jobs j + join job_statuses js on j.job_status_id = js.job_status_id +where + js.job_status_name = 'CANCELLING' +limit + 25; diff --git a/internal/pkg/pool/queries/job_status_cancelled_update.sql b/internal/pkg/pool/queries/job_status_cancelled_update.sql new file mode 100644 index 0000000..8d97fd5 --- /dev/null +++ b/internal/pkg/pool/queries/job_status_cancelled_update.sql @@ -0,0 +1,6 @@ +update jobs +set + job_status_id = (select job_status_id from job_statuses where job_status_name = 'CANCELLED'), + updated_at = extract(epoch from now())::int +where + job_id = $1; diff --git a/pkg/object/job/status/status.go b/pkg/object/job/status/status.go index efe449a..375c540 100644 --- a/pkg/object/job/status/status.go +++ b/pkg/object/job/status/status.go @@ -11,12 +11,14 @@ import ( type Status status.Status const ( - New Status = 1 - Accepted Status = 2 - Running Status = 3 - Failed Status = 4 - Killed Status = 5 - Succeeded Status = 6 + New Status = 1 + Accepted Status = 2 + Running Status = 3 + Failed Status = 4 + Killed Status = 5 + Succeeded Status = 6 + Cancelling Status = 7 + Cancelled Status = 8 ) const ( @@ -25,12 +27,14 @@ const ( var ( statusMapping = map[string]status.Status{ - `new`: status.Status(New), - `accepted`: status.Status(Accepted), - `running`: status.Status(Running), - `failed`: status.Status(Failed), - `killed`: status.Status(Killed), - `succeeded`: status.Status(Succeeded), + `new`: status.Status(New), + `accepted`: status.Status(Accepted), + `running`: status.Status(Running), + `failed`: status.Status(Failed), + `killed`: status.Status(Killed), + `succeeded`: status.Status(Succeeded), + `cancelling`: status.Status(Cancelling), + `cancelled`: status.Status(Cancelled), } ) diff --git a/pkg/plugin/plugin.go b/pkg/plugin/plugin.go index 3512cbf..cfb6db6 100644 --- a/pkg/plugin/plugin.go +++ b/pkg/plugin/plugin.go @@ -1,8 +1,10 @@ package plugin import ( + "context" + "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" ) -type Handler func(*Runtime, *job.Job, *cluster.Cluster) error +type Handler func(context.Context, *Runtime, *job.Job, *cluster.Cluster) error From 925c8897ecb488487b6ef9577861d52486f702b6 Mon Sep 17 00:00:00 2001 From: wlggraham Date: Mon, 17 Nov 2025 17:17:08 -0700 Subject: [PATCH 02/10] update query mod --- internal/pkg/heimdall/job.go | 10 ++++++++-- internal/pkg/heimdall/job_dal.go | 3 ++- internal/pkg/pool/pool_dal.go | 8 +++++--- .../pkg/pool/queries/job_status_cancelled_update.sql | 6 ------ internal/pkg/pool/queries/job_status_update_by_id.sql | 7 +++++++ 5 files changed, 22 insertions(+), 12 deletions(-) delete mode 100644 internal/pkg/pool/queries/job_status_cancelled_update.sql create mode 100644 internal/pkg/pool/queries/job_status_update_by_id.sql diff --git a/internal/pkg/heimdall/job.go b/internal/pkg/heimdall/job.go index 62e01b9..cdb0e8c 100644 --- a/internal/pkg/heimdall/job.go +++ b/internal/pkg/heimdall/job.go @@ -119,6 +119,7 @@ func (h *Heimdall) runJob(job *job.Job, command *command.Command, cluster *clust } + // Check if context was cancelled if ctx.Err() != nil { job.Status = jobStatus.Cancelled @@ -174,19 +175,24 @@ func (h *Heimdall) cancelJob(req *jobRequest) (any, error) { return nil, err } - job := currentJob.(*job.Job) + // make sure we have a job object + job, ok := currentJob.(*job.Job) + if !ok { + return nil, fmt.Errorf("expected *job.Job, got %T", currentJob) + } // check if job can be cancelled (must be running or accepted) if job.Status != jobStatus.Running && job.Status != jobStatus.Accepted { return nil, fmt.Errorf("job cannot be cancelled: current status is %v", job.Status) } // update job status to CANCELLING - job.Status = jobStatus.Cancelling if err := h.updateJobStatusToCancelling(job); err != nil { return nil, err } + // update object to return to caller job.UpdatedAt = int(time.Now().Unix()) + job.Status = jobStatus.Cancelling return job, nil } diff --git a/internal/pkg/heimdall/job_dal.go b/internal/pkg/heimdall/job_dal.go index 53a4e75..5ae0856 100644 --- a/internal/pkg/heimdall/job_dal.go +++ b/internal/pkg/heimdall/job_dal.go @@ -11,6 +11,7 @@ import ( "github.com/patterninc/heimdall/internal/pkg/database" "github.com/patterninc/heimdall/pkg/object" "github.com/patterninc/heimdall/pkg/object/job" + jobStatus "github.com/patterninc/heimdall/pkg/object/job/status" ) //go:embed queries/job/insert.sql @@ -315,7 +316,7 @@ func (h *Heimdall) updateJobStatusToCancelling(job *job.Job) error { defer sess.Close() // update job status to CANCELLING - _, err = sess.Exec(queryJobStatusUpdate, job.Status, job.Error, job.SystemID) + _, err = sess.Exec(queryJobStatusUpdate, jobStatus.Cancelling, job.Error, job.SystemID) if err != nil { return err } diff --git a/internal/pkg/pool/pool_dal.go b/internal/pkg/pool/pool_dal.go index c09d89b..ba07889 100644 --- a/internal/pkg/pool/pool_dal.go +++ b/internal/pkg/pool/pool_dal.go @@ -2,13 +2,15 @@ package pool import ( _ "embed" + + "github.com/patterninc/heimdall/pkg/object/job/status" ) //go:embed queries/cancelling_jobs_select.sql var queryCancellingJobsSelect string -//go:embed queries/job_status_cancelled_update.sql -var queryJobStatusCancelledUpdate string +//go:embed queries/job_status_update_by_id.sql +var queryJobStatusUpdate string // getCancellingJobs retrieves jobs in CANCELLING state from database func (p *Pool[T]) getCancellingJobs() []string { @@ -49,7 +51,7 @@ func (p *Pool[T]) updateJobStatusToCancelled(jobID string) error { } defer sess.Close() - _, err = sess.Exec(queryJobStatusCancelledUpdate, jobID) + _, err = sess.Exec(queryJobStatusUpdate, status.Cancelled, "", jobID) if err == nil { return sess.Commit() diff --git a/internal/pkg/pool/queries/job_status_cancelled_update.sql b/internal/pkg/pool/queries/job_status_cancelled_update.sql deleted file mode 100644 index 8d97fd5..0000000 --- a/internal/pkg/pool/queries/job_status_cancelled_update.sql +++ /dev/null @@ -1,6 +0,0 @@ -update jobs -set - job_status_id = (select job_status_id from job_statuses where job_status_name = 'CANCELLED'), - updated_at = extract(epoch from now())::int -where - job_id = $1; diff --git a/internal/pkg/pool/queries/job_status_update_by_id.sql b/internal/pkg/pool/queries/job_status_update_by_id.sql new file mode 100644 index 0000000..c1ff488 --- /dev/null +++ b/internal/pkg/pool/queries/job_status_update_by_id.sql @@ -0,0 +1,7 @@ +update jobs +set + job_status_id = $1, + job_error = $2, + updated_at = extract(epoch from now())::int +where + job_id = $3; From fadc32e222c126d90223f17012f4c12176f9434a Mon Sep 17 00:00:00 2001 From: wlggraham Date: Tue, 18 Nov 2025 14:27:42 -0700 Subject: [PATCH 03/10] move cancellation poll to main job routine --- internal/pkg/heimdall/cluster_dal.go | 17 +-- internal/pkg/heimdall/command_dal.go | 17 +-- internal/pkg/heimdall/handler.go | 5 +- internal/pkg/heimdall/heimdall.go | 2 +- internal/pkg/heimdall/job.go | 123 +++++++++++++----- internal/pkg/heimdall/job_dal.go | 9 +- internal/pkg/heimdall/jobs_async.go | 2 +- internal/pkg/object/command/ecs/ecs.go | 9 +- internal/pkg/pool/pool.go | 76 +---------- internal/pkg/pool/pool_dal.go | 61 --------- .../pool/queries/cancelling_jobs_select.sql | 9 -- .../pool/queries/job_status_update_by_id.sql | 7 - 12 files changed, 122 insertions(+), 215 deletions(-) delete mode 100644 internal/pkg/pool/pool_dal.go delete mode 100644 internal/pkg/pool/queries/cancelling_jobs_select.sql delete mode 100644 internal/pkg/pool/queries/job_status_update_by_id.sql diff --git a/internal/pkg/heimdall/cluster_dal.go b/internal/pkg/heimdall/cluster_dal.go index 99a31c6..22820aa 100644 --- a/internal/pkg/heimdall/cluster_dal.go +++ b/internal/pkg/heimdall/cluster_dal.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" "database/sql" _ "embed" "encoding/json" @@ -70,13 +71,13 @@ var ( ErrUnknownClusterID = fmt.Errorf(`unknown cluster_id`) ) -func (h *Heimdall) submitCluster(c *cluster.Cluster) (any, error) { +func (h *Heimdall) submitCluster(ctx context.Context, c *cluster.Cluster) (any, error) { if err := h.clusterUpsert(c); err != nil { return nil, err } - return h.getCluster(&cluster.Cluster{Object: object.Object{ID: c.ID}}) + return h.getCluster(ctx, &cluster.Cluster{Object: object.Object{ID: c.ID}}) } @@ -116,7 +117,7 @@ func (h *Heimdall) clusterUpsert(c *cluster.Cluster) error { } -func (h *Heimdall) getCluster(c *cluster.Cluster) (any, error) { +func (h *Heimdall) getCluster(ctx context.Context, c *cluster.Cluster) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -156,7 +157,7 @@ func (h *Heimdall) getCluster(c *cluster.Cluster) (any, error) { } -func (h *Heimdall) getClusterStatus(c *cluster.Cluster) (any, error) { +func (h *Heimdall) getClusterStatus(ctx context.Context, c *cluster.Cluster) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -185,7 +186,7 @@ func (h *Heimdall) getClusterStatus(c *cluster.Cluster) (any, error) { } -func (h *Heimdall) updateClusterStatus(c *cluster.Cluster) (any, error) { +func (h *Heimdall) updateClusterStatus(ctx context.Context, c *cluster.Cluster) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -203,11 +204,11 @@ func (h *Heimdall) updateClusterStatus(c *cluster.Cluster) (any, error) { return nil, ErrUnknownClusterID } - return h.getClusterStatus(c) + return h.getClusterStatus(ctx, c) } -func (h *Heimdall) getClusters(f *database.Filter) (any, error) { +func (h *Heimdall) getClusters(ctx context.Context, f *database.Filter) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -253,7 +254,7 @@ func (h *Heimdall) getClusters(f *database.Filter) (any, error) { } -func (h *Heimdall) getClusterStatuses(_ *database.Filter) (any, error) { +func (h *Heimdall) getClusterStatuses(ctx context.Context, _ *database.Filter) (any, error) { return database.GetSlice(h.Database, queryClusterStatusesSelect) diff --git a/internal/pkg/heimdall/command_dal.go b/internal/pkg/heimdall/command_dal.go index c47262b..6195711 100644 --- a/internal/pkg/heimdall/command_dal.go +++ b/internal/pkg/heimdall/command_dal.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" "database/sql" _ "embed" "encoding/json" @@ -82,13 +83,13 @@ var ( ErrUnknownCommandID = fmt.Errorf(`unknown command_id`) ) -func (h *Heimdall) submitCommand(c *command.Command) (any, error) { +func (h *Heimdall) submitCommand(ctx context.Context, c *command.Command) (any, error) { if err := h.commandUpsert(c); err != nil { return nil, err } - return h.getCommand(&command.Command{Object: object.Object{ID: c.ID}}) + return h.getCommand(ctx, &command.Command{Object: object.Object{ID: c.ID}}) } @@ -145,7 +146,7 @@ func (h *Heimdall) commandUpsert(c *command.Command) error { } -func (h *Heimdall) getCommand(c *command.Command) (any, error) { +func (h *Heimdall) getCommand(ctx context.Context, c *command.Command) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -185,7 +186,7 @@ func (h *Heimdall) getCommand(c *command.Command) (any, error) { } -func (h *Heimdall) getCommandStatus(c *command.Command) (any, error) { +func (h *Heimdall) getCommandStatus(ctx context.Context, c *command.Command) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -214,7 +215,7 @@ func (h *Heimdall) getCommandStatus(c *command.Command) (any, error) { } -func (h *Heimdall) updateCommandStatus(c *command.Command) (any, error) { +func (h *Heimdall) updateCommandStatus(ctx context.Context, c *command.Command) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -232,11 +233,11 @@ func (h *Heimdall) updateCommandStatus(c *command.Command) (any, error) { return nil, ErrUnknownCommandID } - return h.getCommandStatus(c) + return h.getCommandStatus(ctx, c) } -func (h *Heimdall) getCommands(f *database.Filter) (any, error) { +func (h *Heimdall) getCommands(ctx context.Context, f *database.Filter) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -282,7 +283,7 @@ func (h *Heimdall) getCommands(f *database.Filter) (any, error) { } -func (h *Heimdall) getCommandStatuses(_ *database.Filter) (any, error) { +func (h *Heimdall) getCommandStatuses(ctx context.Context, _ *database.Filter) (any, error) { return database.GetSlice(h.Database, queryCommandStatusesSelect) diff --git a/internal/pkg/heimdall/handler.go b/internal/pkg/heimdall/handler.go index 194169b..55607ff 100644 --- a/internal/pkg/heimdall/handler.go +++ b/internal/pkg/heimdall/handler.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" "encoding/json" "fmt" "io" @@ -48,7 +49,7 @@ func writeAPIError(w http.ResponseWriter, err error, obj any) { w.Write(responseJSON) } -func payloadHandler[T any](fn func(*T) (any, error)) http.HandlerFunc { +func payloadHandler[T any](fn func(context.Context, *T) (any, error)) http.HandlerFunc { // start latency timer defer payloadHandlerMethod.RecordLatency(time.Now()) @@ -81,7 +82,7 @@ func payloadHandler[T any](fn func(*T) (any, error)) http.HandlerFunc { } // execute request - result, err := fn(&payload) + result, err := fn(r.Context(), &payload) if err != nil { writeAPIError(w, err, result) return diff --git a/internal/pkg/heimdall/heimdall.go b/internal/pkg/heimdall/heimdall.go index dbf5fd7..871546d 100644 --- a/internal/pkg/heimdall/heimdall.go +++ b/internal/pkg/heimdall/heimdall.go @@ -156,7 +156,7 @@ func (h *Heimdall) Init() error { } // let's start the agent - return h.Pool.Start(h.runAsyncJob, h.getAsyncJobs, h.Database) + return h.Pool.Start(h.runAsyncJob, h.getAsyncJobs) } diff --git a/internal/pkg/heimdall/job.go b/internal/pkg/heimdall/job.go index cdb0e8c..4bd660f 100644 --- a/internal/pkg/heimdall/job.go +++ b/internal/pkg/heimdall/job.go @@ -3,6 +3,7 @@ package heimdall import ( "context" "crypto/rand" + _ "embed" "encoding/json" "fmt" "math/big" @@ -45,7 +46,7 @@ type commandOnCluster struct { cluster *cluster.Cluster } -func (h *Heimdall) submitJob(j *job.Job) (any, error) { +func (h *Heimdall) submitJob(ctx context.Context, j *job.Job) (any, error) { // set / add job properties if err := j.Init(); err != nil { @@ -68,7 +69,7 @@ func (h *Heimdall) submitJob(j *job.Job) (any, error) { } // let's run the job - err = h.runJob(j, command, cluster, context.Background()) + err = h.runJob(ctx, j, command, cluster) // before we process the error, we'll make the best effort to record this job in the database go h.insertJob(j, cluster.ID, command.ID) @@ -77,7 +78,7 @@ func (h *Heimdall) submitJob(j *job.Job) (any, error) { } -func (h *Heimdall) runJob(job *job.Job, command *command.Command, cluster *cluster.Cluster, ctx context.Context) error { +func (h *Heimdall) runJob(ctx context.Context, job *job.Job, command *command.Command, cluster *cluster.Cluster) error { defer runJobMethod.RecordLatency(time.Now(), command.Name, cluster.Name) runJobMethod.CountRequest(command.Name, cluster.Name) @@ -107,25 +108,54 @@ func (h *Heimdall) runJob(job *job.Job, command *command.Command, cluster *clust // ...and now we just start keepalive function for this job go h.jobKeepalive(keepaliveActive, job.SystemID, h.agentName) - // let's execute command - if err := h.commandHandlers[command.ID](ctx, runtime, job, cluster); err != nil { - - job.Status = jobStatus.Failed - job.Error = err.Error() - - runJobMethod.LogAndCountError(err, command.Name, cluster.Name) + // Create channels for coordination between plugin execution and cancellation monitoring + jobDone := make(chan error, 1) + cancelMonitorDone := make(chan struct{}) + + // Create cancellable context for the job + pluginCtx, cancel := context.WithCancel(ctx) + + // Start plugin execution in goroutine + go func() { + defer close(cancelMonitorDone) // signal monitoring to stop + err := h.commandHandlers[command.ID](pluginCtx, runtime, job, cluster) + jobDone <- err + }() + + // Start cancellation monitoring in goroutine + go func() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-cancelMonitorDone: + return // plugin finished, stop monitoring + case <-ticker.C: + if h.isJobCancelling(job) { + cancel() // trigger context cancellation + return + } + } + } + }() - return err + // Wait for job execution to complete + jobErr := <-jobDone + // Check if context was cancelled FIRST (takes precedence over plugin errors) + if pluginCtx.Err() != nil { + job.Status = jobStatus.Cancelling // janitor finishes the cancellation process + runJobMethod.LogAndCountError(pluginCtx.Err(), command.Name, cluster.Name) + return nil } - // Check if context was cancelled - if ctx.Err() != nil { - - job.Status = jobStatus.Cancelled - runJobMethod.LogAndCountError(ctx.Err(), command.Name, cluster.Name) - - return nil + // Handle plugin execution result (only if not cancelled) + if jobErr != nil { + job.Status = jobStatus.Failed + job.Error = jobErr.Error() + runJobMethod.LogAndCountError(jobErr, command.Name, cluster.Name) + return jobErr } if job.StoreResultSync || !job.IsSync { @@ -167,10 +197,10 @@ func (h *Heimdall) storeResults(runtime *plugin.Runtime, job *job.Job) error { return nil } -func (h *Heimdall) cancelJob(req *jobRequest) (any, error) { +func (h *Heimdall) cancelJob(ctx context.Context, req *jobRequest) (any, error) { // validate that job exists and get its current status - currentJob, err := h.getJob(req) + currentJob, err := h.getJob(ctx, req) if err != nil { return nil, err } @@ -181,20 +211,24 @@ func (h *Heimdall) cancelJob(req *jobRequest) (any, error) { return nil, fmt.Errorf("expected *job.Job, got %T", currentJob) } - // check if job can be cancelled (must be running or accepted) - if job.Status != jobStatus.Running && job.Status != jobStatus.Accepted { + // check current job status + switch job.Status { + // already cancelled/cancelling - return success (idempotent) + case jobStatus.Cancelling, jobStatus.Cancelled: + return job, nil + case jobStatus.Running, jobStatus.Accepted: + // can be cancelled - proceed with cancellation + if err := h.updateJobStatusToCancelling(job); err != nil { + return nil, err + } + // update object to return to caller + job.UpdatedAt = int(time.Now().Unix()) + job.Status = jobStatus.Cancelling + return job, nil + default: + // job is in a final state (succeeded, failed, etc.) - cannot be cancelled return nil, fmt.Errorf("job cannot be cancelled: current status is %v", job.Status) } - // update job status to CANCELLING - if err := h.updateJobStatusToCancelling(job); err != nil { - return nil, err - } - - // update object to return to caller - job.UpdatedAt = int(time.Now().Unix()) - job.Status = jobStatus.Cancelling - - return job, nil } func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { @@ -213,7 +247,7 @@ func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { } // let's validate jobID we got - if _, err := h.getJobStatus(&jobRequest{ID: jobID}); err != nil { + if _, err := h.getJobStatus(r.Context(), &jobRequest{ID: jobID}); err != nil { writeAPIError(w, err, nil) return } @@ -304,3 +338,26 @@ func (h *Heimdall) resolveJob(commandCriteria, clusterCriteria *set.Set[string]) return pairs[pairIndex].command, pairs[pairIndex].cluster, nil } + +// isJobCancelling checks if a specific job is in CANCELLING state +func (h *Heimdall) isJobCancelling(j *job.Job) bool { + sess, err := h.Database.NewSession(false) + if err != nil { + return false + } + defer sess.Close() + + row, err := sess.QueryRow(queryJobStatusSelect, j.ID) + if err != nil { + return false + } + + r := &job.Job{} + + err = row.Scan(&r.Status, &r.Error, &r.UpdatedAt) + if err != nil { + return false + } + + return r.Status == jobStatus.Cancelling +} diff --git a/internal/pkg/heimdall/job_dal.go b/internal/pkg/heimdall/job_dal.go index 5ae0856..ef015f6 100644 --- a/internal/pkg/heimdall/job_dal.go +++ b/internal/pkg/heimdall/job_dal.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" "database/sql" _ "embed" "encoding/json" @@ -157,7 +158,7 @@ func (h *Heimdall) insertJob(j *job.Job, clusterID, commandID string) (int64, er } -func (h *Heimdall) getJob(j *jobRequest) (any, error) { +func (h *Heimdall) getJob(ctx context.Context, j *jobRequest) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -197,7 +198,7 @@ func (h *Heimdall) getJob(j *jobRequest) (any, error) { } -func (h *Heimdall) getJobs(f *database.Filter) (any, error) { +func (h *Heimdall) getJobs(ctx context.Context, f *database.Filter) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -243,7 +244,7 @@ func (h *Heimdall) getJobs(f *database.Filter) (any, error) { } -func (h *Heimdall) getJobStatus(j *jobRequest) (any, error) { +func (h *Heimdall) getJobStatus(ctx context.Context, j *jobRequest) (any, error) { // open connection sess, err := h.Database.NewSession(false) @@ -300,7 +301,7 @@ func jobParseContextAndTags(j *job.Job, jobContext string, sess *database.Sessio } -func (h *Heimdall) getJobStatuses(_ *database.Filter) (any, error) { +func (h *Heimdall) getJobStatuses(ctx context.Context, _ *database.Filter) (any, error) { return database.GetSlice(h.Database, queryJobStatusesSelect) diff --git a/internal/pkg/heimdall/jobs_async.go b/internal/pkg/heimdall/jobs_async.go index cebc819..3da8400 100644 --- a/internal/pkg/heimdall/jobs_async.go +++ b/internal/pkg/heimdall/jobs_async.go @@ -114,7 +114,7 @@ func (h *Heimdall) runAsyncJob(ctx context.Context, j *job.Job) error { return h.updateAsyncJobStatus(j, fmt.Errorf(formatErrUnknownCluster, j.CluserID)) } - return h.updateAsyncJobStatus(j, h.runJob(j, command, cluster, ctx)) + return h.updateAsyncJobStatus(j, h.runJob(ctx, j, command, cluster)) } diff --git a/internal/pkg/object/command/ecs/ecs.go b/internal/pkg/object/command/ecs/ecs.go index 85c2bf2..31205a5 100644 --- a/internal/pkg/object/command/ecs/ecs.go +++ b/internal/pkg/object/command/ecs/ecs.go @@ -305,13 +305,8 @@ func (execCtx *executionContext) pollForCompletion(ctx ct.Context) error { return fmt.Errorf("%s", reason) } - // Check for cancellation or sleep until next poll time - select { - case <-ctx.Done(): - stopAllTasks(execCtx, "Job cancelled by user") - return nil - case <-time.After(time.Duration(execCtx.PollingInterval)): - } + // sleep for polling interval + time.Sleep(time.Duration(execCtx.PollingInterval)) } // If you're here, all tasks are complete diff --git a/internal/pkg/pool/pool.go b/internal/pkg/pool/pool.go index 6ceb0f1..254160d 100644 --- a/internal/pkg/pool/pool.go +++ b/internal/pkg/pool/pool.go @@ -3,10 +3,8 @@ package pool import ( "context" "fmt" - "sync" "time" - "github.com/patterninc/heimdall/internal/pkg/database" "github.com/patterninc/heimdall/pkg/object/job" ) @@ -19,16 +17,9 @@ type Pool[T any] struct { Size int `yaml:"size,omitempty" json:"size,omitempty"` Sleep int `yaml:"sleep,omitempty" json:"sleep,omitempty"` queue chan *job.Job - - runningJobs map[string]context.CancelFunc - runningJobsMux sync.RWMutex - db *database.Database } -func (p *Pool[T]) Start(worker func(context.Context, *job.Job) error, getWork func(int) ([]*job.Job, error), database *database.Database) error { - - // record database context - p.db = database +func (p *Pool[T]) Start(worker func(context.Context, *job.Job) error, getWork func(int) ([]*job.Job, error)) error { // do we have the size set? if p.Size <= 0 { @@ -46,12 +37,6 @@ func (p *Pool[T]) Start(worker func(context.Context, *job.Job) error, getWork fu // let's set the counter tokens := &counter{} - // Initialize tracking - p.runningJobs = make(map[string]context.CancelFunc) - - // Start cancellation polling loop - go p.pollForCancellations() - // let's provision workers for i := 0; i < p.Size; i++ { @@ -70,16 +55,8 @@ func (p *Pool[T]) Start(worker func(context.Context, *job.Job) error, getWork fu break } - ctx, cancel := context.WithCancel(context.Background()) - - // register job as running - p.registerRunningJob(w.ID, cancel) - // do the work.... - err := worker(ctx, w) - - // remove job from running jobs - p.unregisterRunningJob(w.ID) + err := worker(context.Background(), w) if err != nil { // TODO: implement proper error logging @@ -135,52 +112,3 @@ func (p *Pool[T]) Start(worker func(context.Context, *job.Job) error, getWork fu return nil } - -func (p *Pool[T]) pollForCancellations() { - // let's poll for cancellations every 15 seconds - ticker := time.NewTicker(15 * time.Second) - defer ticker.Stop() - - for range ticker.C { - - // Get jobs in CANCELLING state from database - cancellingJobs := p.getCancellingJobs() - - // Check each cancelling job - for _, jobID := range cancellingJobs { - if cancelFunc, isLocal := p.isJobRunningLocally(jobID); isLocal { - cancelFunc() // Trigger context cancellation - - // Update job status to CANCELLED in database - p.updateJobStatusToCancelled(jobID) - } - } - } -} - -func (p *Pool[T]) registerRunningJob(jobID string, cancel context.CancelFunc) { - - p.runningJobsMux.Lock() - defer p.runningJobsMux.Unlock() - - p.runningJobs[jobID] = cancel - -} - -func (p *Pool[T]) unregisterRunningJob(jobID string) { - - p.runningJobsMux.Lock() - defer p.runningJobsMux.Unlock() - - delete(p.runningJobs, jobID) -} - -// Check if a job is running locally -func (p *Pool[T]) isJobRunningLocally(jobID string) (context.CancelFunc, bool) { - - p.runningJobsMux.RLock() - defer p.runningJobsMux.RUnlock() - - cancelFunc, exists := p.runningJobs[jobID] - return cancelFunc, exists -} diff --git a/internal/pkg/pool/pool_dal.go b/internal/pkg/pool/pool_dal.go deleted file mode 100644 index ba07889..0000000 --- a/internal/pkg/pool/pool_dal.go +++ /dev/null @@ -1,61 +0,0 @@ -package pool - -import ( - _ "embed" - - "github.com/patterninc/heimdall/pkg/object/job/status" -) - -//go:embed queries/cancelling_jobs_select.sql -var queryCancellingJobsSelect string - -//go:embed queries/job_status_update_by_id.sql -var queryJobStatusUpdate string - -// getCancellingJobs retrieves jobs in CANCELLING state from database -func (p *Pool[T]) getCancellingJobs() []string { - - sess, err := p.db.NewSession(false) - if err != nil { - return nil - } - defer sess.Close() - - rows, err := sess.Query(queryCancellingJobsSelect) - if err != nil { - return nil - } - defer rows.Close() - - var jobIDs []string - for rows.Next() { - var jobID string - if err := rows.Scan(&jobID); err != nil { - continue - } - jobIDs = append(jobIDs, jobID) - } - - return jobIDs -} - -// updateJobStatusToCancelled updates job status to CANCELLED in database -func (p *Pool[T]) updateJobStatusToCancelled(jobID string) error { - if p.db == nil { - return nil - } - - sess, err := p.db.NewSession(true) - if err != nil { - return err - } - defer sess.Close() - - _, err = sess.Exec(queryJobStatusUpdate, status.Cancelled, "", jobID) - - if err == nil { - return sess.Commit() - } - - return err -} diff --git a/internal/pkg/pool/queries/cancelling_jobs_select.sql b/internal/pkg/pool/queries/cancelling_jobs_select.sql deleted file mode 100644 index 520c9b9..0000000 --- a/internal/pkg/pool/queries/cancelling_jobs_select.sql +++ /dev/null @@ -1,9 +0,0 @@ -select - j.job_id -from - jobs j - join job_statuses js on j.job_status_id = js.job_status_id -where - js.job_status_name = 'CANCELLING' -limit - 25; diff --git a/internal/pkg/pool/queries/job_status_update_by_id.sql b/internal/pkg/pool/queries/job_status_update_by_id.sql deleted file mode 100644 index c1ff488..0000000 --- a/internal/pkg/pool/queries/job_status_update_by_id.sql +++ /dev/null @@ -1,7 +0,0 @@ -update jobs -set - job_status_id = $1, - job_error = $2, - updated_at = extract(epoch from now())::int -where - job_id = $3; From fe9d65b20747e1fcc2c787ce9d1bddcd391ccb40 Mon Sep 17 00:00:00 2001 From: wlggraham Date: Thu, 20 Nov 2025 12:26:36 -0700 Subject: [PATCH 04/10] update plugins to use context --- internal/pkg/aws/cloudwatch.go | 3 +- internal/pkg/aws/glue.go | 9 +- internal/pkg/aws/s3.go | 5 +- internal/pkg/heimdall/job.go | 85 ++++++++----------- internal/pkg/heimdall/job_dal.go | 20 ----- .../object/command/clickhouse/clickhouse.go | 18 ++-- internal/pkg/object/command/dynamo/dynamo.go | 9 +- internal/pkg/object/command/ecs/ecs.go | 43 +++++----- internal/pkg/object/command/glue/glue.go | 10 +-- internal/pkg/object/command/ping/ping.go | 8 +- internal/pkg/object/command/shell/shell.go | 10 +-- .../pkg/object/command/snowflake/snowflake.go | 10 +-- internal/pkg/object/command/spark/spark.go | 17 ++-- .../pkg/object/command/sparkeks/sparkeks.go | 75 ++++++++-------- .../object/command/sparkeks/sparkeks_test.go | 5 +- internal/pkg/object/command/trino/client.go | 15 ++-- internal/pkg/object/command/trino/trino.go | 16 ++-- internal/pkg/pool/pool.go | 8 +- 18 files changed, 162 insertions(+), 204 deletions(-) diff --git a/internal/pkg/aws/cloudwatch.go b/internal/pkg/aws/cloudwatch.go index d8467ba..543c9f9 100644 --- a/internal/pkg/aws/cloudwatch.go +++ b/internal/pkg/aws/cloudwatch.go @@ -1,6 +1,7 @@ package aws import ( + "context" "fmt" "os" "time" @@ -10,7 +11,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" ) -func PullLogs(writer *os.File, logGroup, logStream string, chunkSize int, memoryLimit int64) error { +func PullLogs(ctx context.Context, writer *os.File, logGroup, logStream string, chunkSize int, memoryLimit int64) error { // initialize AWS session cfg, err := config.LoadDefaultConfig(ctx) diff --git a/internal/pkg/aws/glue.go b/internal/pkg/aws/glue.go index 94489a0..a540d94 100644 --- a/internal/pkg/aws/glue.go +++ b/internal/pkg/aws/glue.go @@ -1,6 +1,7 @@ package aws import ( + "context" "fmt" "strings" @@ -17,7 +18,7 @@ var ( ErrMissingCatalogTableMetadata = fmt.Errorf(`missing table metadata in the glue catalog`) ) -func GetTableMetadata(catalogID, tableName string) ([]byte, error) { +func GetTableMetadata(ctx context.Context, catalogID, tableName string) ([]byte, error) { // split tableName to namespace and table names tableNameParts := strings.Split(tableName, `.`) @@ -27,18 +28,18 @@ func GetTableMetadata(catalogID, tableName string) ([]byte, error) { } // let's get the latest metadata file location - location, err := getTableMetadataLocation(catalogID, tableNameParts[0], tableNameParts[1]) + location, err := getTableMetadataLocation(ctx, catalogID, tableNameParts[0], tableNameParts[1]) if err != nil { return nil, err } // let's pull the file content - return ReadFromS3(location) + return ReadFromS3(ctx, location) } // function that calls AWS glue catalog to get the snapshot ID for a given database, table and branch -func getTableMetadataLocation(catalogID, databaseName, tableName string) (string, error) { +func getTableMetadataLocation(ctx context.Context, catalogID, databaseName, tableName string) (string, error) { // Return an error if databaseName or tableName is empty if databaseName == `` || tableName == `` { diff --git a/internal/pkg/aws/s3.go b/internal/pkg/aws/s3.go index 1be4514..51843e2 100644 --- a/internal/pkg/aws/s3.go +++ b/internal/pkg/aws/s3.go @@ -13,12 +13,11 @@ import ( ) var ( - ctx = context.Background() rxS3Path = regexp.MustCompile(`^s3://([^/]+)/(.*)$`) ) // WriteToS3 writes a file to S3, providing the same interface as os.WriteFile function -func WriteToS3(name string, data []byte, _ os.FileMode) error { +func WriteToS3(ctx context.Context, name string, data []byte, _ os.FileMode) error { bucket, key, err := parseS3Path(name) if err != nil { @@ -47,7 +46,7 @@ func WriteToS3(name string, data []byte, _ os.FileMode) error { } -func ReadFromS3(name string) ([]byte, error) { +func ReadFromS3(ctx context.Context, name string) ([]byte, error) { bucket, key, err := parseS3Path(name) if err != nil { diff --git a/internal/pkg/heimdall/job.go b/internal/pkg/heimdall/job.go index 4bd660f..fd2d8f3 100644 --- a/internal/pkg/heimdall/job.go +++ b/internal/pkg/heimdall/job.go @@ -78,16 +78,16 @@ func (h *Heimdall) submitJob(ctx context.Context, j *job.Job) (any, error) { } -func (h *Heimdall) runJob(ctx context.Context, job *job.Job, command *command.Command, cluster *cluster.Cluster) error { +func (h *Heimdall) runJob(ctx context.Context, j *job.Job, command *command.Command, cluster *cluster.Cluster) error { defer runJobMethod.RecordLatency(time.Now(), command.Name, cluster.Name) runJobMethod.CountRequest(command.Name, cluster.Name) // let's set environment runtime := &plugin.Runtime{ - WorkingDirectory: h.JobsDirectory + separator + job.ID, - ArchiveDirectory: h.ArchiveDirectory + separator + job.ID, - ResultDirectory: h.ResultDirectory + separator + job.ID, + WorkingDirectory: h.JobsDirectory + separator + j.ID, + ArchiveDirectory: h.ArchiveDirectory + separator + j.ID, + ResultDirectory: h.ResultDirectory + separator + j.ID, Version: h.Version, UserAgent: fmt.Sprintf(formatUserAgent, h.Version), } @@ -106,7 +106,7 @@ func (h *Heimdall) runJob(ctx context.Context, job *job.Job, command *command.Co defer close(keepaliveActive) // ...and now we just start keepalive function for this job - go h.jobKeepalive(keepaliveActive, job.SystemID, h.agentName) + go h.jobKeepalive(keepaliveActive, j.SystemID, h.agentName) // Create channels for coordination between plugin execution and cancellation monitoring jobDone := make(chan error, 1) @@ -118,7 +118,7 @@ func (h *Heimdall) runJob(ctx context.Context, job *job.Job, command *command.Co // Start plugin execution in goroutine go func() { defer close(cancelMonitorDone) // signal monitoring to stop - err := h.commandHandlers[command.ID](pluginCtx, runtime, job, cluster) + err := h.commandHandlers[command.ID](pluginCtx, runtime, j, cluster) jobDone <- err }() @@ -129,12 +129,17 @@ func (h *Heimdall) runJob(ctx context.Context, job *job.Job, command *command.Co for { select { + // plugin finished, stop monitoring case <-cancelMonitorDone: - return // plugin finished, stop monitoring + return case <-ticker.C: - if h.isJobCancelling(job) { - cancel() // trigger context cancellation - return + // If job is in cancelling state, trigger context cancellation + result, err := h.getJobStatus(ctx, &jobRequest{ID: j.ID}) + if err == nil { + if job, ok := result.(*job.Job); ok && job.Status == jobStatus.Cancelling { + cancel() + return + } } } } @@ -143,42 +148,42 @@ func (h *Heimdall) runJob(ctx context.Context, job *job.Job, command *command.Co // Wait for job execution to complete jobErr := <-jobDone - // Check if context was cancelled FIRST (takes precedence over plugin errors) + // Check if context was cancelled and mark status appropriately if pluginCtx.Err() != nil { - job.Status = jobStatus.Cancelling // janitor finishes the cancellation process + j.Status = jobStatus.Cancelling // janitor will update to cancelled when resources are cleaned up runJobMethod.LogAndCountError(pluginCtx.Err(), command.Name, cluster.Name) return nil } // Handle plugin execution result (only if not cancelled) if jobErr != nil { - job.Status = jobStatus.Failed - job.Error = jobErr.Error() + j.Status = jobStatus.Failed + j.Error = jobErr.Error() runJobMethod.LogAndCountError(jobErr, command.Name, cluster.Name) return jobErr } - if job.StoreResultSync || !job.IsSync { - h.storeResults(runtime, job) + if j.StoreResultSync || !j.IsSync { + h.storeResults(runtime, j) } else { - go h.storeResults(runtime, job) + go h.storeResults(runtime, j) } - job.Status = jobStatus.Succeeded + j.Status = jobStatus.Succeeded runJobMethod.CountSuccess(command.Name, cluster.Name) return nil } -func (h *Heimdall) storeResults(runtime *plugin.Runtime, job *job.Job) error { +func (h *Heimdall) storeResults(runtime *plugin.Runtime, j *job.Job) error { // do we have result to be written? - if job.Result == nil { + if j.Result == nil { return nil } // prepare result - data, err := json.Marshal(job.Result) + data, err := json.Marshal(j.Result) if err != nil { return err @@ -187,7 +192,9 @@ func (h *Heimdall) storeResults(runtime *plugin.Runtime, job *job.Job) error { // write result writeFileFunc := os.WriteFile if strings.HasPrefix(runtime.ResultDirectory, s3Prefix) { - writeFileFunc = aws.WriteToS3 + writeFileFunc = func(name string, data []byte, perm os.FileMode) error { + return aws.WriteToS3(context.Background(), name, data, perm) + } } if err := writeFileFunc(runtime.ResultDirectory+separator+resultFilename, data, 0600); err != nil { @@ -218,12 +225,11 @@ func (h *Heimdall) cancelJob(ctx context.Context, req *jobRequest) (any, error) return job, nil case jobStatus.Running, jobStatus.Accepted: // can be cancelled - proceed with cancellation - if err := h.updateJobStatusToCancelling(job); err != nil { + job.Status = jobStatus.Cancelling + job.UpdatedAt = int(time.Now().Unix()) + if err := h.updateAsyncJobStatus(job, nil); err != nil { return nil, err } - // update object to return to caller - job.UpdatedAt = int(time.Now().Unix()) - job.Status = jobStatus.Cancelling return job, nil default: // job is in a final state (succeeded, failed, etc.) - cannot be cancelled @@ -266,7 +272,9 @@ func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { readFileFunc := os.ReadFile filenamePath := fmt.Sprintf(jobFileFormat, sourceDirectory, jobID, filename) if strings.HasPrefix(filenamePath, s3Prefix) { - readFileFunc = aws.ReadFromS3 + readFileFunc = func(path string) ([]byte, error) { + return aws.ReadFromS3(context.Background(), path) + } } // get file's content @@ -338,26 +346,3 @@ func (h *Heimdall) resolveJob(commandCriteria, clusterCriteria *set.Set[string]) return pairs[pairIndex].command, pairs[pairIndex].cluster, nil } - -// isJobCancelling checks if a specific job is in CANCELLING state -func (h *Heimdall) isJobCancelling(j *job.Job) bool { - sess, err := h.Database.NewSession(false) - if err != nil { - return false - } - defer sess.Close() - - row, err := sess.QueryRow(queryJobStatusSelect, j.ID) - if err != nil { - return false - } - - r := &job.Job{} - - err = row.Scan(&r.Status, &r.Error, &r.UpdatedAt) - if err != nil { - return false - } - - return r.Status == jobStatus.Cancelling -} diff --git a/internal/pkg/heimdall/job_dal.go b/internal/pkg/heimdall/job_dal.go index ef015f6..2fcb95e 100644 --- a/internal/pkg/heimdall/job_dal.go +++ b/internal/pkg/heimdall/job_dal.go @@ -12,7 +12,6 @@ import ( "github.com/patterninc/heimdall/internal/pkg/database" "github.com/patterninc/heimdall/pkg/object" "github.com/patterninc/heimdall/pkg/object/job" - jobStatus "github.com/patterninc/heimdall/pkg/object/job/status" ) //go:embed queries/job/insert.sql @@ -306,22 +305,3 @@ func (h *Heimdall) getJobStatuses(ctx context.Context, _ *database.Filter) (any, return database.GetSlice(h.Database, queryJobStatusesSelect) } - -func (h *Heimdall) updateJobStatusToCancelling(job *job.Job) error { - - // open connection - sess, err := h.Database.NewSession(true) - if err != nil { - return err - } - defer sess.Close() - - // update job status to CANCELLING - _, err = sess.Exec(queryJobStatusUpdate, jobStatus.Cancelling, job.Error, job.SystemID) - if err != nil { - return err - } - - // commit transaction - return sess.Commit() -} diff --git a/internal/pkg/object/command/clickhouse/clickhouse.go b/internal/pkg/object/command/clickhouse/clickhouse.go index 472c781..5c67cf3 100644 --- a/internal/pkg/object/command/clickhouse/clickhouse.go +++ b/internal/pkg/object/command/clickhouse/clickhouse.go @@ -7,7 +7,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2" "github.com/ClickHouse/clickhouse-go/v2/lib/driver" "github.com/hladush/go-telemetry/pkg/telemetry" - hdctx "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/object/job/status" @@ -16,7 +16,7 @@ import ( "github.com/patterninc/heimdall/pkg/result/column" ) -type commandContext struct { +type clickhouseCommandContext struct { Username string `yaml:"username,omitempty" json:"username,omitempty"` Password string `yaml:"password,omitempty" json:"password,omitempty"` } @@ -45,11 +45,11 @@ var ( ) // New creates a new clickhouse plugin handler -func New(ctx *hdctx.Context) (plugin.Handler, error) { - t := &commandContext{} +func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { + t := &clickhouseCommandContext{} - if ctx != nil { - if err := ctx.Unmarshal(t); err != nil { + if commandContext != nil { + if err := commandContext.Unmarshal(t); err != nil { return nil, err } } @@ -57,9 +57,9 @@ func New(ctx *hdctx.Context) (plugin.Handler, error) { return t.handler, nil } -func (cmd *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (cmd *clickhouseCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { - jobContext, err := cmd.createJobContext(j, c) + jobContext, err := cmd.createJobContext(ctx, j, c) if err != nil { handleMethod.LogAndCountError(err, "create_job_context") return err @@ -81,7 +81,7 @@ func (cmd *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *jo return nil } -func (cmd *commandContext) createJobContext(j *job.Job, c *cluster.Cluster) (*jobContext, error) { +func (cmd *clickhouseCommandContext) createJobContext(ctx context.Context, j *job.Job, c *cluster.Cluster) (*jobContext, error) { // get cluster context clusterCtx := &clusterContext{} if c.Context != nil { diff --git a/internal/pkg/object/command/dynamo/dynamo.go b/internal/pkg/object/command/dynamo/dynamo.go index b36b0d8..77dae5e 100644 --- a/internal/pkg/object/command/dynamo/dynamo.go +++ b/internal/pkg/object/command/dynamo/dynamo.go @@ -1,7 +1,7 @@ package dynamo import ( - ct "context" + "context" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" @@ -9,7 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -31,12 +31,11 @@ type dynamoClusterContext struct { type dynamoCommandContext struct{} var ( - ctx = ct.Background() assumeRoleSession = aws.String("AssumeRoleSession") ) // New creates a new dynamo plugin handler. -func New(_ *context.Context) (plugin.Handler, error) { +func New(_ *heimdallContext.Context) (plugin.Handler, error) { s := &dynamoCommandContext{} return s.handler, nil @@ -44,7 +43,7 @@ func New(_ *context.Context) (plugin.Handler, error) { } // Handler for the Spark job submission. -func (d *dynamoCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { +func (d *dynamoCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { // let's unmarshal job context jobContext := &dynamoJobContext{} diff --git a/internal/pkg/object/command/ecs/ecs.go b/internal/pkg/object/command/ecs/ecs.go index d43d1b8..0256284 100644 --- a/internal/pkg/object/command/ecs/ecs.go +++ b/internal/pkg/object/command/ecs/ecs.go @@ -1,7 +1,7 @@ package ecs import ( - ct "context" + "context" "encoding/json" "fmt" "os" @@ -15,7 +15,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/hladush/go-telemetry/pkg/telemetry" heimdallAws "github.com/patterninc/heimdall/internal/pkg/aws" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/duration" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" @@ -116,12 +116,11 @@ const ( ) var ( - ctx = ct.Background() errMissingTemplate = fmt.Errorf("task definition template is required") methodMetrics = telemetry.NewMethod("ecs", "ecs plugin") ) -func New(commandContext *context.Context) (plugin.Handler, error) { +func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { e := &ecsCommandContext{ PollingInterval: defaultPollingInterval, @@ -141,21 +140,21 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } // handler implements the main ECS plugin logic -func (e *ecsCommandContext) handler(ctx ct.Context, r *plugin.Runtime, job *job.Job, cluster *cluster.Cluster) error { +func (e *ecsCommandContext) handler(ctx context.Context, r *plugin.Runtime, job *job.Job, cluster *cluster.Cluster) error { // Build execution context with resolved configuration and loaded template - execCtx, err := buildExecutionContext(e, job, cluster, r) + execCtx, err := buildExecutionContext(ctx, e, job, cluster, r) if err != nil { return err } // register task definition - if err := execCtx.registerTaskDefinition(); err != nil { + if err := execCtx.registerTaskDefinition(ctx); err != nil { return err } // Start tasks - if err := execCtx.startTasks(job.ID); err != nil { + if err := execCtx.startTasks(ctx, job.ID); err != nil { return err } @@ -165,7 +164,7 @@ func (e *ecsCommandContext) handler(ctx ct.Context, r *plugin.Runtime, job *job. } // Try to retrieve logs, but don't fail the job if it fails - if err := execCtx.retrieveLogs(); err != nil { + if err := execCtx.retrieveLogs(ctx); err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Failed to retrieve logs: %v\n", err)) } @@ -180,7 +179,7 @@ func (e *ecsCommandContext) handler(ctx ct.Context, r *plugin.Runtime, job *job. } // prepare and register task definition with ECS -func (execCtx *executionContext) registerTaskDefinition() error { +func (execCtx *executionContext) registerTaskDefinition(ctx context.Context) error { registerInput := &ecs.RegisterTaskDefinitionInput{ Family: aws.String(aws.ToString(execCtx.TaskDefinitionWrapper.TaskDefinition.Family)), RequiresCompatibilities: []types.Compatibility{types.CompatibilityFargate}, @@ -204,10 +203,10 @@ func (execCtx *executionContext) registerTaskDefinition() error { } // startTasks launches all tasks and returns a map of task trackers -func (execCtx *executionContext) startTasks(jobID string) error { +func (execCtx *executionContext) startTasks(ctx context.Context, jobID string) error { for i := 0; i < execCtx.TaskCount; i++ { - taskARN, err := runTask(execCtx, fmt.Sprintf("%s%s-%d", startedByPrefix, jobID, i), i) + taskARN, err := runTask(ctx, execCtx, fmt.Sprintf("%s%s-%d", startedByPrefix, jobID, i), i) if err != nil { return err } @@ -223,7 +222,7 @@ func (execCtx *executionContext) startTasks(jobID string) error { } // monitor tasks until completion, faliure, or timeout -func (execCtx *executionContext) pollForCompletion(ctx ct.Context) error { +func (execCtx *executionContext) pollForCompletion(ctx context.Context) error { startTime := time.Now() stopTime := startTime.Add(time.Duration(execCtx.Timeout)) @@ -287,7 +286,7 @@ func (execCtx *executionContext) pollForCompletion(ctx ct.Context) error { // Stop all other running tasks reason := fmt.Sprintf(errMaxFailCount, tracker.ActiveARN, tracker.Retries, execCtx.MaxFailCount) - if err := stopAllTasks(execCtx, reason); err != nil { + if err := stopAllTasks(ctx, execCtx, reason); err != nil { return err } @@ -296,7 +295,7 @@ func (execCtx *executionContext) pollForCompletion(ctx ct.Context) error { break } - newTaskARN, err := runTask(execCtx, tracker.Name, tracker.TaskNum) + newTaskARN, err := runTask(ctx, execCtx, tracker.Name, tracker.TaskNum) if err != nil { return err } @@ -330,7 +329,7 @@ func (execCtx *executionContext) pollForCompletion(ctx ct.Context) error { // Stop all remaining tasks reason := fmt.Sprintf(errPollingTimeout, incompleteARNs, execCtx.Timeout) - if err := stopAllTasks(execCtx, reason); err != nil { + if err := stopAllTasks(ctx, execCtx, reason); err != nil { return err } @@ -347,7 +346,7 @@ func (execCtx *executionContext) pollForCompletion(ctx ct.Context) error { } -func buildExecutionContext(commandCtx *ecsCommandContext, j *job.Job, c *cluster.Cluster, runtime *plugin.Runtime) (*executionContext, error) { +func buildExecutionContext(ctx context.Context, commandCtx *ecsCommandContext, j *job.Job, c *cluster.Cluster, runtime *plugin.Runtime) (*executionContext, error) { execCtx := &executionContext{ tasks: make(map[string]*taskTracker), @@ -355,7 +354,7 @@ func buildExecutionContext(commandCtx *ecsCommandContext, j *job.Job, c *cluster } // Create a context from commandCtx and unmarshal onto execCtx (defaults) - commandContext := context.New(commandCtx) + commandContext := heimdallContext.New(commandCtx) if err := commandContext.Unmarshal(execCtx); err != nil { return nil, err } @@ -464,7 +463,7 @@ func buildContainerOverrides(execCtx *executionContext) error { } // stopAllTasks stops all non-completed tasks with the given reason -func stopAllTasks(execCtx *executionContext, reason string) error { +func stopAllTasks(ctx context.Context, execCtx *executionContext, reason string) error { // AWS ECS has a 1024 character limit on the reason field if len(reason) > 1024 { reason = reason[:1021] + "..." @@ -537,7 +536,7 @@ func loadTaskDefinitionTemplate(templatePath string) (*taskDefinitionWrapper, er } // runTask runs a single task and returns the task ARN -func runTask(execCtx *executionContext, startedBy string, taskNum int) (string, error) { +func runTask(ctx context.Context, execCtx *executionContext, startedBy string, taskNum int) (string, error) { // Create a copy of the overrides and add TASK_NAME and TASK_NUM env variables finalOverrides := append([]types.ContainerOverride{}, execCtx.ContainerOverrides...) @@ -603,7 +602,7 @@ func isTaskSuccessful(task types.Task, execCtx *executionContext) bool { } // We pull logs from cloudwatch for all containers in a single task that represents the job outcome -func (execCtx *executionContext) retrieveLogs() error { +func (execCtx *executionContext) retrieveLogs(ctx context.Context) error { var selectedTask *taskTracker var writer *os.File @@ -655,7 +654,7 @@ func (execCtx *executionContext) retrieveLogs() error { case types.LogDriverAwslogs: logGroup := logInfo.options["awslogs-group"] logStream := fmt.Sprintf("%s/%s/%s", logInfo.options["awslogs-stream-prefix"], logInfo.containerName, taskID) - if err := heimdallAws.PullLogs(writer, logGroup, logStream, maxLogChunkSize, maxLogMemoryBytes); err != nil { + if err := heimdallAws.PullLogs(ctx, writer, logGroup, logStream, maxLogChunkSize, maxLogMemoryBytes); err != nil { return err } default: diff --git a/internal/pkg/object/command/glue/glue.go b/internal/pkg/object/command/glue/glue.go index b2ac5b2..cece65b 100644 --- a/internal/pkg/object/command/glue/glue.go +++ b/internal/pkg/object/command/glue/glue.go @@ -1,10 +1,10 @@ package glue import ( - ct "context" + "context" "github.com/patterninc/heimdall/internal/pkg/aws" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -19,7 +19,7 @@ type glueJobContext struct { TableName string `yaml:"table_name,omitempty" json:"table_name,omitempty"` } -func New(commandContext *context.Context) (plugin.Handler, error) { +func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { g := &glueCommandContext{} @@ -33,7 +33,7 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } -func (g *glueCommandContext) handler(ct ct.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { +func (g *glueCommandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { // let's unmarshal job context jc := &glueJobContext{} @@ -44,7 +44,7 @@ func (g *glueCommandContext) handler(ct ct.Context, _ *plugin.Runtime, j *job.Jo } // let's get our metadata - metadata, err := aws.GetTableMetadata(g.CatalogID, jc.TableName) + metadata, err := aws.GetTableMetadata(ctx, g.CatalogID, jc.TableName) if err != nil { return } diff --git a/internal/pkg/object/command/ping/ping.go b/internal/pkg/object/command/ping/ping.go index 3b438f4..caaa686 100644 --- a/internal/pkg/object/command/ping/ping.go +++ b/internal/pkg/object/command/ping/ping.go @@ -1,10 +1,10 @@ package ping import ( - ct "context" + "context" "fmt" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -17,14 +17,14 @@ const ( type pingCommandContext struct{} -func New(_ *context.Context) (plugin.Handler, error) { +func New(_ *heimdallContext.Context) (plugin.Handler, error) { p := &pingCommandContext{} return p.handler, nil } -func (p *pingCommandContext) handler(ct ct.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { +func (p *pingCommandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { j.Result, err = result.FromMessage(fmt.Sprintf(messageFormat, j.User)) return diff --git a/internal/pkg/object/command/shell/shell.go b/internal/pkg/object/command/shell/shell.go index 9b21037..acbbd4c 100644 --- a/internal/pkg/object/command/shell/shell.go +++ b/internal/pkg/object/command/shell/shell.go @@ -1,13 +1,13 @@ package shell import ( - ct "context" + "context" "encoding/json" "os" "os/exec" "path" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -34,7 +34,7 @@ type runtimeContext struct { Runtime *plugin.Runtime `yaml:"runtime,omitempty" json:"runtime,omitempty"` } -func New(commandContext *context.Context) (plugin.Handler, error) { +func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { s := &shellCommandContext{} @@ -48,7 +48,7 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } -func (s *shellCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *shellCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // let's unmarshal job context jc := &shellJobContext{} @@ -83,7 +83,7 @@ func (s *shellCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.J commandWithArguments = append(commandWithArguments, jc.Arguments...) // configure command - cmd := exec.Command(commandWithArguments[0], commandWithArguments[1:]...) + cmd := exec.CommandContext(ctx, commandWithArguments[0], commandWithArguments[1:]...) cmd.Stdout = r.Stdout cmd.Stderr = r.Stderr diff --git a/internal/pkg/object/command/snowflake/snowflake.go b/internal/pkg/object/command/snowflake/snowflake.go index f4d2c1a..efdd133 100644 --- a/internal/pkg/object/command/snowflake/snowflake.go +++ b/internal/pkg/object/command/snowflake/snowflake.go @@ -1,7 +1,7 @@ package snowflake import ( - ct "context" + "context" "crypto/rsa" "crypto/x509" "database/sql" @@ -11,7 +11,7 @@ import ( sf "github.com/snowflakedb/gosnowflake" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -66,12 +66,12 @@ func parsePrivateKey(privateKeyBytes []byte) (*rsa.PrivateKey, error) { } -func New(_ *context.Context) (plugin.Handler, error) { +func New(_ *heimdallContext.Context) (plugin.Handler, error) { s := &snowflakeCommandContext{} return s.handler, nil } -func (s *snowflakeCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *snowflakeCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { clusterContext := &snowflakeClusterContext{} if c.Context != nil { @@ -118,7 +118,7 @@ func (s *snowflakeCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *j } defer db.Close() - rows, err := db.Query(jobContext.Query) + rows, err := db.QueryContext(ctx, jobContext.Query) if err != nil { return err } diff --git a/internal/pkg/object/command/spark/spark.go b/internal/pkg/object/command/spark/spark.go index a624933..0f99790 100644 --- a/internal/pkg/object/command/spark/spark.go +++ b/internal/pkg/object/command/spark/spark.go @@ -1,7 +1,7 @@ package spark import ( - ct "context" + "context" "encoding/json" "fmt" "os" @@ -19,7 +19,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/babourine/x/pkg/set" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -64,7 +64,6 @@ const ( ) var ( - ctx = ct.Background() sparkDefaults = aws.String(`spark-defaults`) assumeRoleSession = aws.String("AssumeRoleSession") runtimeStates = set.New([]types.JobRunState{types.JobRunStateCompleted, types.JobRunStateFailed, types.JobRunStateCancelled}) @@ -77,7 +76,7 @@ var ( ) // New creates a new Spark plugin handler. -func New(commandContext *context.Context) (plugin.Handler, error) { +func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { s := &sparkCommandContext{} @@ -92,7 +91,7 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } // Handler for the Spark job submission. -func (s *sparkCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { +func (s *sparkCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { // let's unmarshal job context jobContext := &sparkJobContext{} @@ -165,7 +164,7 @@ func (s *sparkCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.J svc := emrcontainers.NewFromConfig(awsConfig, assumeRoleOptions) // let's get the cluster ID - clusterID, err := getClusterID(svc, c.Name) + clusterID, err := getClusterID(ctx, svc, c.Name) if err != nil { return err } @@ -175,7 +174,7 @@ func (s *sparkCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.J // upload query to s3 here... queryURI := fmt.Sprintf("%s/%s/query.sql", s.QueriesURI, j.ID) - if err := uploadFileToS3(queryURI, jobContext.Query); err != nil { + if err := uploadFileToS3(ctx, queryURI, jobContext.Query); err != nil { return err } @@ -288,7 +287,7 @@ func (s *sparkCommandContext) setJobDriver(jobContext *sparkJobContext, jobDrive } -func getClusterID(svc *emrcontainers.Client, clusterName string) (*string, error) { +func getClusterID(ctx context.Context, svc *emrcontainers.Client, clusterName string) (*string, error) { // let's get the cluster ID outputListClusters, err := svc.ListVirtualClusters(ctx, &emrcontainers.ListVirtualClustersInput{ @@ -327,7 +326,7 @@ func printState(stdout *os.File, state types.JobRunState) { stdout.WriteString(fmt.Sprintf("%v - job is still running. latest status: %v\n", time.Now(), state)) } -func uploadFileToS3(fileURI, content string) error { +func uploadFileToS3(ctx context.Context, fileURI, content string) error { // get bucket name and prefix s3Parts := rxS3.FindAllStringSubmatch(fileURI, -1) diff --git a/internal/pkg/object/command/sparkeks/sparkeks.go b/internal/pkg/object/command/sparkeks/sparkeks.go index 2aa007a..c41702d 100644 --- a/internal/pkg/object/command/sparkeks/sparkeks.go +++ b/internal/pkg/object/command/sparkeks/sparkeks.go @@ -3,7 +3,6 @@ package sparkeks import ( "bytes" "context" - ct "context" "encoding/json" "fmt" "io" @@ -74,7 +73,6 @@ const ( ) var ( - ctx = context.Background() rxS3 = regexp.MustCompile(`^s3://([^/]+)/(.*)$`) runtimeStates = []v1beta2.ApplicationStateType{ v1beta2.ApplicationStateCompleted, @@ -156,27 +154,24 @@ func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { } // handler executes the Spark EKS job submission and execution. -func (s *sparkEksCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { - - // Assign global context to incoming cancellation context - ctx = ct +func (s *sparkEksCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // 1. Build execution context, create URIs, and upload query - execCtx, err := buildExecutionContextAndURI(r, j, c, s) + execCtx, err := buildExecutionContextAndURI(ctx, r, j, c, s) if err != nil { return fmt.Errorf("failed to build execution context: %w", err) } // 2. Submit the Spark Application to the cluster - if err := execCtx.submitSparkApp(); err != nil { + if err := execCtx.submitSparkApp(ctx); err != nil { return err } // 3. Monitor the job until completion and collect logs - monitorErr := execCtx.monitorJobAndCollectLogs() + monitorErr := execCtx.monitorJobAndCollectLogs(ctx) // 4. Cleanup any resources that are still pending - if err := execCtx.cleanupSparkApp(); err != nil { + if err := execCtx.cleanupSparkApp(ctx); err != nil { // Log cleanup error but don't override the main monitoring error execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Warning: failed to cleanup application %s: %v\n", execCtx.submittedApp.Name, err)) } @@ -187,7 +182,7 @@ func (s *sparkEksCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *jo } // 5. Get and store results if required - if err := execCtx.getAndStoreResults(); err != nil { + if err := execCtx.getAndStoreResults(ctx); err != nil { return err } @@ -195,7 +190,7 @@ func (s *sparkEksCommandContext) handler(ct ct.Context, r *plugin.Runtime, j *jo } // buildExecutionContextAndURI prepares the context, merges configurations, and uploads the query. -func buildExecutionContextAndURI(r *plugin.Runtime, j *job.Job, c *cluster.Cluster, s *sparkEksCommandContext) (*executionContext, error) { +func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster, s *sparkEksCommandContext) (*executionContext, error) { execCtx := &executionContext{ runtime: r, job: j, @@ -253,12 +248,12 @@ func buildExecutionContextAndURI(r *plugin.Runtime, j *job.Job, c *cluster.Clust execCtx.logURI = fmt.Sprintf("%s/%s/%s", s.JobsURI, j.ID, logsPath) // Upload query to S3 - if err := uploadFileToS3(execCtx.awsConfig, execCtx.queryURI, execCtx.jobContext.Query); err != nil { + if err := uploadFileToS3(ctx, execCtx.awsConfig, execCtx.queryURI, execCtx.jobContext.Query); err != nil { return nil, fmt.Errorf("failed to upload query to S3: %w", err) } // create empty log s3 directory to avoid spark event log dir errors - if err := uploadFileToS3(execCtx.awsConfig, fmt.Sprintf("%s/.keepdir", execCtx.logURI), ""); err != nil { + if err := uploadFileToS3(ctx, execCtx.awsConfig, fmt.Sprintf("%s/.keepdir", execCtx.logURI), ""); err != nil { return nil, fmt.Errorf("failed to create log directory in S3: %w", err) } @@ -266,9 +261,9 @@ func buildExecutionContextAndURI(r *plugin.Runtime, j *job.Job, c *cluster.Clust } // submitSparkApp creates clients, generates the spec, and submits it to Kubernetes. -func (e *executionContext) submitSparkApp() error { +func (e *executionContext) submitSparkApp(ctx context.Context) error { // Create Kubernetes and Spark Operator clients - if err := createSparkClients(e); err != nil { + if err := createSparkClients(ctx, e); err != nil { return fmt.Errorf("failed to create Spark Operator client: %w", err) } @@ -302,7 +297,7 @@ func (e *executionContext) submitSparkApp() error { } // cleanupSparkApp removes the SparkApplication from the cluster if it still exists. -func (e *executionContext) cleanupSparkApp() error { +func (e *executionContext) cleanupSparkApp(ctx context.Context) error { if e.submittedApp == nil { return nil } @@ -321,12 +316,12 @@ func (e *executionContext) cleanupSparkApp() error { } // getAndStoreResults fetches the job output from S3 and stores it. -func (e *executionContext) getAndStoreResults() error { +func (e *executionContext) getAndStoreResults(ctx context.Context) error { if !e.jobContext.ReturnResult { return nil } - returnResultFileURI, err := getS3FileURI(e.awsConfig, e.resultURI, avroFileExtension) + returnResultFileURI, err := getS3FileURI(ctx, e.awsConfig, e.resultURI, avroFileExtension) if err != nil { e.runtime.Stdout.WriteString(fmt.Sprintf("failed to find .avro file in results directory %s: %s", e.resultURI, err)) return fmt.Errorf("failed to find .avro file in results directory %s: %w", e.resultURI, err) @@ -340,7 +335,7 @@ func (e *executionContext) getAndStoreResults() error { } // uploadFileToS3 uploads content to S3. -func uploadFileToS3(awsConfig aws.Config, fileURI, content string) error { +func uploadFileToS3(ctx context.Context, awsConfig aws.Config, fileURI, content string) error { s3Parts := rxS3.FindAllStringSubmatch(fileURI, -1) if len(s3Parts) == 0 || len(s3Parts[0]) < 3 { return fmt.Errorf("unexpected S3 URI format: %s", fileURI) @@ -363,7 +358,7 @@ func updateS3ToS3aURI(uri string) string { } // getS3FileURI finds a file in an S3 directory that matches the given extension. -func getS3FileURI(awsConfig aws.Config, directoryURI, matchingExtension string) (string, error) { +func getS3FileURI(ctx context.Context, awsConfig aws.Config, directoryURI, matchingExtension string) (string, error) { s3Parts := rxS3.FindAllStringSubmatch(directoryURI, -1) if len(s3Parts) == 0 || len(s3Parts[0]) < 3 { return "", fmt.Errorf("invalid S3 URI format: %s", directoryURI) @@ -408,7 +403,7 @@ func getSparkSubmitParameters(context *sparkEksJobContext) *string { } // getSparkApplicationPods returns the list of pods associated with a Spark application. -func getSparkApplicationPods(kubeClient *kubernetes.Clientset, appName, namespace string) ([]corev1.Pod, error) { +func getSparkApplicationPods(ctx context.Context, kubeClient *kubernetes.Clientset, appName, namespace string) ([]corev1.Pod, error) { labelSelector := fmt.Sprintf("%s=%s", sparkAppLabelSelectorFormat, appName) podList, err := kubeClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{LabelSelector: labelSelector}) if err != nil { @@ -442,7 +437,7 @@ func writeDriverLogsToStderr(execCtx *executionContext, pod corev1.Pod, logConte } // getAndUploadPodContainerLogs fetches logs from a specific container in a pod and uploads them to S3. -func getAndUploadPodContainerLogs(execCtx *executionContext, pod corev1.Pod, container corev1.Container, previous bool, logType string, writeToStderr bool) { +func getAndUploadPodContainerLogs(ctx context.Context, execCtx *executionContext, pod corev1.Pod, container corev1.Container, previous bool, logType string, writeToStderr bool) { logOptions := &corev1.PodLogOptions{Container: container.Name, Previous: previous} req := execCtx.kubeClient.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, logOptions) logs, err := req.Stream(ctx) @@ -462,31 +457,31 @@ func getAndUploadPodContainerLogs(execCtx *executionContext, pod corev1.Pod, con } logURI := fmt.Sprintf("%s/%s-%s", execCtx.logURI, pod.Name, logType) - if err := uploadFileToS3(execCtx.awsConfig, logURI, string(logContent)); err != nil { + if err := uploadFileToS3(ctx, execCtx.awsConfig, logURI, string(logContent)); err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Pod %s, container %s: %s upload error: %v\n", pod.Name, container.Name, logType, err)) } } } // getSparkApplicationPodLogs fetches logs from pods and uploads them to S3. -func getSparkApplicationPodLogs(execCtx *executionContext, pods []corev1.Pod, writeToStderr bool) error { +func getSparkApplicationPodLogs(ctx context.Context, execCtx *executionContext, pods []corev1.Pod, writeToStderr bool) error { for _, pod := range pods { if !isPodInValidPhase(pod) { continue } for _, container := range pod.Spec.Containers { // Get current logs and upload - getAndUploadPodContainerLogs(execCtx, pod, container, false, stdoutLogSuffix, writeToStderr) + getAndUploadPodContainerLogs(ctx, execCtx, pod, container, false, stdoutLogSuffix, writeToStderr) // Get logs from previous (failed) runs and upload - getAndUploadPodContainerLogs(execCtx, pod, container, true, stderrLogSuffix, false) + getAndUploadPodContainerLogs(ctx, execCtx, pod, container, true, stderrLogSuffix, false) } } return nil } // createSparkClients creates Kubernetes and Spark clients for the EKS cluster. -func createSparkClients(execCtx *executionContext) error { - kubeconfigPath, err := updateKubeConfig(execCtx) +func createSparkClients(ctx context.Context, execCtx *executionContext) error { + kubeconfigPath, err := updateKubeConfig(ctx, execCtx) if err != nil { return fmt.Errorf("failed to update kubeconfig: %w", err) } @@ -518,7 +513,7 @@ func createSparkClients(execCtx *executionContext) error { return nil } -func updateKubeConfig(execCtx *executionContext) (string, error) { +func updateKubeConfig(ctx context.Context, execCtx *executionContext) (string, error) { region := os.Getenv(awsRegionEnvVar) if execCtx.clusterContext.Region != nil { region = *execCtx.clusterContext.Region @@ -767,7 +762,7 @@ func loadTemplate(execCtx *executionContext) (*v1beta2.SparkApplication, error) } // monitorJobAndCollectLogs monitors the Spark job until completion and collects logs. -func (e *executionContext) monitorJobAndCollectLogs() error { +func (e *executionContext) monitorJobAndCollectLogs(ctx context.Context) error { appName, namespace := e.submittedApp.Name, e.submittedApp.Namespace e.runtime.Stdout.WriteString(fmt.Sprintf("Monitoring Spark application: %s\n", appName)) @@ -779,7 +774,7 @@ func (e *executionContext) monitorJobAndCollectLogs() error { for { if monitorCtx.Err() != nil { if finalSparkApp != nil { - collectSparkApplicationLogs(e, finalSparkApp, true) + collectSparkApplicationLogs(ctx, e, finalSparkApp, true) } if monitorCtx.Err() == context.DeadlineExceeded { return fmt.Errorf("spark job timed out after %v", jobTimeout) @@ -793,7 +788,7 @@ func (e *executionContext) monitorJobAndCollectLogs() error { if err != nil { e.runtime.Stderr.WriteString(fmt.Sprintf("Spark application %s/%s not found or deleted externally: %v\n", namespace, appName, err)) if finalSparkApp != nil { - collectSparkApplicationLogs(e, finalSparkApp, true) + collectSparkApplicationLogs(ctx, e, finalSparkApp, true) } return fmt.Errorf("spark application %s/%s not found: %w", namespace, appName, err) } @@ -807,17 +802,17 @@ func (e *executionContext) monitorJobAndCollectLogs() error { } if state == v1beta2.ApplicationStateRunning { - collectSparkApplicationLogs(e, sparkApp, false) + collectSparkApplicationLogs(ctx, e, sparkApp, false) continue } switch state { case v1beta2.ApplicationStateCompleted: - collectSparkApplicationLogs(e, sparkApp, false) + collectSparkApplicationLogs(ctx, e, sparkApp, false) e.runtime.Stdout.WriteString("Spark job completed successfully\n") return nil case v1beta2.ApplicationStateFailed: - collectSparkApplicationLogs(e, sparkApp, true) + collectSparkApplicationLogs(ctx, e, sparkApp, true) errorMessage := sparkApp.Status.AppState.ErrorMessage if errorMessage == "" { errorMessage = unknownErrorMsg @@ -825,7 +820,7 @@ func (e *executionContext) monitorJobAndCollectLogs() error { e.runtime.Stderr.WriteString(fmt.Sprintf("Spark job failed: %s\n", errorMessage)) return fmt.Errorf("spark job failed: %s", errorMessage) case v1beta2.ApplicationStateFailedSubmission, v1beta2.ApplicationStateUnknown: - collectSparkApplicationLogs(e, finalSparkApp, true) + collectSparkApplicationLogs(ctx, e, finalSparkApp, true) msg := sparkJobSubmissionFailedMsg if state == v1beta2.ApplicationStateUnknown { msg = sparkAppUnknownStateMsg @@ -836,17 +831,17 @@ func (e *executionContext) monitorJobAndCollectLogs() error { } // collectSparkApplicationLogs collects logs from Spark application pods. -func collectSparkApplicationLogs(execCtx *executionContext, sparkApp *v1beta2.SparkApplication, writeToStderr bool) { +func collectSparkApplicationLogs(ctx context.Context, execCtx *executionContext, sparkApp *v1beta2.SparkApplication, writeToStderr bool) { if sparkApp == nil { return } - pods, err := getSparkApplicationPods(execCtx.kubeClient, sparkApp.Name, sparkApp.Namespace) + pods, err := getSparkApplicationPods(ctx, execCtx.kubeClient, sparkApp.Name, sparkApp.Namespace) if err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Warning: failed to get Spark application pods: %v\n", err)) return } - if err := getSparkApplicationPodLogs(execCtx, pods, writeToStderr); err != nil { + if err := getSparkApplicationPodLogs(ctx, execCtx, pods, writeToStderr); err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Warning: failed to collect pod logs: %v\n", err)) } } diff --git a/internal/pkg/object/command/sparkeks/sparkeks_test.go b/internal/pkg/object/command/sparkeks/sparkeks_test.go index ad56b69..8c4235b 100644 --- a/internal/pkg/object/command/sparkeks/sparkeks_test.go +++ b/internal/pkg/object/command/sparkeks/sparkeks_test.go @@ -1,6 +1,7 @@ package sparkeks import ( + "context" "os" "strings" "testing" @@ -61,7 +62,7 @@ func TestGetSparkSubmitParameters_NilParameters(t *testing.T) { } func TestGetS3FileURI_InvalidFormat(t *testing.T) { - _, err := getS3FileURI(aws.Config{}, "invalid-uri", "avro") + _, err := getS3FileURI(context.Background(), aws.Config{}, "invalid-uri", "avro") if err == nil { t.Error("Expected error for invalid S3 URI format") } @@ -70,7 +71,7 @@ func TestGetS3FileURI_InvalidFormat(t *testing.T) { func TestGetS3FileURI_ValidFormat(t *testing.T) { // This test only checks parsing, not actual AWS interaction uri := "s3://bucket/path/" - _, err := getS3FileURI(aws.Config{}, uri, "avro") + _, err := getS3FileURI(context.Background(), aws.Config{}, uri, "avro") // Should not error on parsing, but will error on AWS call (which is fine for unit test context) if err == nil || !strings.Contains(err.Error(), "failed to list S3 objects") { t.Logf("Expected AWS list objects error, got: %v", err) diff --git a/internal/pkg/object/command/trino/client.go b/internal/pkg/object/command/trino/client.go index 793ddf9..ce39bd5 100644 --- a/internal/pkg/object/command/trino/client.go +++ b/internal/pkg/object/command/trino/client.go @@ -2,6 +2,7 @@ package trino import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -45,7 +46,7 @@ type response struct { Error map[string]any `json:"error"` } -func newRequest(r *plugin.Runtime, j *job.Job, c *cluster.Cluster, jobCtx *jobContext) (*request, error) { +func newRequest(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster, jobCtx *jobContext) (*request, error) { // get cluster context clusterCtx := &clusterContext{} @@ -72,7 +73,7 @@ func newRequest(r *plugin.Runtime, j *job.Job, c *cluster.Cluster, jobCtx *jobCo } // submit query - if err := req.submit(jobCtx.Query); err != nil { + if err := req.submit(ctx, jobCtx.Query); err != nil { return nil, err } @@ -80,9 +81,9 @@ func newRequest(r *plugin.Runtime, j *job.Job, c *cluster.Cluster, jobCtx *jobCo } -func (r *request) submit(query string) error { +func (r *request) submit(ctx context.Context, query string) error { - req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/v1/statement", r.endpoint), bytes.NewBuffer([]byte(query))) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/v1/statement", r.endpoint), bytes.NewBuffer([]byte(query))) if err != nil { return err } @@ -91,9 +92,9 @@ func (r *request) submit(query string) error { } -func (r *request) poll() error { +func (r *request) poll(ctx context.Context) error { - req, err := http.NewRequest(http.MethodGet, r.nextUri, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.nextUri, nil) if err != nil { return err } @@ -174,4 +175,4 @@ func (r *request) api(req *http.Request) error { func normalizeTrinoQuery(query string) string { // Trino does not support semicolon at the end of the query, so we remove it if present return strings.TrimSuffix(query, ";") -} \ No newline at end of file +} diff --git a/internal/pkg/object/command/trino/trino.go b/internal/pkg/object/command/trino/trino.go index 1a96146..8d246da 100644 --- a/internal/pkg/object/command/trino/trino.go +++ b/internal/pkg/object/command/trino/trino.go @@ -1,13 +1,13 @@ package trino import ( - ct "context" + "context" "fmt" "log" "time" "github.com/hladush/go-telemetry/pkg/telemetry" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -37,14 +37,14 @@ type jobContext struct { } // New creates a new trino plugin handler -func New(ctx *context.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { t := &commandContext{ PollInterval: defaultPollInterval, } - if ctx != nil { - if err := ctx.Unmarshal(t); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(t); err != nil { return nil, err } } @@ -53,7 +53,7 @@ func New(ctx *context.Context) (plugin.Handler, error) { } -func (t *commandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (t *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // get job context jobCtx := &jobContext{} @@ -69,7 +69,7 @@ func (t *commandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c // this code will be enabled in prod after some testing } // let's submit our query to trino - req, err := newRequest(r, j, c, jobCtx) + req, err := newRequest(ctx, r, j, c, jobCtx) if err != nil { return err } @@ -77,7 +77,7 @@ func (t *commandContext) handler(ct ct.Context, r *plugin.Runtime, j *job.Job, c // now let's keep pooling until we get the full result... for req.nextUri != `` { time.Sleep(time.Duration(t.PollInterval) * time.Millisecond) - if err := req.poll(); err != nil { + if err := req.poll(ctx); err != nil { return err } } diff --git a/internal/pkg/pool/pool.go b/internal/pkg/pool/pool.go index 254160d..dec05db 100644 --- a/internal/pkg/pool/pool.go +++ b/internal/pkg/pool/pool.go @@ -4,8 +4,6 @@ import ( "context" "fmt" "time" - - "github.com/patterninc/heimdall/pkg/object/job" ) const ( @@ -16,10 +14,10 @@ const ( type Pool[T any] struct { Size int `yaml:"size,omitempty" json:"size,omitempty"` Sleep int `yaml:"sleep,omitempty" json:"sleep,omitempty"` - queue chan *job.Job + queue chan T } -func (p *Pool[T]) Start(worker func(context.Context, *job.Job) error, getWork func(int) ([]*job.Job, error)) error { +func (p *Pool[T]) Start(worker func(context.Context, T) error, getWork func(int) ([]T, error)) error { // do we have the size set? if p.Size <= 0 { @@ -32,7 +30,7 @@ func (p *Pool[T]) Start(worker func(context.Context, *job.Job) error, getWork fu } // set the queue of the size of double the pool size - p.queue = make(chan *job.Job, p.Size*2) + p.queue = make(chan T, p.Size*2) // let's set the counter tokens := &counter{} From ae56e8e4bdf28c1163ae86fc01ef79ded13a85c3 Mon Sep 17 00:00:00 2001 From: wlggraham Date: Thu, 20 Nov 2025 12:32:08 -0700 Subject: [PATCH 05/10] runtime s3 function fix --- pkg/plugin/runtime.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/plugin/runtime.go b/pkg/plugin/runtime.go index 6c0d616..179fb81 100644 --- a/pkg/plugin/runtime.go +++ b/pkg/plugin/runtime.go @@ -1,6 +1,7 @@ package plugin import ( + "context" "fmt" "io/fs" "os" @@ -97,7 +98,9 @@ func copyDir(src, dst string) error { // if we have local filesystem, crete directory as appropriate writeFileFunc := os.WriteFile if strings.HasPrefix(dst, s3Prefix) { - writeFileFunc = aws.WriteToS3 + writeFileFunc = func(name string, data []byte, perm os.FileMode) error { + return aws.WriteToS3(context.Background(), name, data, perm) + } } else { if _, err := os.Stat(dst); os.IsNotExist(err) { os.MkdirAll(dst, jobDirectoryPermissions) From 8d073a67201914cb26effca4c05dbdcf72cc20aa Mon Sep 17 00:00:00 2001 From: wlggraham Date: Thu, 18 Dec 2025 15:02:27 -0700 Subject: [PATCH 06/10] update context naming conventions --- .../object/command/clickhouse/clickhouse.go | 14 +++---- internal/pkg/object/command/dynamo/dynamo.go | 14 +++---- internal/pkg/object/command/ecs/ecs.go | 20 +++++----- internal/pkg/object/command/glue/glue.go | 16 ++++---- internal/pkg/object/command/ping/ping.go | 6 +-- internal/pkg/object/command/shell/shell.go | 24 +++++------ .../pkg/object/command/snowflake/snowflake.go | 14 +++---- internal/pkg/object/command/spark/spark.go | 24 +++++------ .../pkg/object/command/sparkeks/sparkeks.go | 40 +++++++++---------- .../object/command/sparkeks/sparkeks_test.go | 10 ++--- 10 files changed, 91 insertions(+), 91 deletions(-) diff --git a/internal/pkg/object/command/clickhouse/clickhouse.go b/internal/pkg/object/command/clickhouse/clickhouse.go index 5c67cf3..5581667 100644 --- a/internal/pkg/object/command/clickhouse/clickhouse.go +++ b/internal/pkg/object/command/clickhouse/clickhouse.go @@ -16,7 +16,7 @@ import ( "github.com/patterninc/heimdall/pkg/result/column" ) -type clickhouseCommandContext struct { +type commandContext struct { Username string `yaml:"username,omitempty" json:"username,omitempty"` Password string `yaml:"password,omitempty" json:"password,omitempty"` } @@ -45,11 +45,11 @@ var ( ) // New creates a new clickhouse plugin handler -func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { - t := &clickhouseCommandContext{} +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { + t := &commandContext{} - if commandContext != nil { - if err := commandContext.Unmarshal(t); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(t); err != nil { return nil, err } } @@ -57,7 +57,7 @@ func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { return t.handler, nil } -func (cmd *clickhouseCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (cmd *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { jobContext, err := cmd.createJobContext(ctx, j, c) if err != nil { @@ -81,7 +81,7 @@ func (cmd *clickhouseCommandContext) handler(ctx context.Context, r *plugin.Runt return nil } -func (cmd *clickhouseCommandContext) createJobContext(ctx context.Context, j *job.Job, c *cluster.Cluster) (*jobContext, error) { +func (cmd *commandContext) createJobContext(ctx context.Context, j *job.Job, c *cluster.Cluster) (*jobContext, error) { // get cluster context clusterCtx := &clusterContext{} if c.Context != nil { diff --git a/internal/pkg/object/command/dynamo/dynamo.go b/internal/pkg/object/command/dynamo/dynamo.go index 77dae5e..283a6b8 100644 --- a/internal/pkg/object/command/dynamo/dynamo.go +++ b/internal/pkg/object/command/dynamo/dynamo.go @@ -17,18 +17,18 @@ import ( ) // dynamoJobContext represents the context for a dynamo job -type dynamoJobContext struct { +type jobContext struct { Query string `yaml:"query,omitempty" json:"query,omitempty"` Limit int `yaml:"limit,omitempty" json:"limit,omitempty"` } // dynamoClusterContext represents the context for a dynamo endpoint -type dynamoClusterContext struct { +type clusterContext struct { RoleARN *string `yaml:"role_arn,omitempty" json:"role_arn,omitempty"` } // dynamoCommandContext represents the dynamo command context -type dynamoCommandContext struct{} +type commandContext struct{} var ( assumeRoleSession = aws.String("AssumeRoleSession") @@ -37,16 +37,16 @@ var ( // New creates a new dynamo plugin handler. func New(_ *heimdallContext.Context) (plugin.Handler, error) { - s := &dynamoCommandContext{} + s := &commandContext{} return s.handler, nil } // Handler for the Spark job submission. -func (d *dynamoCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { +func (d *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { // let's unmarshal job context - jobContext := &dynamoJobContext{} + jobContext := &jobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jobContext); err != nil { return err @@ -54,7 +54,7 @@ func (d *dynamoCommandContext) handler(ctx context.Context, r *plugin.Runtime, j } // let's unmarshal cluster context - clusterContext := &dynamoClusterContext{} + clusterContext := &clusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return err diff --git a/internal/pkg/object/command/ecs/ecs.go b/internal/pkg/object/command/ecs/ecs.go index 0256284..a15819f 100644 --- a/internal/pkg/object/command/ecs/ecs.go +++ b/internal/pkg/object/command/ecs/ecs.go @@ -23,7 +23,7 @@ import ( ) // ECS command context structure -type ecsCommandContext struct { +type commandContext struct { TaskDefinitionTemplate string `yaml:"task_definition_template,omitempty" json:"task_definition_template,omitempty"` TaskCount int `yaml:"task_count,omitempty" json:"task_count,omitempty"` CPU int `yaml:"cpu,omitempty" json:"cpu,omitempty"` @@ -35,7 +35,7 @@ type ecsCommandContext struct { } // ECS cluster context structure -type ecsClusterContext struct { +type clusterContext struct { MaxCPU int `yaml:"max_cpu,omitempty" json:"max_cpu,omitempty"` MaxMemory int `yaml:"max_memory,omitempty" json:"max_memory,omitempty"` MaxTaskCount int `yaml:"max_task_count,omitempty" json:"max_task_count,omitempty"` @@ -86,7 +86,7 @@ type executionContext struct { Memory int `json:"memory"` TaskDefinitionWrapper *taskDefinitionWrapper `json:"task_definition_wrapper"` ContainerOverrides []types.ContainerOverride `json:"container_overrides"` - ClusterConfig *ecsClusterContext `json:"cluster_config"` + ClusterConfig *clusterContext `json:"cluster_config"` PollingInterval duration.Duration `json:"polling_interval"` Timeout duration.Duration `json:"timeout"` @@ -120,17 +120,17 @@ var ( methodMetrics = telemetry.NewMethod("ecs", "ecs plugin") ) -func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { - e := &ecsCommandContext{ + e := &commandContext{ PollingInterval: defaultPollingInterval, Timeout: defaultTaskTimeout, MaxFailCount: defaultMaxFailCount, TaskCount: defaultTaskCount, } - if commandContext != nil { - if err := commandContext.Unmarshal(e); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(e); err != nil { return nil, err } } @@ -140,7 +140,7 @@ func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { } // handler implements the main ECS plugin logic -func (e *ecsCommandContext) handler(ctx context.Context, r *plugin.Runtime, job *job.Job, cluster *cluster.Cluster) error { +func (e *commandContext) handler(ctx context.Context, r *plugin.Runtime, job *job.Job, cluster *cluster.Cluster) error { // Build execution context with resolved configuration and loaded template execCtx, err := buildExecutionContext(ctx, e, job, cluster, r) @@ -346,7 +346,7 @@ func (execCtx *executionContext) pollForCompletion(ctx context.Context) error { } -func buildExecutionContext(ctx context.Context, commandCtx *ecsCommandContext, j *job.Job, c *cluster.Cluster, runtime *plugin.Runtime) (*executionContext, error) { +func buildExecutionContext(ctx context.Context, commandCtx *commandContext, j *job.Job, c *cluster.Cluster, runtime *plugin.Runtime) (*executionContext, error) { execCtx := &executionContext{ tasks: make(map[string]*taskTracker), @@ -367,7 +367,7 @@ func buildExecutionContext(ctx context.Context, commandCtx *ecsCommandContext, j } // Add cluster config (no overlapping values) - clusterContext := &ecsClusterContext{} + clusterContext := &clusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return nil, err diff --git a/internal/pkg/object/command/glue/glue.go b/internal/pkg/object/command/glue/glue.go index cece65b..63503b6 100644 --- a/internal/pkg/object/command/glue/glue.go +++ b/internal/pkg/object/command/glue/glue.go @@ -11,20 +11,20 @@ import ( "github.com/patterninc/heimdall/pkg/result" ) -type glueCommandContext struct { +type commandContext struct { CatalogID string `yaml:"catalog_id,omitempty" json:"catalog_id,omitempty"` } -type glueJobContext struct { +type jobContext struct { TableName string `yaml:"table_name,omitempty" json:"table_name,omitempty"` } -func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { - g := &glueCommandContext{} + g := &commandContext{} - if commandContext != nil { - if err := commandContext.Unmarshal(g); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(g); err != nil { return nil, err } } @@ -33,10 +33,10 @@ func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { } -func (g *glueCommandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { +func (g *commandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { // let's unmarshal job context - jc := &glueJobContext{} + jc := &jobContext{} if j.Context != nil { if err = j.Context.Unmarshal(jc); err != nil { return diff --git a/internal/pkg/object/command/ping/ping.go b/internal/pkg/object/command/ping/ping.go index caaa686..b09a22c 100644 --- a/internal/pkg/object/command/ping/ping.go +++ b/internal/pkg/object/command/ping/ping.go @@ -15,16 +15,16 @@ const ( messageFormat = `Hello, %s!` ) -type pingCommandContext struct{} +type commandContext struct{} func New(_ *heimdallContext.Context) (plugin.Handler, error) { - p := &pingCommandContext{} + p := &commandContext{} return p.handler, nil } -func (p *pingCommandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { +func (p *commandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { j.Result, err = result.FromMessage(fmt.Sprintf(messageFormat, j.User)) return diff --git a/internal/pkg/object/command/shell/shell.go b/internal/pkg/object/command/shell/shell.go index acbbd4c..aa34040 100644 --- a/internal/pkg/object/command/shell/shell.go +++ b/internal/pkg/object/command/shell/shell.go @@ -19,27 +19,27 @@ const ( contextFilename = `context.json` ) -type shellCommandContext struct { +type commandContext struct { Command []string `yaml:"command,omitempty" json:"command,omitempty"` } -type shellJobContext struct { +type jobContext struct { Arguments []string `yaml:"arguments,omitempty" json:"arguments,omitempty"` } type runtimeContext struct { - Job *job.Job `yaml:"job,omitempty" json:"job,omitempty"` - Command *shellCommandContext `yaml:"command,omitempty" json:"command,omitempty"` - Cluster *cluster.Cluster `yaml:"cluster,omitempty" json:"cluster,omitempty"` - Runtime *plugin.Runtime `yaml:"runtime,omitempty" json:"runtime,omitempty"` + Job *job.Job `yaml:"job,omitempty" json:"job,omitempty"` + Command *commandContext `yaml:"command,omitempty" json:"command,omitempty"` + Cluster *cluster.Cluster `yaml:"cluster,omitempty" json:"cluster,omitempty"` + Runtime *plugin.Runtime `yaml:"runtime,omitempty" json:"runtime,omitempty"` } -func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { - s := &shellCommandContext{} + s := &commandContext{} - if commandContext != nil { - if err := commandContext.Unmarshal(s); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(s); err != nil { return nil, err } } @@ -48,10 +48,10 @@ func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { } -func (s *shellCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // let's unmarshal job context - jc := &shellJobContext{} + jc := &jobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jc); err != nil { return err diff --git a/internal/pkg/object/command/snowflake/snowflake.go b/internal/pkg/object/command/snowflake/snowflake.go index efdd133..cf850db 100644 --- a/internal/pkg/object/command/snowflake/snowflake.go +++ b/internal/pkg/object/command/snowflake/snowflake.go @@ -28,13 +28,13 @@ var ( ErrInvalidKeyType = fmt.Errorf(`invalida key type`) ) -type snowflakeCommandContext struct{} +type commandContext struct{} -type snowflakeJobContext struct { +type jobContext struct { Query string `yaml:"query,omitempty" json:"query,omitempty"` } -type snowflakeClusterContext struct { +type clusterContext struct { Account string `yaml:"account,omitempty" json:"account,omitempty"` User string `yaml:"user,omitempty" json:"user,omitempty"` Database string `yaml:"database,omitempty" json:"database,omitempty"` @@ -67,13 +67,13 @@ func parsePrivateKey(privateKeyBytes []byte) (*rsa.PrivateKey, error) { } func New(_ *heimdallContext.Context) (plugin.Handler, error) { - s := &snowflakeCommandContext{} + s := &commandContext{} return s.handler, nil } -func (s *snowflakeCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { - clusterContext := &snowflakeClusterContext{} + clusterContext := &clusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return err @@ -81,7 +81,7 @@ func (s *snowflakeCommandContext) handler(ctx context.Context, r *plugin.Runtime } // let's unmarshal job context - jobContext := &snowflakeJobContext{} + jobContext := &jobContext{} if err := j.Context.Unmarshal(jobContext); err != nil { return err } diff --git a/internal/pkg/object/command/spark/spark.go b/internal/pkg/object/command/spark/spark.go index 0f99790..0583a04 100644 --- a/internal/pkg/object/command/spark/spark.go +++ b/internal/pkg/object/command/spark/spark.go @@ -32,7 +32,7 @@ type sparkSubmitParameters struct { } // spark represents the Spark command context -type sparkCommandContext struct { +type commandContext struct { QueriesURI string `yaml:"queries_uri,omitempty" json:"queries_uri,omitempty"` ResultsURI string `yaml:"results_uri,omitempty" json:"results_uri,omitempty"` LogsURI *string `yaml:"logs_uri,omitempty" json:"logs_uri,omitempty"` @@ -41,7 +41,7 @@ type sparkCommandContext struct { } // sparkJobContext represents the context for a spark job -type sparkJobContext struct { +type jobContext struct { Query string `yaml:"query,omitempty" json:"query,omitempty"` Arguments []string `yaml:"arguments,omitempty" json:"arguments,omitempty"` Parameters *sparkSubmitParameters `yaml:"parameters,omitempty" json:"parameters,omitempty"` @@ -49,7 +49,7 @@ type sparkJobContext struct { } // sparkClusterContext represents the context for a spark cluster -type sparkClusterContext struct { +type clusterContext struct { ExecutionRoleArn *string `yaml:"execution_role_arn,omitempty" json:"execution_role_arn,omitempty"` EMRReleaseLabel *string `yaml:"emr_release_label,omitempty" json:"emr_release_label,omitempty"` RoleARN *string `yaml:"role_arn,omitempty" json:"role_arn,omitempty"` @@ -76,12 +76,12 @@ var ( ) // New creates a new Spark plugin handler. -func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { - s := &sparkCommandContext{} + s := &commandContext{} - if commandContext != nil { - if err := commandContext.Unmarshal(s); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(s); err != nil { return nil, err } } @@ -91,10 +91,10 @@ func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { } // Handler for the Spark job submission. -func (s *sparkCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { +func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { // let's unmarshal job context - jobContext := &sparkJobContext{} + jobContext := &jobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jobContext); err != nil { return err @@ -102,7 +102,7 @@ func (s *sparkCommandContext) handler(ctx context.Context, r *plugin.Runtime, j } // let's unmarshal cluster context - clusterContext := &sparkClusterContext{} + clusterContext := &clusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return err @@ -261,7 +261,7 @@ timeoutLoop: } -func (s *sparkCommandContext) setJobDriver(jobContext *sparkJobContext, jobDriver *types.JobDriver, queryURI string, resultURI string) { +func (s *commandContext) setJobDriver(jobContext *jobContext, jobDriver *types.JobDriver, queryURI string, resultURI string) { jobParameters := getSparkSubmitParameters(jobContext) if jobContext.Arguments != nil { jobDriver.SparkSubmitJobDriver = &types.SparkSubmitJobDriver{ @@ -308,7 +308,7 @@ func getClusterID(ctx context.Context, svc *emrcontainers.Client, clusterName st } -func getSparkSubmitParameters(context *sparkJobContext) *string { +func getSparkSubmitParameters(context *jobContext) *string { properties := context.Parameters.Properties conf := make([]string, 0, len(properties)) diff --git a/internal/pkg/object/command/sparkeks/sparkeks.go b/internal/pkg/object/command/sparkeks/sparkeks.go index c41702d..ac9e8a1 100644 --- a/internal/pkg/object/command/sparkeks/sparkeks.go +++ b/internal/pkg/object/command/sparkeks/sparkeks.go @@ -90,24 +90,24 @@ var ( ErrSparkApplicationFile = fmt.Errorf("failed to read SparkApplication application template file: check file path and permissions") ) -type sparkEksCommandContext struct { +type commandContext struct { JobsURI string `yaml:"jobs_uri,omitempty" json:"jobs_uri,omitempty"` WrapperURI string `yaml:"wrapper_uri,omitempty" json:"wrapper_uri,omitempty"` Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` KubeNamespace string `yaml:"kube_namespace,omitempty" json:"kube_namespace,omitempty"` } -type sparkEksJobParameters struct { +type jobParameters struct { Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` } -type sparkEksJobContext struct { - Query string `yaml:"query,omitempty" json:"query,omitempty"` - Parameters *sparkEksJobParameters `yaml:"parameters,omitempty" json:"parameters,omitempty"` - ReturnResult bool `yaml:"return_result,omitempty" json:"return_result,omitempty"` +type jobContext struct { + Query string `yaml:"query,omitempty" json:"query,omitempty"` + Parameters *jobParameters `yaml:"parameters,omitempty" json:"parameters,omitempty"` + ReturnResult bool `yaml:"return_result,omitempty" json:"return_result,omitempty"` } -type sparkEksClusterContext struct { +type clusterContext struct { RoleARN *string `yaml:"role_arn,omitempty" json:"role_arn,omitempty"` Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` Image *string `yaml:"image,omitempty" json:"image,omitempty"` @@ -121,9 +121,9 @@ type executionContext struct { runtime *plugin.Runtime job *job.Job cluster *cluster.Cluster - commandContext *sparkEksCommandContext - jobContext *sparkEksJobContext - clusterContext *sparkEksClusterContext + commandContext *commandContext + jobContext *jobContext + clusterContext *clusterContext sparkClient *sparkClientSet.Clientset kubeClient *kubernetes.Clientset @@ -139,13 +139,13 @@ type executionContext struct { } // New creates a new Spark EKS plugin handler. -func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { - s := &sparkEksCommandContext{ +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { + s := &commandContext{ KubeNamespace: defaultNamespace, } - if commandContext != nil { - if err := commandContext.Unmarshal(s); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(s); err != nil { return nil, err } } @@ -154,7 +154,7 @@ func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { } // handler executes the Spark EKS job submission and execution. -func (s *sparkEksCommandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // 1. Build execution context, create URIs, and upload query execCtx, err := buildExecutionContextAndURI(ctx, r, j, c, s) @@ -190,7 +190,7 @@ func (s *sparkEksCommandContext) handler(ctx context.Context, r *plugin.Runtime, } // buildExecutionContextAndURI prepares the context, merges configurations, and uploads the query. -func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster, s *sparkEksCommandContext) (*executionContext, error) { +func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster, s *commandContext) (*executionContext, error) { execCtx := &executionContext{ runtime: r, job: j, @@ -199,7 +199,7 @@ func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job. } // Parse job context - jobContext := &sparkEksJobContext{} + jobContext := &jobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jobContext); err != nil { return nil, fmt.Errorf("failed to unmarshal job context: %w", err) @@ -208,7 +208,7 @@ func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job. execCtx.jobContext = jobContext // Parse cluster context - clusterContext := &sparkEksClusterContext{} + clusterContext := &clusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return nil, fmt.Errorf("failed to unmarshal cluster context: %w", err) @@ -218,7 +218,7 @@ func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job. // Initialize and merge properties from command -> job if execCtx.jobContext.Parameters == nil { - execCtx.jobContext.Parameters = &sparkEksJobParameters{ + execCtx.jobContext.Parameters = &jobParameters{ Properties: make(map[string]string), } } @@ -390,7 +390,7 @@ func printState(writer io.Writer, state v1beta2.ApplicationStateType) { } // getSparkSubmitParameters returns Spark submit parameters as a string. -func getSparkSubmitParameters(context *sparkEksJobContext) *string { +func getSparkSubmitParameters(context *jobContext) *string { if context.Parameters == nil || len(context.Parameters.Properties) == 0 { return nil } diff --git a/internal/pkg/object/command/sparkeks/sparkeks_test.go b/internal/pkg/object/command/sparkeks/sparkeks_test.go index 8c4235b..dc823c4 100644 --- a/internal/pkg/object/command/sparkeks/sparkeks_test.go +++ b/internal/pkg/object/command/sparkeks/sparkeks_test.go @@ -25,8 +25,8 @@ func TestUpdateS3ToS3aURI(t *testing.T) { } func TestGetSparkSubmitParameters(t *testing.T) { - ctx := &sparkEksJobContext{ - Parameters: &sparkEksJobParameters{ + ctx := &jobContext{ + Parameters: &jobParameters{ Properties: map[string]string{ "spark.executor.memory": "4g", "spark.driver.cores": "2", @@ -40,8 +40,8 @@ func TestGetSparkSubmitParameters(t *testing.T) { } func TestGetSparkSubmitParameters_Empty(t *testing.T) { - ctx := &sparkEksJobContext{ - Parameters: &sparkEksJobParameters{ + ctx := &jobContext{ + Parameters: &jobParameters{ Properties: map[string]string{}, }, } @@ -52,7 +52,7 @@ func TestGetSparkSubmitParameters_Empty(t *testing.T) { } func TestGetSparkSubmitParameters_NilParameters(t *testing.T) { - ctx := &sparkEksJobContext{ + ctx := &jobContext{ Parameters: nil, } params := getSparkSubmitParameters(ctx) From 10062cd3ed9ad66e819a3c874bf83bf2c0c89876 Mon Sep 17 00:00:00 2001 From: wlggraham Date: Thu, 18 Dec 2025 15:03:33 -0700 Subject: [PATCH 07/10] add cancelled_by job field --- assets/databases/heimdall/tables/jobs.sql | 3 +- internal/pkg/heimdall/job.go | 82 +++++++++---------- internal/pkg/heimdall/job_dal.go | 7 +- internal/pkg/heimdall/queries/job/insert.sql | 6 +- internal/pkg/heimdall/queries/job/select.sql | 3 +- .../pkg/heimdall/queries/job/select_jobs.sql | 3 +- .../queries/job/status_cancel_update.sql | 8 ++ pkg/object/job/job.go | 1 + 8 files changed, 63 insertions(+), 50 deletions(-) create mode 100644 internal/pkg/heimdall/queries/job/status_cancel_update.sql diff --git a/assets/databases/heimdall/tables/jobs.sql b/assets/databases/heimdall/tables/jobs.sql index 11de46d..4876d6f 100644 --- a/assets/databases/heimdall/tables/jobs.sql +++ b/assets/databases/heimdall/tables/jobs.sql @@ -19,4 +19,5 @@ create table if not exists jobs constraint _jobs_job_id unique (job_id) ); -alter table jobs add column if not exists store_result_sync boolean not null default false; \ No newline at end of file +alter table jobs add column if not exists store_result_sync boolean not null default false; +alter table jobs add column if not exists cancelled_by varchar(64) null; \ No newline at end of file diff --git a/internal/pkg/heimdall/job.go b/internal/pkg/heimdall/job.go index fd2d8f3..b3b2a08 100644 --- a/internal/pkg/heimdall/job.go +++ b/internal/pkg/heimdall/job.go @@ -38,9 +38,13 @@ const ( var ( ErrCommandClusterPairNotFound = fmt.Errorf(`command-cluster pair is not found`) + ErrJobCancelFailed = fmt.Errorf(`async job unrecognized or already in final state`) runJobMethod = telemetry.NewMethod("runJob", "heimdall") ) +//go:embed queries/job/status_cancel_update.sql +var queryJobCancelUpdate string + type commandOnCluster struct { command *command.Command cluster *cluster.Cluster @@ -114,6 +118,7 @@ func (h *Heimdall) runJob(ctx context.Context, j *job.Job, command *command.Comm // Create cancellable context for the job pluginCtx, cancel := context.WithCancel(ctx) + defer cancel() // Start plugin execution in goroutine go func() { @@ -122,28 +127,30 @@ func (h *Heimdall) runJob(ctx context.Context, j *job.Job, command *command.Comm jobDone <- err }() - // Start cancellation monitoring in goroutine - go func() { - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for { - select { - // plugin finished, stop monitoring - case <-cancelMonitorDone: - return - case <-ticker.C: - // If job is in cancelling state, trigger context cancellation - result, err := h.getJobStatus(ctx, &jobRequest{ID: j.ID}) - if err == nil { - if job, ok := result.(*job.Job); ok && job.Status == jobStatus.Cancelling { - cancel() - return + // Start cancellation monitoring for async jobs + if !j.IsSync { + go func() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + // plugin finished, stop monitoring + case <-cancelMonitorDone: + return + case <-ticker.C: + // If job is in cancelling state, trigger context cancellation + result, err := h.getJobStatus(ctx, &jobRequest{ID: j.ID}) + if err == nil { + if job, ok := result.(*job.Job); ok && job.Status == jobStatus.Cancelling { + cancel() + return + } } } } - } - }() + }() + } // Wait for job execution to complete jobErr := <-jobDone @@ -206,35 +213,26 @@ func (h *Heimdall) storeResults(runtime *plugin.Runtime, j *job.Job) error { func (h *Heimdall) cancelJob(ctx context.Context, req *jobRequest) (any, error) { - // validate that job exists and get its current status - currentJob, err := h.getJob(ctx, req) + sess, err := h.Database.NewSession(false) if err != nil { return nil, err } + defer sess.Close() - // make sure we have a job object - job, ok := currentJob.(*job.Job) - if !ok { - return nil, fmt.Errorf("expected *job.Job, got %T", currentJob) + // Attempt to cancel + rowsAffected, err := sess.Exec(queryJobCancelUpdate, req.ID, req.User) + if err != nil { + return nil, err } - // check current job status - switch job.Status { - // already cancelled/cancelling - return success (idempotent) - case jobStatus.Cancelling, jobStatus.Cancelled: - return job, nil - case jobStatus.Running, jobStatus.Accepted: - // can be cancelled - proceed with cancellation - job.Status = jobStatus.Cancelling - job.UpdatedAt = int(time.Now().Unix()) - if err := h.updateAsyncJobStatus(job, nil); err != nil { - return nil, err - } - return job, nil - default: - // job is in a final state (succeeded, failed, etc.) - cannot be cancelled - return nil, fmt.Errorf("job cannot be cancelled: current status is %v", job.Status) + if rowsAffected == 0 { + return nil, ErrJobCancelFailed } + + // return job status + return &job.Job{ + Status: jobStatus.Cancelling, + }, nil } func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { @@ -273,7 +271,7 @@ func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { filenamePath := fmt.Sprintf(jobFileFormat, sourceDirectory, jobID, filename) if strings.HasPrefix(filenamePath, s3Prefix) { readFileFunc = func(path string) ([]byte, error) { - return aws.ReadFromS3(context.Background(), path) + return aws.ReadFromS3(r.Context(), path) } } diff --git a/internal/pkg/heimdall/job_dal.go b/internal/pkg/heimdall/job_dal.go index 2fcb95e..a007eae 100644 --- a/internal/pkg/heimdall/job_dal.go +++ b/internal/pkg/heimdall/job_dal.go @@ -89,6 +89,7 @@ var ( type jobRequest struct { ID string `yaml:"id,omitempty" json:"id,omitempty"` File string `yaml:"file,omitempty" json:"file,omitempty"` + User string `yaml:"user,omitempty" json:"user,omitempty"` } func (h *Heimdall) insertJob(j *job.Job, clusterID, commandID string) (int64, error) { @@ -101,7 +102,7 @@ func (h *Heimdall) insertJob(j *job.Job, clusterID, commandID string) (int64, er defer sess.Close() // insert job row - jobID, err := sess.InsertRow(queryJobInsert, clusterID, commandID, j.Status, j.ID, j.Name, j.Version, j.Description, j.Context.String(), j.Error, j.User, j.IsSync, j.StoreResultSync) + jobID, err := sess.InsertRow(queryJobInsert, clusterID, commandID, j.Status, j.ID, j.Name, j.Version, j.Description, j.Context.String(), j.Error, j.User, j.IsSync, j.StoreResultSync, j.CancelledBy) if err != nil { return 0, err } @@ -181,7 +182,7 @@ func (h *Heimdall) getJob(ctx context.Context, j *jobRequest) (any, error) { var jobContext string if err := row.Scan(&r.SystemID, &r.Status, &r.Name, &r.Version, &r.Description, &jobContext, &r.Error, &r.User, &r.IsSync, - &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.CluserID, &r.ClusterName, &r.StoreResultSync); err != nil { + &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.CluserID, &r.ClusterName, &r.StoreResultSync, &r.CancelledBy); err != nil { if err == sql.ErrNoRows { return nil, ErrUnknownJobID } else { @@ -225,7 +226,7 @@ func (h *Heimdall) getJobs(ctx context.Context, f *database.Filter) (any, error) r := &job.Job{} if err := rows.Scan(&r.SystemID, &r.ID, &r.Status, &r.Name, &r.Version, &r.Description, &jobContext, &r.Error, &r.User, &r.IsSync, - &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.CluserID, &r.ClusterName, &r.StoreResultSync); err != nil { + &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.CluserID, &r.ClusterName, &r.StoreResultSync, &r.CancelledBy); err != nil { return nil, err } diff --git a/internal/pkg/heimdall/queries/job/insert.sql b/internal/pkg/heimdall/queries/job/insert.sql index 46813a6..c641363 100644 --- a/internal/pkg/heimdall/queries/job/insert.sql +++ b/internal/pkg/heimdall/queries/job/insert.sql @@ -11,7 +11,8 @@ insert into jobs job_error, username, is_sync, - store_result_sync + store_result_sync, + cancelled_by ) select cm.system_command_id, @@ -25,7 +26,8 @@ select $9, -- job_error $10, -- username $11, -- is_sync - $12 -- store_result_sync + $12, -- store_result_sync + $13 -- cancelled_by from clusters cl, commands cm diff --git a/internal/pkg/heimdall/queries/job/select.sql b/internal/pkg/heimdall/queries/job/select.sql index bc84784..716633f 100644 --- a/internal/pkg/heimdall/queries/job/select.sql +++ b/internal/pkg/heimdall/queries/job/select.sql @@ -14,7 +14,8 @@ select cm.command_name, cl.cluster_id, cl.cluster_name, - j.store_result_sync + j.store_result_sync, + j.cancelled_by from jobs j left join commands cm on cm.system_command_id = j.job_command_id diff --git a/internal/pkg/heimdall/queries/job/select_jobs.sql b/internal/pkg/heimdall/queries/job/select_jobs.sql index 39b4d81..3a439fa 100644 --- a/internal/pkg/heimdall/queries/job/select_jobs.sql +++ b/internal/pkg/heimdall/queries/job/select_jobs.sql @@ -15,7 +15,8 @@ select cm.command_name, cl.cluster_id, cl.cluster_name, - j.store_result_sync + j.store_result_sync, + j.cancelled_by from jobs j join job_statuses js on js.job_status_id = j.job_status_id diff --git a/internal/pkg/heimdall/queries/job/status_cancel_update.sql b/internal/pkg/heimdall/queries/job/status_cancel_update.sql new file mode 100644 index 0000000..8597499 --- /dev/null +++ b/internal/pkg/heimdall/queries/job/status_cancel_update.sql @@ -0,0 +1,8 @@ +update jobs +set + job_status_id = 7, -- CANCELLING + cancelled_by = $2, + updated_at = extract(epoch from now())::int +where + job_id = $1 + and job_status_id not in (4, 5, 6, 8); -- Not in FAILED, KILLED, SUCCEEDED, CANCELLED diff --git a/pkg/object/job/job.go b/pkg/object/job/job.go index b0d6855..6245b59 100644 --- a/pkg/object/job/job.go +++ b/pkg/object/job/job.go @@ -21,6 +21,7 @@ type Job struct { CommandName string `yaml:"command_name,omitempty" json:"command_name,omitempty"` CluserID string `yaml:"cluster_id,omitempty" json:"cluster_id,omitempty"` ClusterName string `yaml:"cluster_name,omitempty" json:"cluster_name,omitempty"` + CancelledBy string `yaml:"cancelled_by,omitempty" json:"cancelled_by,omitempty"` Result *result.Result `yaml:"result,omitempty" json:"result,omitempty"` } From d1c858d76f2347024c7c29f091bcbf756de4cd1e Mon Sep 17 00:00:00 2001 From: wlggraham Date: Thu, 18 Dec 2025 15:04:08 -0700 Subject: [PATCH 08/10] update readme with ecs & new endpoint --- README.md | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 1ad1c91..e426523 100644 --- a/README.md +++ b/README.md @@ -88,16 +88,17 @@ GET /api/v1/job//stderr Heimdall supports a growing set of pluggable command types: -| Plugin | Description | Execution Mode | -| ----------- | -------------------------------------- | -------------- | -| `ping` | [Basic plugin used for testing](https://github.com/patterninc/heimdall/blob/main/plugins/ping/README.md) | Sync or Async | -| `shell` | [Shell command execution](https://github.com/patterninc/heimdall/blob/main/plugins/shell/README.md) | Sync or Async | -| `glue` | [Pulling Iceberg table metadata](https://github.com/patterninc/heimdall/blob/main/plugins/glue/README.md) | Sync or Async | -| `dynamo` | [DynamoDB read operation](https://github.com/patterninc/heimdall/blob/main/plugins/dynamo/README.md) | Sync or Async | -| `snowflake` | [Query execution in Snowflake](https://github.com/patterninc/heimdall/blob/main/plugins/snowflake/README.md) | Async | -| `spark` | [SparkSQL query execution on EMR on EKS](https://github.com/patterninc/heimdall/blob/main/plugins/spark/README.md) | Async | -| `trino` | [Query execution in Trino](https://github.com/patterninc/heimdall/blob/main/plugins/trino/README.md) | Async | -| `clickhouse`| [Query execution in Clickhouse](https://github.com/patterninc/heimdall/blob/main/plugins/clickhouse/README.md) | Sync | +| Plugin | Description | Execution Mode | +| ----------- | -------------------------------------- | -------------- | +| `ping` | [Basic plugin used for testing](https://github.com/patterninc/heimdall/blob/main/plugins/ping/README.md) | Sync or Async | +| `shell` | [Shell command execution](https://github.com/patterninc/heimdall/blob/main/plugins/shell/README.md) | Sync or Async | +| `glue` | [Pulling Iceberg table metadata](https://github.com/patterninc/heimdall/blob/main/plugins/glue/README.md) | Sync or Async | +| `dynamo` | [DynamoDB read operation](https://github.com/patterninc/heimdall/blob/main/plugins/dynamo/README.md) | Sync or Async | +| `snowflake` | [Query execution in Snowflake](https://github.com/patterninc/heimdall/blob/main/plugins/snowflake/README.md) | Async | +| `spark` | [SparkSQL query execution on EMR on EKS](https://github.com/patterninc/heimdall/blob/main/plugins/spark/README.md) | Async | +| `trino` | [Query execution in Trino](https://github.com/patterninc/heimdall/blob/main/plugins/trino/README.md) | Async | +| `clickhouse` | [Query execution in Clickhouse](https://github.com/patterninc/heimdall/blob/main/plugins/clickhouse/README.md) | Sync | +| `ecs fargate` | [Task Deployment in ECS Fargate](https://github.com/patterninc/heimdall/blob/main/plugins/ecs/README.md) | Async | --- @@ -163,6 +164,7 @@ It centralizes execution logic, logging, and auditing—all accessible via API o | `POST /api/v1/job` | Submit a job | | `GET /api/v1/job/` | Get job details | | `GET /api/v1/job//status` | Check job status | +| `POST /api/v1/job//cancel` | Cancel an async job | | `GET /api/v1/job//stdout` | Get stdout for a completed job | | `GET /api/v1/job//stderr` | Get stderr for a completed job | | `GET /api/v1/job//result` | Get job's result | From caaa34c78e8a1f328b4259b9f599d1621ebc50a4 Mon Sep 17 00:00:00 2001 From: wlggraham Date: Thu, 18 Dec 2025 15:44:23 -0700 Subject: [PATCH 09/10] update old cancelled_by job fields --- assets/databases/heimdall/tables/jobs.sql | 3 ++- web/src/modules/Jobs/Helper.tsx | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/assets/databases/heimdall/tables/jobs.sql b/assets/databases/heimdall/tables/jobs.sql index 4876d6f..e40a6fb 100644 --- a/assets/databases/heimdall/tables/jobs.sql +++ b/assets/databases/heimdall/tables/jobs.sql @@ -20,4 +20,5 @@ create table if not exists jobs ); alter table jobs add column if not exists store_result_sync boolean not null default false; -alter table jobs add column if not exists cancelled_by varchar(64) null; \ No newline at end of file +alter table jobs add column if not exists cancelled_by varchar(64) null; +update jobs set cancelled_by = '' where cancelled_by is null; \ No newline at end of file diff --git a/web/src/modules/Jobs/Helper.tsx b/web/src/modules/Jobs/Helper.tsx index 50ce30d..c47f205 100644 --- a/web/src/modules/Jobs/Helper.tsx +++ b/web/src/modules/Jobs/Helper.tsx @@ -36,6 +36,7 @@ export type JobType = { command_name: string cluster_id: string cluster_name: string + cancelled_by: string error?: string context?: { properties: { From 34f6d86a8a1747147b67b4bfbcf29ff1f4d237c9 Mon Sep 17 00:00:00 2001 From: wlggraham Date: Thu, 18 Dec 2025 15:59:29 -0700 Subject: [PATCH 10/10] spelling typo --- internal/pkg/heimdall/job_dal.go | 4 ++-- internal/pkg/heimdall/jobs_async.go | 6 +++--- pkg/object/job/job.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/pkg/heimdall/job_dal.go b/internal/pkg/heimdall/job_dal.go index a007eae..7719b27 100644 --- a/internal/pkg/heimdall/job_dal.go +++ b/internal/pkg/heimdall/job_dal.go @@ -182,7 +182,7 @@ func (h *Heimdall) getJob(ctx context.Context, j *jobRequest) (any, error) { var jobContext string if err := row.Scan(&r.SystemID, &r.Status, &r.Name, &r.Version, &r.Description, &jobContext, &r.Error, &r.User, &r.IsSync, - &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.CluserID, &r.ClusterName, &r.StoreResultSync, &r.CancelledBy); err != nil { + &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.ClusterID, &r.ClusterName, &r.StoreResultSync, &r.CancelledBy); err != nil { if err == sql.ErrNoRows { return nil, ErrUnknownJobID } else { @@ -226,7 +226,7 @@ func (h *Heimdall) getJobs(ctx context.Context, f *database.Filter) (any, error) r := &job.Job{} if err := rows.Scan(&r.SystemID, &r.ID, &r.Status, &r.Name, &r.Version, &r.Description, &jobContext, &r.Error, &r.User, &r.IsSync, - &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.CluserID, &r.ClusterName, &r.StoreResultSync, &r.CancelledBy); err != nil { + &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.ClusterID, &r.ClusterName, &r.StoreResultSync, &r.CancelledBy); err != nil { return nil, err } diff --git a/internal/pkg/heimdall/jobs_async.go b/internal/pkg/heimdall/jobs_async.go index 3da8400..67804f4 100644 --- a/internal/pkg/heimdall/jobs_async.go +++ b/internal/pkg/heimdall/jobs_async.go @@ -51,7 +51,7 @@ func (h *Heimdall) getAsyncJobs(limit int) ([]*job.Job, error) { jobContext, j := ``, &job.Job{} - if err := rows.Scan(&j.SystemID, &j.CommandID, &j.CluserID, &j.Status, &j.ID, &j.Name, + if err := rows.Scan(&j.SystemID, &j.CommandID, &j.ClusterID, &j.Status, &j.ID, &j.Name, &j.Version, &j.Description, &jobContext, &j.User, &j.IsSync, &j.CreatedAt, &j.UpdatedAt, &j.StoreResultSync); err != nil { return nil, err } @@ -109,9 +109,9 @@ func (h *Heimdall) runAsyncJob(ctx context.Context, j *job.Job) error { } // do we have hte cluster? - cluster, found := h.Clusters[j.CluserID] + cluster, found := h.Clusters[j.ClusterID] if !found { - return h.updateAsyncJobStatus(j, fmt.Errorf(formatErrUnknownCluster, j.CluserID)) + return h.updateAsyncJobStatus(j, fmt.Errorf(formatErrUnknownCluster, j.ClusterID)) } return h.updateAsyncJobStatus(j, h.runJob(ctx, j, command, cluster)) diff --git a/pkg/object/job/job.go b/pkg/object/job/job.go index 6245b59..22b9c83 100644 --- a/pkg/object/job/job.go +++ b/pkg/object/job/job.go @@ -19,7 +19,7 @@ type Job struct { ClusterCriteria *set.Set[string] `yaml:"cluster_criteria,omitempty" json:"cluster_criteria,omitempty"` CommandID string `yaml:"command_id,omitempty" json:"command_id,omitempty"` CommandName string `yaml:"command_name,omitempty" json:"command_name,omitempty"` - CluserID string `yaml:"cluster_id,omitempty" json:"cluster_id,omitempty"` + ClusterID string `yaml:"cluster_id,omitempty" json:"cluster_id,omitempty"` ClusterName string `yaml:"cluster_name,omitempty" json:"cluster_name,omitempty"` CancelledBy string `yaml:"cancelled_by,omitempty" json:"cancelled_by,omitempty"` Result *result.Result `yaml:"result,omitempty" json:"result,omitempty"`