diff --git a/src/context.rs b/src/context.rs index 3071119..3bb942d 100644 --- a/src/context.rs +++ b/src/context.rs @@ -307,13 +307,15 @@ where let checkpoint_name = self.get_checkpoint_name(name, &())?; let duration_ms = duration.as_millis() as i64; - let (needs_suspend,): (bool,) = sqlx::query_as("SELECT durable.sleep_for($1, $2, $3, $4)") - .bind(&self.queue_name) - .bind(self.run_id) - .bind(&checkpoint_name) - .bind(duration_ms) - .fetch_one(&self.pool) - .await?; + let (needs_suspend,): (bool,) = + sqlx::query_as("SELECT durable.sleep_for($1, $2, $3, $4, $5)") + .bind(&self.queue_name) + .bind(self.task_id) + .bind(self.run_id) + .bind(&checkpoint_name) + .bind(duration_ms) + .fetch_one(&self.pool) + .await?; if needs_suspend { return Err(TaskError::Control(ControlFlow::Suspend)); diff --git a/src/postgres/migrations/20251202002136_initial_setup.sql b/src/postgres/migrations/20251202002136_initial_setup.sql index db2325d..a053015 100644 --- a/src/postgres/migrations/20251202002136_initial_setup.sql +++ b/src/postgres/migrations/20251202002136_initial_setup.sql @@ -334,6 +334,7 @@ declare v_claim_until timestamptz; v_sql text; v_expired_run record; + v_cancelled_task uuid; begin if v_claim_timeout <= 0 then raise exception 'claim_timeout must be greater than zero'; @@ -344,59 +345,82 @@ begin -- Apply cancellation rules before claiming. -- These are max_delay (delay before starting) and -- max_duration (duration from created to finished) - execute format( - 'with limits as ( - select task_id, - (cancellation->>''max_delay'')::bigint as max_delay, - (cancellation->>''max_duration'')::bigint as max_duration, - enqueue_at, - first_started_at, - state - from durable.%I - where state in (''pending'', ''sleeping'', ''running'') - ), - to_cancel as ( - select task_id - from limits - where - ( - max_delay is not null - and first_started_at is null - and extract(epoch from ($1 - enqueue_at)) >= max_delay - ) - or - ( - max_duration is not null - and first_started_at is not null - and extract(epoch from ($1 - first_started_at)) >= max_duration - ) - ) - update durable.%I t - set state = ''cancelled'', - cancelled_at = coalesce(t.cancelled_at, $1) - where t.task_id in (select task_id from to_cancel)', - 't_' || p_queue_name, - 't_' || p_queue_name - ) using v_now; + -- Use a loop so we can cleanup each cancelled task properly. + for v_cancelled_task in + execute format( + 'with limits as ( + select task_id, + (cancellation->>''max_delay'')::bigint as max_delay, + (cancellation->>''max_duration'')::bigint as max_duration, + enqueue_at, + first_started_at, + state + from durable.%I + where state in (''pending'', ''sleeping'', ''running'') + ), + to_cancel as ( + select task_id + from limits + where + ( + max_delay is not null + and first_started_at is null + and extract(epoch from ($1 - enqueue_at)) >= max_delay + ) + or + ( + max_duration is not null + and first_started_at is not null + and extract(epoch from ($1 - first_started_at)) >= max_duration + ) + ) + update durable.%I t + set state = ''cancelled'', + cancelled_at = coalesce(t.cancelled_at, $1) + where t.task_id in (select task_id from to_cancel) + returning t.task_id', + 't_' || p_queue_name, + 't_' || p_queue_name + ) using v_now + loop + -- Cancel all runs for this task + execute format( + 'update durable.%I + set state = ''cancelled'', + claimed_by = null, + claim_expires_at = null + where task_id = $1 + and state not in (''completed'', ''failed'', ''cancelled'')', + 'r_' || p_queue_name + ) using v_cancelled_task; + + -- Cleanup: delete waiters, emit event, cascade cancel children + perform durable.cleanup_task_terminal(p_queue_name, v_cancelled_task, 'cancelled', null, true); + end loop; - -- Fail any run claims that have timed out + -- Fail any run claims that have timed out. + -- Lock tasks first to keep a consistent task -> run lock order. for v_expired_run in execute format( - 'select run_id, - claimed_by, - claim_expires_at, - attempt - from durable.%I - where state = ''running'' - and claim_expires_at is not null - and claim_expires_at <= $1 - for update skip locked', - 'r_' || p_queue_name + 'select r.run_id, + r.task_id, + r.claimed_by, + r.claim_expires_at, + r.attempt + from durable.%I r + join durable.%I t on t.task_id = r.task_id + where r.state = ''running'' + and r.claim_expires_at is not null + and r.claim_expires_at <= $1 + for update of t skip locked', + 'r_' || p_queue_name, + 't_' || p_queue_name ) using v_now loop perform durable.fail_run( p_queue_name, + v_expired_run.task_id, v_expired_run.run_id, jsonb_strip_nulls(jsonb_build_object( 'name', '$ClaimTimeout', @@ -499,6 +523,7 @@ $$; -- Marks a run as completed create function durable.complete_run ( p_queue_name text, + p_task_id uuid, p_run_id uuid, p_state jsonb default null ) @@ -506,11 +531,27 @@ create function durable.complete_run ( language plpgsql as $$ declare - v_task_id uuid; + v_task_id_locked uuid; + v_run_task_id uuid; v_state text; - v_parent_task_id uuid; v_now timestamptz := durable.current_time(); begin + -- Lock task first to keep a consistent task -> run lock order. + execute format( + 'select task_id + from durable.%I + where task_id = $1 + for update', + 't_' || p_queue_name + ) + into v_task_id_locked + using p_task_id; + + if v_task_id_locked is null then + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; + end if; + + -- Lock the run after the task lock execute format( 'select task_id, state from durable.%I @@ -518,17 +559,22 @@ begin for update', 'r_' || p_queue_name ) - into v_task_id, v_state + into v_run_task_id, v_state using p_run_id; - if v_task_id is null then + if v_run_task_id is null then raise exception 'Run "%" not found in queue "%"', p_run_id, p_queue_name; end if; + if v_run_task_id <> p_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, p_task_id; + end if; + if v_state <> 'running' then raise exception 'Run "%" is not currently running in queue "%"', p_run_id, p_queue_name; end if; + -- Update run to completed execute format( 'update durable.%I set state = ''completed'', @@ -538,38 +584,30 @@ begin 'r_' || p_queue_name ) using p_run_id, v_now, p_state; - -- Get parent_task_id to check if this is a subtask + -- Update task to completed execute format( 'update durable.%I set state = ''completed'', completed_payload = $2, last_attempt_run = $3 - where task_id = $1 - returning parent_task_id', + where task_id = $1', 't_' || p_queue_name - ) - into v_parent_task_id - using v_task_id, p_state, p_run_id; - - -- Clean up any wait registrations for this run - execute format( - 'delete from durable.%I where run_id = $1', - 'w_' || p_queue_name - ) using p_run_id; - - -- Emit completion event for parent to join on (only if this is a subtask) - if v_parent_task_id is not null then - perform durable.emit_event( - p_queue_name, - '$child:' || v_task_id::text, - jsonb_build_object('status', 'completed', 'result', p_state) - ); - end if; + ) using p_task_id, p_state, p_run_id; + + -- Cleanup: delete waiters and emit completion event for parent + perform durable.cleanup_task_terminal( + p_queue_name, + p_task_id, + 'completed', + jsonb_build_object('result', p_state), + false -- don't cascade cancel children for completed tasks + ); end; $$; create function durable.sleep_for( p_queue_name text, + p_task_id uuid, p_run_id uuid, p_checkpoint_name text, p_duration_ms bigint @@ -581,33 +619,66 @@ declare v_wake_at timestamptz; v_existing_state jsonb; v_now timestamptz := durable.current_time(); - v_task_id uuid; + v_run_task_id uuid; + v_run_state text; + v_task_state text; begin - -- Get task_id from run (needed for checkpoint table key) + -- Lock task first to keep a consistent task -> run lock order. execute format( - 'select task_id from durable.%I where run_id = $1 and state = ''running'' for update', + 'select state from durable.%I where task_id = $1 for update', + 't_' || p_queue_name + ) into v_task_state using p_task_id; + + if v_task_state is null then + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; + end if; + + if v_task_state = 'cancelled' then + raise exception sqlstate 'AB001' using message = 'Task has been cancelled'; + end if; + + -- Lock run after task + execute format( + 'select task_id, state from durable.%I where run_id = $1 for update', 'r_' || p_queue_name - ) into v_task_id using p_run_id; + ) into v_run_task_id, v_run_state using p_run_id; - if v_task_id is null then + if v_run_task_id is null then + raise exception 'Run "%" not found in queue "%"', p_run_id, p_queue_name; + end if; + + if v_run_task_id <> p_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, p_task_id; + end if; + + if v_run_state <> 'running' then raise exception 'Run "%" is not currently running in queue "%"', p_run_id, p_queue_name; end if; - -- Check for existing checkpoint, else compute and store wake time + -- Check for existing checkpoint execute format( 'select state from durable.%I where task_id = $1 and checkpoint_name = $2', 'c_' || p_queue_name - ) into v_existing_state using v_task_id, p_checkpoint_name; + ) into v_existing_state using p_task_id, p_checkpoint_name; if v_existing_state is not null then v_wake_at := (v_existing_state #>> '{}')::timestamptz; else + -- Compute wake time and store checkpoint (first-writer-wins) v_wake_at := v_now + (p_duration_ms || ' milliseconds')::interval; execute format( 'insert into durable.%I (task_id, checkpoint_name, state, owner_run_id, updated_at) - values ($1, $2, $3, $4, $5)', + values ($1, $2, $3, $4, $5) + on conflict (task_id, checkpoint_name) do nothing', + 'c_' || p_queue_name + ) using p_task_id, p_checkpoint_name, to_jsonb(v_wake_at::text), p_run_id, v_now; + + -- Re-read in case we lost the race (first-writer-wins) + execute format( + 'select state from durable.%I where task_id = $1 and checkpoint_name = $2', 'c_' || p_queue_name - ) using v_task_id, p_checkpoint_name, to_jsonb(v_wake_at::text), p_run_id, v_now; + ) into v_existing_state using p_task_id, p_checkpoint_name; + v_wake_at := (v_existing_state #>> '{}')::timestamptz; end if; -- If wake time passed, return false (no suspend needed) @@ -632,12 +703,60 @@ begin set state = ''sleeping'' where task_id = $1', 't_' || p_queue_name - ) using v_task_id; + ) using p_task_id; return true; end; $$; +-- Consolidates cleanup logic for a task that has reached a terminal state. +-- This function: +-- 1. Deletes wait registrations for the task +-- 2. Emits a completion event for the parent (if this is a subtask) +-- 3. Optionally cascades cancellation to children +-- +-- Called by: complete_run, fail_run, cancel_task, cascade_cancel_children, claim_task +create function durable.cleanup_task_terminal ( + p_queue_name text, + p_task_id uuid, + p_status text, -- 'completed', 'failed', 'cancelled' + p_payload jsonb default null, + p_cascade_children boolean default false +) + returns void + language plpgsql +as $$ +declare + v_parent_task_id uuid; +begin + -- Get parent_task_id for event emission + execute format( + 'select parent_task_id from durable.%I where task_id = $1', + 't_' || p_queue_name + ) into v_parent_task_id using p_task_id; + + -- Delete wait registrations for this task + execute format( + 'delete from durable.%I where task_id = $1', + 'w_' || p_queue_name + ) using p_task_id; + + -- Emit completion event for parent (if subtask) + if v_parent_task_id is not null then + perform durable.emit_event( + p_queue_name, + '$child:' || p_task_id::text, + jsonb_build_object('status', p_status) || coalesce(p_payload, '{}'::jsonb) + ); + end if; + + -- Cascade cancel children if requested + if p_cascade_children then + perform durable.cascade_cancel_children(p_queue_name, p_task_id); + end if; +end; +$$; + -- Recursively cancels all children of a parent task. -- Used when a parent task fails or is cancelled to cascade the cancellation. create function durable.cascade_cancel_children ( @@ -649,13 +768,12 @@ create function durable.cascade_cancel_children ( as $$ declare v_child_id uuid; - v_child_state text; v_now timestamptz := durable.current_time(); begin -- Find all children of this parent that are not in terminal state - for v_child_id, v_child_state in + for v_child_id in execute format( - 'select task_id, state + 'select task_id from durable.%I where parent_task_id = $1 and state not in (''completed'', ''failed'', ''cancelled'') @@ -684,27 +802,15 @@ begin 'r_' || p_queue_name ) using v_child_id; - -- Delete wait registrations - execute format( - 'delete from durable.%I where task_id = $1', - 'w_' || p_queue_name - ) using v_child_id; - - -- Emit cancellation event so parent's join() can receive it - perform durable.emit_event( - p_queue_name, - '$child:' || v_child_id::text, - jsonb_build_object('status', 'cancelled') - ); - - -- Recursively cancel grandchildren - perform durable.cascade_cancel_children(p_queue_name, v_child_id); + -- Cleanup: delete waiters, emit event, and recursively cascade to grandchildren + perform durable.cleanup_task_terminal(p_queue_name, v_child_id, 'cancelled', null, true); end loop; end; $$; create function durable.fail_run ( p_queue_name text, + p_task_id uuid, p_run_id uuid, p_reason jsonb, p_retry_at timestamptz default null @@ -713,7 +819,7 @@ create function durable.fail_run ( language plpgsql as $$ declare - v_task_id uuid; + v_run_task_id uuid; v_attempt integer; v_retry_strategy jsonb; v_max_attempts integer; @@ -735,36 +841,43 @@ declare v_recorded_attempt integer; v_last_attempt_run uuid := p_run_id; v_cancelled_at timestamptz := null; - v_parent_task_id uuid; begin - -- find the run to fail + -- Lock task first to keep a consistent task -> run lock order. execute format( - 'select r.task_id, r.attempt - from durable.%I r - where r.run_id = $1 - and r.state in (''running'', ''sleeping'') + 'select retry_strategy, max_attempts, first_started_at, cancellation, state + from durable.%I + where task_id = $1 for update', - 'r_' || p_queue_name + 't_' || p_queue_name ) - into v_task_id, v_attempt - using p_run_id; + into v_retry_strategy, v_max_attempts, v_first_started, v_cancellation, v_task_state + using p_task_id; - if v_task_id is null then - raise exception 'Run "%" cannot be failed in queue "%"', p_run_id, p_queue_name; + if v_task_state is null then + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; end if; - -- get the retry strategy and metadata about task + -- Lock run after task and ensure it's still eligible execute format( - 'select retry_strategy, max_attempts, first_started_at, cancellation, state, parent_task_id + 'select task_id, attempt from durable.%I - where task_id = $1 + where run_id = $1 + and state in (''running'', ''sleeping'') for update', - 't_' || p_queue_name + 'r_' || p_queue_name ) - into v_retry_strategy, v_max_attempts, v_first_started, v_cancellation, v_task_state, v_parent_task_id - using v_task_id; + into v_run_task_id, v_attempt + using p_run_id; + + if v_run_task_id is null then + raise exception 'Run "%" cannot be failed in queue "%"', p_run_id, p_queue_name; + end if; + + if v_run_task_id <> p_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, p_task_id; + end if; - -- actually fail the run + -- Actually fail the run execute format( 'update durable.%I set state = ''failed'', @@ -779,7 +892,7 @@ begin v_task_state_after := 'failed'; v_recorded_attempt := v_attempt; - -- compute the next retry time + -- Compute the next retry time if v_max_attempts is null or v_next_attempt <= v_max_attempts then if p_retry_at is not null then v_next_available := p_retry_at; @@ -815,7 +928,7 @@ begin end if; end if; - -- set up the new run if not cancelling + -- Set up the new run if not cancelling if not v_task_cancel then v_task_state_after := case when v_next_available > v_now then 'sleeping' else 'pending' end; v_new_run_id := durable.portable_uuidv7(); @@ -827,7 +940,7 @@ begin 'r_' || p_queue_name, v_task_state_after ) - using v_new_run_id, v_task_id, v_next_attempt, v_next_available; + using v_new_run_id, p_task_id, v_next_attempt, v_next_available; end if; end if; @@ -847,26 +960,23 @@ begin where task_id = $1', 't_' || p_queue_name, v_task_state_after - ) using v_task_id, v_task_state_after, v_recorded_attempt, v_last_attempt_run, v_cancelled_at; + ) using p_task_id, v_task_state_after, v_recorded_attempt, v_last_attempt_run, v_cancelled_at; + -- Delete wait registrations for this run execute format( 'delete from durable.%I where run_id = $1', 'w_' || p_queue_name ) using p_run_id; - -- If task reached terminal failure state (failed or cancelled), emit event and cascade cancel + -- If task reached terminal state, cleanup (emit event, cascade cancel) if v_task_state_after in ('failed', 'cancelled') then - -- Cascade cancel all children - perform durable.cascade_cancel_children(p_queue_name, v_task_id); - - -- Emit completion event for parent to join on (only if this is a subtask) - if v_parent_task_id is not null then - perform durable.emit_event( - p_queue_name, - '$child:' || v_task_id::text, - jsonb_build_object('status', v_task_state_after, 'error', p_reason) - ); - end if; + perform durable.cleanup_task_terminal( + p_queue_name, + p_task_id, + v_task_state_after, + jsonb_build_object('error', p_reason), + true -- cascade cancel children + ); end if; end; $$; @@ -889,8 +999,6 @@ as $$ declare v_now timestamptz := durable.current_time(); v_new_attempt integer; - v_existing_attempt integer; - v_existing_owner uuid; v_task_state text; begin if p_step_name is null or length(trim(p_step_name)) = 0 then @@ -932,29 +1040,22 @@ begin end if; execute format( - 'select c.owner_run_id, - r.attempt - from durable.%I c - left join durable.%I r on r.run_id = c.owner_run_id - where c.task_id = $1 - and c.checkpoint_name = $2', + 'insert into durable.%I (task_id, checkpoint_name, state, owner_run_id, updated_at) + values ($1, $2, $3, $4, $5) + on conflict (task_id, checkpoint_name) + do update set state = excluded.state, + owner_run_id = excluded.owner_run_id, + updated_at = excluded.updated_at + where $6 >= coalesce( + (select r.attempt + from durable.%I r + where r.run_id = durable.%I.owner_run_id), + $6 + )', 'c_' || p_queue_name, - 'r_' || p_queue_name - ) - into v_existing_owner, v_existing_attempt - using p_task_id, p_step_name; - - if v_existing_owner is null or v_existing_attempt is null or v_new_attempt >= v_existing_attempt then - execute format( - 'insert into durable.%I (task_id, checkpoint_name, state, owner_run_id, updated_at) - values ($1, $2, $3, $4, $5) - on conflict (task_id, checkpoint_name) - do update set state = excluded.state, - owner_run_id = excluded.owner_run_id, - updated_at = excluded.updated_at', - 'c_' || p_queue_name - ) using p_task_id, p_step_name, p_state, p_owner_run, v_now; - end if; + 'r_' || p_queue_name, + 'c_' || p_queue_name + ) using p_task_id, p_step_name, p_state, p_owner_run, v_now, v_new_attempt; end; $$; @@ -1081,6 +1182,7 @@ create function durable.await_event ( as $$ declare v_run_state text; + v_run_task_id uuid; v_existing_payload jsonb; v_event_payload jsonb; v_checkpoint_payload jsonb; @@ -1123,25 +1225,39 @@ begin return; end if; - -- let's get the run state, any existing event payload and wake event name + -- Lock task first to keep a consistent task -> run lock order. execute format( - 'select r.state, r.event_payload, r.wake_event, t.state - from durable.%I r - join durable.%I t on t.task_id = r.task_id - where r.run_id = $1 - for update', - 'r_' || p_queue_name, + 'select state from durable.%I where task_id = $1 for update', 't_' || p_queue_name ) - into v_run_state, v_existing_payload, v_wake_event, v_task_state + into v_task_state + using p_task_id; + + if v_task_state is null then + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; + end if; + + if v_task_state = 'cancelled' then + raise exception sqlstate 'AB001' using message = 'Task has been cancelled'; + end if; + + -- Lock run after task + execute format( + 'select task_id, state, event_payload, wake_event + from durable.%I + where run_id = $1 + for update', + 'r_' || p_queue_name + ) + into v_run_task_id, v_run_state, v_existing_payload, v_wake_event using p_run_id; if v_run_state is null then raise exception 'Run "%" not found while awaiting event', p_run_id; end if; - if v_task_state = 'cancelled' then - raise exception sqlstate 'AB001' using message = 'Task has been cancelled'; + if v_run_task_id <> p_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, p_task_id; end if; execute format( @@ -1248,6 +1364,7 @@ as $$ declare v_now timestamptz := durable.current_time(); v_payload jsonb := coalesce(p_payload, 'null'::jsonb); + v_inserted_count integer; begin if p_event_name is null or length(trim(p_event_name)) = 0 then raise exception 'event_name must be provided'; @@ -1265,6 +1382,14 @@ begin 'e_' || p_queue_name ) using p_event_name, v_payload, v_now; + get diagnostics v_inserted_count = row_count; + + -- Only wake waiters if we actually inserted (first emit). + -- Subsequent emits are no-ops to maintain consistency. + if v_inserted_count = 0 then + return; + end if; + execute format( 'with expired_waits as ( delete from durable.%1$I w @@ -1279,6 +1404,17 @@ begin where event_name = $1 and (timeout_at is null or timeout_at > $2) ), + -- Lock tasks before updating runs to prevent waking cancelled tasks. + -- Only lock sleeping tasks to avoid interfering with other operations. + -- This prevents waking cancelled tasks (e.g., when cascade_cancel_children + -- is running concurrently). + locked_tasks as ( + select t.task_id + from durable.%4$I t + where t.task_id in (select task_id from affected) + and t.state = ''sleeping'' + for update + ), -- update the run table for all waiting runs so they are pending again updated_runs as ( update durable.%2$I r @@ -1290,6 +1426,7 @@ begin claim_expires_at = null where r.run_id in (select run_id from affected) and r.state = ''sleeping'' + and r.task_id in (select task_id from locked_tasks) returning r.run_id, r.task_id ), -- update checkpoints for all affected tasks/steps so they contain the event payload @@ -1336,16 +1473,15 @@ as $$ declare v_now timestamptz := durable.current_time(); v_task_state text; - v_parent_task_id uuid; begin execute format( - 'select state, parent_task_id + 'select state from durable.%I where task_id = $1 for update', 't_' || p_queue_name ) - into v_task_state, v_parent_task_id + into v_task_state using p_task_id; if v_task_state is null then @@ -1374,22 +1510,8 @@ begin 'r_' || p_queue_name ) using p_task_id; - execute format( - 'delete from durable.%I where task_id = $1', - 'w_' || p_queue_name - ) using p_task_id; - - -- Cascade cancel all children - perform durable.cascade_cancel_children(p_queue_name, p_task_id); - - -- Emit cancellation event for parent to join on (only if this is a subtask) - if v_parent_task_id is not null then - perform durable.emit_event( - p_queue_name, - '$child:' || p_task_id::text, - jsonb_build_object('status', 'cancelled') - ); - end if; + -- Cleanup: delete waiters, emit event, cascade cancel children + perform durable.cleanup_task_terminal(p_queue_name, p_task_id, 'cancelled', null, true); end; $$; diff --git a/src/worker.rs b/src/worker.rs index c93619b..8714715 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -339,7 +339,7 @@ impl Worker { Ok(ctx) => ctx, Err(e) => { tracing::error!("Failed to create task context: {}", e); - Self::fail_run(&pool, &queue_name, task.run_id, &e.into()).await; + Self::fail_run(&pool, &queue_name, task.task_id, task.run_id, &e.into()).await; return; } }; @@ -353,6 +353,7 @@ impl Worker { Self::fail_run( &pool, &queue_name, + task.task_id, task.run_id, &TaskError::Validation { message: format!("Unknown task: {}", task.task_name), @@ -493,7 +494,7 @@ impl Worker { { outcome = "completed"; } - Self::complete_run(&pool, &queue_name, task.run_id, output).await; + Self::complete_run(&pool, &queue_name, task.task_id, task.run_id, output).await; #[cfg(feature = "telemetry")] crate::telemetry::record_task_completed(&queue_name_for_metrics, &task_name); @@ -521,7 +522,7 @@ impl Worker { outcome = "failed"; } tracing::error!("Task {} failed: {}", task_label, e); - Self::fail_run(&pool, &queue_name, task.run_id, e).await; + Self::fail_run(&pool, &queue_name, task.task_id, task.run_id, e).await; #[cfg(feature = "telemetry")] crate::telemetry::record_task_failed( @@ -545,10 +546,17 @@ impl Worker { } } - async fn complete_run(pool: &PgPool, queue_name: &str, run_id: Uuid, result: JsonValue) { - let query = "SELECT durable.complete_run($1, $2, $3)"; + async fn complete_run( + pool: &PgPool, + queue_name: &str, + task_id: Uuid, + run_id: Uuid, + result: JsonValue, + ) { + let query = "SELECT durable.complete_run($1, $2, $3, $4)"; if let Err(e) = sqlx::query(query) .bind(queue_name) + .bind(task_id) .bind(run_id) .bind(&result) .execute(pool) @@ -558,11 +566,18 @@ impl Worker { } } - async fn fail_run(pool: &PgPool, queue_name: &str, run_id: Uuid, error: &TaskError) { + async fn fail_run( + pool: &PgPool, + queue_name: &str, + task_id: Uuid, + run_id: Uuid, + error: &TaskError, + ) { let error_json = serialize_task_error(error); - let query = "SELECT durable.fail_run($1, $2, $3, $4)"; + let query = "SELECT durable.fail_run($1, $2, $3, $4, $5)"; if let Err(e) = sqlx::query(query) .bind(queue_name) + .bind(task_id) .bind(run_id) .bind(&error_json) .bind(None::>) diff --git a/tests/common/helpers.rs b/tests/common/helpers.rs index d8302aa..cfd5085 100644 --- a/tests/common/helpers.rs +++ b/tests/common/helpers.rs @@ -1,8 +1,31 @@ use chrono::{DateTime, Utc}; +use sqlx::postgres::PgPoolOptions; use sqlx::{AssertSqlSafe, PgPool}; use std::time::Duration; use uuid::Uuid; +// ============================================================================ +// Pool helpers +// ============================================================================ + +/// Create a single-connection pool for tests that use fake_time. +/// +/// PostgreSQL session variables like `durable.fake_now` are scoped to a single +/// connection. When using a multi-connection pool, different connections may +/// have different values for the session variable, causing flaky tests. +/// +/// This function creates a pool with max_connections=1, ensuring all queries +/// run on the same connection and see the same fake_now value. +#[allow(dead_code)] +pub async fn single_conn_pool(pool: &PgPool) -> PgPool { + let connect_options = (*pool.connect_options()).clone(); + PgPoolOptions::new() + .max_connections(1) + .connect_with(connect_options) + .await + .expect("Failed to create single-connection pool") +} + /// Set fake time for deterministic testing. /// Uses the durable.fake_now session variable. #[allow(dead_code)] @@ -245,3 +268,61 @@ pub async fn get_failed_payload( .await?; Ok(result.and_then(|(p,)| p)) } + +// ============================================================================ +// Run inspection helpers (detailed) +// ============================================================================ + +/// Information about a run. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct RunInfo { + pub run_id: Uuid, + pub task_id: Uuid, + pub attempt: i32, + pub state: String, + pub available_at: DateTime, + pub claim_expires_at: Option>, + pub failure_reason: Option, +} + +/// Get all runs for a task, ordered by attempt number. +#[allow(dead_code)] +pub async fn get_runs_for_task( + pool: &PgPool, + queue: &str, + task_id: Uuid, +) -> sqlx::Result> { + let query = AssertSqlSafe(format!( + "SELECT run_id, task_id, attempt, state, available_at, claim_expires_at, failure_reason + FROM durable.r_{} WHERE task_id = $1 ORDER BY attempt", + queue + )); + type RunInfoRow = ( + Uuid, + Uuid, + i32, + String, + DateTime, + Option>, + Option, + ); + let rows: Vec = sqlx::query_as(query).bind(task_id).fetch_all(pool).await?; + + Ok(rows + .into_iter() + .map( + |(run_id, task_id, attempt, state, available_at, claim_expires_at, failure_reason)| { + RunInfo { + run_id, + task_id, + attempt, + state, + available_at, + claim_expires_at, + failure_reason, + } + }, + ) + .collect()) +} diff --git a/tests/fanout_test.rs b/tests/fanout_test.rs index d5e7ddf..4a41267 100644 --- a/tests/fanout_test.rs +++ b/tests/fanout_test.rs @@ -355,6 +355,100 @@ async fn test_cascade_cancel_when_parent_cancelled(pool: PgPool) -> sqlx::Result Ok(()) } +/// Test that cascade cancellation happens when a parent task is auto-cancelled +/// due to max_duration expiring while waiting for a child. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_cascade_cancel_when_parent_auto_cancelled_by_max_duration( + pool: PgPool, +) -> sqlx::Result<()> { + use common::helpers::{advance_time, set_fake_time, single_conn_pool, wait_for_task_terminal}; + use durable::{CancellationPolicy, SpawnOptions}; + + // Use single-conn pool for fake_time + let test_pool = single_conn_pool(&pool).await; + + let client = create_client(test_pool.clone(), "fanout_auto_cancel").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + client.register::().await.unwrap(); + + let start_time = chrono::Utc::now(); + set_fake_time(&test_pool, start_time).await?; + + // Spawn parent with max_duration of 2 seconds + // The child will sleep for 10 seconds, so the parent will auto-cancel + let spawn_result = client + .spawn_with_options::( + SpawnSlowChildParams { + child_sleep_ms: 10000, // 10 seconds + }, + SpawnOptions { + cancellation: Some(CancellationPolicy { + max_pending_time: None, + max_running_time: Some(Duration::from_secs(2)), // 2 seconds max duration + }), + ..Default::default() + }, + ) + .await + .expect("Failed to spawn task"); + + // Start worker + let worker = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(30), + concurrency: 2, + ..Default::default() + }) + .await + .unwrap(); + + // Wait for parent to spawn child + tokio::time::sleep(Duration::from_millis(500)).await; + + // Verify child was spawned + let query = "SELECT task_id FROM durable.t_fanout_auto_cancel WHERE parent_task_id = $1"; + let child_ids: Vec<(uuid::Uuid,)> = sqlx::query_as(query) + .bind(spawn_result.task_id) + .fetch_all(&test_pool) + .await?; + + assert!(!child_ids.is_empty(), "Child should have been spawned"); + + // Advance time past max_duration + advance_time(&test_pool, 3).await?; + + // Give time for auto-cancellation to trigger + tokio::time::sleep(Duration::from_millis(500)).await; + + // Wait for parent to reach terminal state + let terminal = wait_for_task_terminal( + &test_pool, + "fanout_auto_cancel", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + worker.shutdown().await; + + // Parent should be cancelled due to max_duration + assert_eq!( + terminal, + Some("cancelled".to_string()), + "Parent should be auto-cancelled due to max_duration" + ); + + // Child should also be cancelled (cascade) + let child_state = get_task_state(&pool, "fanout_auto_cancel", child_ids[0].0).await; + assert_eq!( + child_state, "cancelled", + "Child should be cascade cancelled when parent is auto-cancelled" + ); + + Ok(()) +} + // ============================================================================ // spawn_by_name Tests // ============================================================================ diff --git a/tests/lock_order_test.rs b/tests/lock_order_test.rs new file mode 100644 index 0000000..9056154 --- /dev/null +++ b/tests/lock_order_test.rs @@ -0,0 +1,505 @@ +#![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] + +//! Tests for lock ordering in SQL functions. +//! +//! These tests verify that functions that touch both tasks and runs +//! acquire locks in a consistent order (task first, then run) to prevent +//! deadlocks. +//! +//! The lock ordering pattern is: +//! 1. Find the task_id (no lock) +//! 2. Lock the task FOR UPDATE +//! 3. Lock the run FOR UPDATE +//! +//! Without this ordering, two concurrent transactions could deadlock: +//! - Transaction A: locks run, waits for task +//! - Transaction B: locks task, waits for run + +mod common; + +use common::helpers::{get_task_state, single_conn_pool, wait_for_task_terminal}; +use common::tasks::{ + DoubleParams, DoubleTask, FailingParams, FailingTask, SleepParams, SleepingTask, +}; +use durable::{Durable, MIGRATOR, RetryStrategy, SpawnOptions, WorkerOptions}; +use sqlx::postgres::{PgConnectOptions, PgConnection}; +use sqlx::{AssertSqlSafe, Connection, PgPool}; +use std::time::{Duration, Instant}; +use uuid::Uuid; + +async fn create_client(pool: PgPool, queue_name: &str) -> Durable { + Durable::builder() + .pool(pool) + .queue_name(queue_name) + .build() + .await + .expect("Failed to create Durable client") +} + +// ============================================================================ +// Lock Ordering Tests +// ============================================================================ + +/// Test that complete_run works correctly with the lock ordering. +/// Completes a task and verifies the task reaches completed state. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_complete_run_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "lock_complete").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let spawn_result = client + .spawn::(DoubleParams { value: 21 }) + .await + .expect("Failed to spawn task"); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(30), + ..Default::default() + }) + .await + .unwrap(); + + let terminal = wait_for_task_terminal( + &pool, + "lock_complete", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + worker.shutdown().await; + + assert_eq!(terminal, Some("completed".to_string())); + + Ok(()) +} + +/// Test that fail_run works correctly with the lock ordering. +/// Fails a task and verifies it eventually reaches failed state after retries. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_fail_run_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "lock_fail").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let spawn_result = client + .spawn_with_options::( + FailingParams { + error_message: "intentional failure".to_string(), + }, + SpawnOptions { + retry_strategy: Some(RetryStrategy::Fixed { + base_delay: Duration::from_secs(0), + }), + max_attempts: Some(2), + ..Default::default() + }, + ) + .await + .expect("Failed to spawn task"); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(30), + ..Default::default() + }) + .await + .unwrap(); + + let terminal = wait_for_task_terminal( + &pool, + "lock_fail", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + worker.shutdown().await; + + assert_eq!(terminal, Some("failed".to_string())); + + Ok(()) +} + +/// Test that sleep_for works correctly with the lock ordering. +/// Sleeps and verifies the task suspends and then completes. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_sleep_for_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "lock_sleep").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let spawn_result = client + .spawn::(SleepParams { seconds: 1 }) + .await + .expect("Failed to spawn task"); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(30), + ..Default::default() + }) + .await + .unwrap(); + + // Wait for task to complete + let terminal = wait_for_task_terminal( + &pool, + "lock_sleep", + spawn_result.task_id, + Duration::from_secs(10), + ) + .await?; + worker.shutdown().await; + + assert_eq!(terminal, Some("completed".to_string())); + + Ok(()) +} + +/// Test concurrent complete and cancel operations don't deadlock. +/// This would deadlock if lock ordering were inconsistent. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_concurrent_complete_and_cancel(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "lock_conc_cc").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + // Spawn several tasks + let mut task_ids = Vec::new(); + for _ in 0..5 { + let spawn_result = client + .spawn::(SleepParams { seconds: 1 }) + .await + .expect("Failed to spawn task"); + task_ids.push(spawn_result.task_id); + } + + let worker = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(30), + concurrency: 5, + ..Default::default() + }) + .await + .unwrap(); + + // Let tasks start + tokio::time::sleep(Duration::from_millis(50)).await; + + // Cancel some tasks while others are completing + for (i, task_id) in task_ids.iter().enumerate() { + if i % 2 == 0 { + // Ignore errors - task might already be completed + let _ = client.cancel_task(*task_id, None).await; + } + } + + // Wait for all tasks to reach terminal state + for task_id in &task_ids { + let _ = + wait_for_task_terminal(&pool, "lock_conc_cc", *task_id, Duration::from_secs(5)).await?; + } + + worker.shutdown().await; + + // All tasks should be in terminal state (completed or cancelled) + for task_id in &task_ids { + let state = get_task_state(&pool, "lock_conc_cc", *task_id).await?; + assert!( + state == Some("completed".to_string()) || state == Some("cancelled".to_string()), + "Task should be terminal, got {:?}", + state + ); + } + + Ok(()) +} + +/// Test that emit_event wakes sleeping tasks correctly with lock ordering. +/// This tests the emit_event function's locked_tasks CTE. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_emit_event_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { + use common::tasks::{EventWaitParams, EventWaitingTask}; + + let client = create_client(pool.clone(), "lock_emit").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let spawn_result = client + .spawn::(EventWaitParams { + event_name: "test_event".to_string(), + timeout_seconds: Some(30), + }) + .await + .expect("Failed to spawn task"); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(30), + ..Default::default() + }) + .await + .unwrap(); + + // Wait for task to start sleeping (awaiting event) + tokio::time::sleep(Duration::from_millis(200)).await; + + let state = get_task_state(&pool, "lock_emit", spawn_result.task_id).await?; + assert_eq!( + state, + Some("sleeping".to_string()), + "Task should be sleeping waiting for event" + ); + + // Emit the event + let emit_query = AssertSqlSafe( + "SELECT durable.emit_event('lock_emit', 'test_event', '\"hello\"'::jsonb)".to_string(), + ); + sqlx::query(emit_query).execute(&pool).await?; + + // Wait for task to complete + let terminal = wait_for_task_terminal( + &pool, + "lock_emit", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + worker.shutdown().await; + + assert_eq!(terminal, Some("completed".to_string())); + + Ok(()) +} + +/// Test concurrent emit and cancel operations. +/// This tests that emit_event's locked_tasks CTE properly handles +/// tasks being cancelled concurrently. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_concurrent_emit_and_cancel(pool: PgPool) -> sqlx::Result<()> { + use common::tasks::{EventWaitParams, EventWaitingTask}; + + // Use single-conn pool to ensure deterministic event emission + let test_pool = single_conn_pool(&pool).await; + + let client = create_client(pool.clone(), "lock_emit_cancel").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + // Spawn multiple tasks waiting for the same event + let mut task_ids = Vec::new(); + for _ in 0..3 { + let spawn_result = client + .spawn::(EventWaitParams { + event_name: "shared_event".to_string(), + timeout_seconds: Some(30), + }) + .await + .expect("Failed to spawn task"); + task_ids.push(spawn_result.task_id); + } + + let worker = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(30), + concurrency: 5, + ..Default::default() + }) + .await + .unwrap(); + + // Wait for all tasks to start sleeping + tokio::time::sleep(Duration::from_millis(500)).await; + + for task_id in &task_ids { + let state = get_task_state(&pool, "lock_emit_cancel", *task_id).await?; + assert_eq!( + state, + Some("sleeping".to_string()), + "Task should be sleeping" + ); + } + + // Cancel one task while emitting the event + let cancel_task_id = task_ids[0]; + let emit_handle = tokio::spawn({ + let test_pool = test_pool.clone(); + async move { + let emit_query = AssertSqlSafe( + "SELECT durable.emit_event('lock_emit_cancel', 'shared_event', '\"wakeup\"'::jsonb)" + .to_string(), + ); + sqlx::query(emit_query).execute(&test_pool).await + } + }); + + // Cancel concurrently + let _ = client.cancel_task(cancel_task_id, None).await; + + // Wait for emit to complete + emit_handle.await.unwrap()?; + + // Wait for all tasks to reach terminal state + for task_id in &task_ids { + let _ = wait_for_task_terminal(&pool, "lock_emit_cancel", *task_id, Duration::from_secs(5)) + .await?; + } + + worker.shutdown().await; + + // The cancelled task should be cancelled, others should be completed + let cancelled_state = get_task_state(&pool, "lock_emit_cancel", cancel_task_id).await?; + assert_eq!( + cancelled_state, + Some("cancelled".to_string()), + "Cancelled task should be cancelled" + ); + + for task_id in &task_ids[1..] { + let state = get_task_state(&pool, "lock_emit_cancel", *task_id).await?; + assert_eq!( + state, + Some("completed".to_string()), + "Non-cancelled task should be completed" + ); + } + + Ok(()) +} + +/// Regression test: claim_task uses SKIP LOCKED to avoid deadlock with emit_event. +/// +/// emit_event locks tasks first (locked_tasks CTE with FOR UPDATE), then updates runs. +/// claim_task joins runs+tasks with FOR UPDATE SKIP LOCKED. +/// +/// This test verifies that when a task is locked (simulating emit_event holding the lock), +/// claim_task skips that task instead of blocking (which would cause deadlock). +/// +/// We make the test deterministic by: +/// - Creating a task and making it claimable (pending state) +/// - Holding a FOR UPDATE lock on the task row in a separate connection +/// - Calling claim_task - it should complete immediately with 0 results (not block) +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_claim_task_skips_locked_tasks_no_deadlock(pool: PgPool) -> sqlx::Result<()> { + let queue = "skip_locked_test"; + + // Setup: Create queue and spawn a task + sqlx::query("SELECT durable.create_queue($1)") + .bind(queue) + .execute(&pool) + .await?; + + let (task_id, run_id): (Uuid, Uuid) = + sqlx::query_as("SELECT task_id, run_id FROM durable.spawn_task($1, $2, $3, $4)") + .bind(queue) + .bind("test-task") + .bind(serde_json::json!({})) + .bind(serde_json::json!({})) + .fetch_one(&pool) + .await?; + + // Verify task is pending and claimable + let state = get_task_state(&pool, queue, task_id).await?; + assert_eq!(state, Some("pending".to_string())); + + // Get connect options from pool for creating separate connections + let connect_opts: PgConnectOptions = (*pool.connect_options()).clone(); + + // Open lock connection and hold FOR UPDATE lock on the task row + // This simulates emit_event's locked_tasks CTE holding the lock mid-transaction + let lock_opts = connect_opts.clone().application_name("durable-task-locker"); + let mut lock_conn = PgConnection::connect_with(&lock_opts).await?; + + sqlx::query("BEGIN").execute(&mut lock_conn).await?; + sqlx::query(AssertSqlSafe(format!( + "SELECT 1 FROM durable.t_{} WHERE task_id = $1 FOR UPDATE", + queue + ))) + .bind(task_id) + .execute(&mut lock_conn) + .await?; + + // Wait until the lock is confirmed held by checking pg_stat_activity + let deadline = Instant::now() + Duration::from_secs(5); + loop { + let row: Option<(String,)> = + sqlx::query_as("SELECT state FROM pg_stat_activity WHERE application_name = $1") + .bind("durable-task-locker") + .fetch_optional(&pool) + .await?; + + if let Some((ref state,)) = row + && state == "idle in transaction" + { + break; + } + assert!( + Instant::now() < deadline, + "Lock connection did not reach expected state" + ); + tokio::time::sleep(Duration::from_millis(10)).await; + } + + // Now call claim_task from another connection + // If SKIP LOCKED works correctly, it should complete immediately with 0 results + // If SKIP LOCKED didn't apply to the task table, it would block and timeout + let claim_opts = connect_opts.clone().application_name("durable-claimer"); + let mut claim_conn = PgConnection::connect_with(&claim_opts).await?; + + // Set a short statement timeout - if claim_task blocks, it will fail + sqlx::query("SET statement_timeout = '500ms'") + .execute(&mut claim_conn) + .await?; + + let claim_result: Vec<(Uuid,)> = + sqlx::query_as("SELECT run_id FROM durable.claim_task($1, $2, $3, $4)") + .bind(queue) + .bind("worker") + .bind(60) + .bind(1) + .fetch_all(&mut claim_conn) + .await?; + + // claim_task should have completed (not timed out) and returned 0 results + // because the task was locked and SKIP LOCKED caused it to be skipped + assert!( + claim_result.is_empty(), + "claim_task should skip locked task, but got {} results", + claim_result.len() + ); + + // Reset statement timeout + sqlx::query("SET statement_timeout = 0") + .execute(&mut claim_conn) + .await?; + + // Release the lock + sqlx::query("ROLLBACK").execute(&mut lock_conn).await?; + drop(lock_conn); + + // Now claim_task should be able to claim the task + let claim_result2: Vec<(Uuid,)> = + sqlx::query_as("SELECT run_id FROM durable.claim_task($1, $2, $3, $4)") + .bind(queue) + .bind("worker") + .bind(60) + .bind(1) + .fetch_all(&mut claim_conn) + .await?; + + assert_eq!( + claim_result2.len(), + 1, + "claim_task should claim the task after lock is released" + ); + assert_eq!(claim_result2[0].0, run_id); + + Ok(()) +}