From 6263d73616aa151a25e41f5ac20e5594ca57a738 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Sat, 20 Dec 2025 18:43:45 -0500 Subject: [PATCH 1/7] fix: prevent lost wakeups in event await/emit race condition Add advisory locks to serialize concurrent await_event and emit_event operations on the same event. This prevents a race condition where: 1. Task A checks if event exists (not yet) 2. Task B emits the event and wakes waiters (none yet) 3. Task A registers as a waiter (missed the wake) The lock_event() function uses pg_advisory_xact_lock with hashed queue_name and event_name to ensure atomicity. Also changes emit_event to first-writer-wins semantics (ON CONFLICT DO NOTHING) to maintain consistency - subsequent emits for the same event are no-ops. Tests: - test_event_functions_use_advisory_locks: Verifies both functions call lock_event - test_event_race_stress: Stress test with 128 concurrent tasks x 4 rounds - test_event_first_writer_wins: Renamed from test_event_last_write_wins --- .../20251202002136_initial_setup.sql | 26 ++- tests/event_test.rs | 158 +++++++++++++++++- 2 files changed, 176 insertions(+), 8 deletions(-) diff --git a/src/postgres/migrations/20251202002136_initial_setup.sql b/src/postgres/migrations/20251202002136_initial_setup.sql index 7c1ba1b..db2325d 100644 --- a/src/postgres/migrations/20251202002136_initial_setup.sql +++ b/src/postgres/migrations/20251202002136_initial_setup.sql @@ -1048,6 +1048,19 @@ begin end; $$; +-- Advisory lock to serialize await_event and emit_event operations on the same event. +-- This prevents lost wakeups when a waiter is being set up while an emit is happening. +-- Called at the top of await_event and emit_event. +create function durable.lock_event ( + p_queue_name text, + p_event_name text +) + returns void + language sql +as $$ + select pg_advisory_xact_lock(hashtext(p_queue_name), hashtext(p_event_name)); +$$; + -- awaits an event for a given task's run and step name. -- this will immediately return if it the event has already returned -- it will also time out if the event has taken too long @@ -1082,6 +1095,9 @@ begin raise exception 'event_name must be provided'; end if; + -- Serialize with concurrent emit_event calls on the same event + perform durable.lock_event(p_queue_name, p_event_name); + if p_timeout is not null then if p_timeout < 0 then raise exception 'timeout must be non-negative'; @@ -1237,13 +1253,15 @@ begin raise exception 'event_name must be provided'; end if; - -- insert the event into the events table + -- Serialize with concurrent await_event calls on the same event + perform durable.lock_event(p_queue_name, p_event_name); + + -- Insert the event into the events table (first-writer-wins). + -- Subsequent emits for the same event are no-ops. execute format( 'insert into durable.%I (event_name, payload, emitted_at) values ($1, $2, $3) - on conflict (event_name) - do update set payload = excluded.payload, - emitted_at = excluded.emitted_at', + on conflict (event_name) do nothing', 'e_' || p_queue_name ) using p_event_name, v_payload, v_now; diff --git a/tests/event_test.rs b/tests/event_test.rs index f162800..0a4ca1e 100644 --- a/tests/event_test.rs +++ b/tests/event_test.rs @@ -330,9 +330,10 @@ async fn test_event_payload_preserved_on_retry(pool: PgPool) -> sqlx::Result<()> Ok(()) } -/// Test that emitting an event with the same name updates the payload (last-write-wins). +/// Test that emitting an event with the same name keeps the first payload (first-writer-wins). +/// Subsequent emits for the same event are no-ops to maintain consistency with lost-wakeup prevention. #[sqlx::test(migrator = "MIGRATOR")] -async fn test_event_last_write_wins(pool: PgPool) -> sqlx::Result<()> { +async fn test_event_first_writer_wins(pool: PgPool) -> sqlx::Result<()> { let client = create_client(pool.clone(), "event_dedup").await; client.create_queue(None).await.unwrap(); client.register::().await.unwrap(); @@ -376,7 +377,7 @@ async fn test_event_last_write_wins(pool: PgPool) -> sqlx::Result<()> { assert_eq!(terminal, Some("completed".to_string())); - // Should receive the second payload (last-write-wins) + // Should receive the first payload (first-writer-wins) let query = AssertSqlSafe( "SELECT completed_payload FROM durable.t_event_dedup WHERE task_id = $1".to_string(), ); @@ -385,7 +386,7 @@ async fn test_event_last_write_wins(pool: PgPool) -> sqlx::Result<()> { .fetch_one(&pool) .await?; - assert_eq!(result.0, json!({"version": "second"})); + assert_eq!(result.0, json!({"version": "first"})); Ok(()) } @@ -791,3 +792,152 @@ async fn test_emit_event_with_empty_name_fails(pool: PgPool) -> sqlx::Result<()> Ok(()) } + +// ============================================================================ +// Advisory Lock Tests +// ============================================================================ + +/// Test that both await_event and emit_event use advisory locks for synchronization. +/// This verifies the implementation calls lock_event() by inspecting function definitions. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_event_functions_use_advisory_locks(pool: PgPool) -> sqlx::Result<()> { + // Check that await_event calls lock_event + let await_def: (String,) = sqlx::query_as( + "SELECT pg_get_functiondef(oid) FROM pg_proc WHERE proname = 'await_event' AND pronamespace = (SELECT oid FROM pg_namespace WHERE nspname = 'durable')" + ) + .fetch_one(&pool) + .await?; + + assert!( + await_def.0.contains("lock_event"), + "await_event should call lock_event for advisory locking" + ); + + // Check that emit_event calls lock_event + let emit_def: (String,) = sqlx::query_as( + "SELECT pg_get_functiondef(oid) FROM pg_proc WHERE proname = 'emit_event' AND pronamespace = (SELECT oid FROM pg_namespace WHERE nspname = 'durable')" + ) + .fetch_one(&pool) + .await?; + + assert!( + emit_def.0.contains("lock_event"), + "emit_event should call lock_event for advisory locking" + ); + + Ok(()) +} + +/// Stress test to verify that advisory locks prevent lost wakeups. +/// This test spawns many tasks waiting on distinct events and emits all events +/// with jittered timing to maximize race condition likelihood. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_event_race_stress(pool: PgPool) -> sqlx::Result<()> { + // Configurable via environment variables for CI tuning + let rounds: usize = std::env::var("DURABLE_EVENT_RACE_ROUNDS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(4); + let tasks_per_round: usize = std::env::var("DURABLE_EVENT_RACE_TASKS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(128); + let jitter_ms: u64 = std::env::var("DURABLE_EVENT_RACE_JITTER_MS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(8); + + let client = create_client(pool.clone(), "event_race").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: 0.01, + claim_timeout: 60, + concurrency: 32, + ..Default::default() + }) + .await; + + for round in 0..rounds { + let mut task_ids = Vec::with_capacity(tasks_per_round); + let mut event_names = Vec::with_capacity(tasks_per_round); + + // Spawn tasks waiting on unique events + for i in 0..tasks_per_round { + let event_name = format!("race_event_r{}_{}", round, i); + event_names.push(event_name.clone()); + + let spawn_result = client + .spawn::(EventWaitParams { + event_name, + timeout_seconds: Some(30), + }) + .await + .expect("Failed to spawn task"); + task_ids.push(spawn_result.task_id); + } + + // Brief pause to let tasks start waiting + tokio::time::sleep(Duration::from_millis(100)).await; + + // Emit all events with jittered timing to maximize race conditions + let pool_for_emit = pool.clone(); + let emit_handles: Vec<_> = event_names + .into_iter() + .enumerate() + .map(|(i, event_name)| { + let pool = pool_for_emit.clone(); + tokio::spawn(async move { + // Jitter: vary start times + let jitter = Duration::from_micros((i as u64 * 17) % (jitter_ms * 1000)); + tokio::time::sleep(jitter).await; + let emit_client = create_client(pool, "event_race").await; + emit_client + .emit_event::(&event_name, &json!({"idx": i}), None) + .await + .expect("Failed to emit event"); + }) + }) + .collect(); + + // Wait for all emits to complete + for handle in emit_handles { + handle.await.expect("Emit task panicked"); + } + + // Check for orphaned waiters (wait registrations for already-emitted events) + // This indicates a lost wakeup + let orphaned_count: (i64,) = sqlx::query_as(AssertSqlSafe( + "SELECT COUNT(*) FROM durable.w_event_race w + WHERE EXISTS (SELECT 1 FROM durable.e_event_race e WHERE e.event_name = w.event_name)" + .to_string(), + )) + .fetch_one(&pool) + .await?; + + if orphaned_count.0 > 0 { + panic!( + "Round {}: Found {} orphaned waiters for already-emitted events (lost wakeup detected!)", + round, orphaned_count.0 + ); + } + + // Wait for all tasks to complete + for task_id in task_ids { + let terminal = + wait_for_task_terminal(&pool, "event_race", task_id, Duration::from_secs(10)) + .await?; + assert_eq!( + terminal, + Some("completed".to_string()), + "Round {}: Task should complete after event is emitted", + round + ); + } + } + + worker.shutdown().await; + Ok(()) +} From 6f851f4954308f726c39c1bd28f87426498d1074 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Sat, 20 Dec 2025 19:14:07 -0500 Subject: [PATCH 2/7] Lock Ordering + Cleanup Consolidation Problem 1. Deadlock risk: Functions that touch both tasks and runs could deadlock if they acquired locks in inconsistent order 2. Scattered cleanup logic: Terminal task cleanup (deleting waiters, emitting parent events, cascading cancellation) was duplicated across multiple functions 3. Incomplete cascade cancellation: Auto-cancelled tasks (via max_duration) didn't cascade cancel their children Solution Lock Ordering: All functions now acquire locks in consistent order (task first, then run): - complete_run, fail_run, sleep_for, await_event: lock task FOR UPDATE before locking run - claim_task: lock task before run when handling expired claims - emit_event: lock sleeping tasks before waking runs (via locked_tasks CTE) Cleanup Consolidation: New cleanup_task_terminal() function handles: - Deleting wait registrations for the task - Emitting completion event for parent ($child:) - Optionally cascading cancellation to children Used by: complete_run, fail_run, cancel_task, cascade_cancel_children, claim_task Other improvements: - emit_event early return when event already exists (optimization) - sleep_for now takes task_id parameter for proper lock ordering Tests - New lock_order_test.rs with 6 tests verifying lock ordering works correctly - New test_cascade_cancel_when_parent_auto_cancelled_by_max_duration in fanout_test.rs - New test helpers: single_conn_pool(), RunInfo, get_runs_for_task() --- src/context.rs | 16 +- .../20251202002136_initial_setup.sql | 460 ++++++++++++------ tests/common/helpers.rs | 81 +++ tests/fanout_test.rs | 93 ++++ tests/lock_order_test.rs | 365 ++++++++++++++ 5 files changed, 856 insertions(+), 159 deletions(-) create mode 100644 tests/lock_order_test.rs diff --git a/src/context.rs b/src/context.rs index 80edad6..e9c07e6 100644 --- a/src/context.rs +++ b/src/context.rs @@ -308,13 +308,15 @@ where let checkpoint_name = self.get_checkpoint_name(name, &())?; let duration_ms = duration.as_millis() as i64; - let (needs_suspend,): (bool,) = sqlx::query_as("SELECT durable.sleep_for($1, $2, $3, $4)") - .bind(&self.queue_name) - .bind(self.run_id) - .bind(&checkpoint_name) - .bind(duration_ms) - .fetch_one(&self.pool) - .await?; + let (needs_suspend,): (bool,) = + sqlx::query_as("SELECT durable.sleep_for($1, $2, $3, $4, $5)") + .bind(&self.queue_name) + .bind(self.task_id) + .bind(self.run_id) + .bind(&checkpoint_name) + .bind(duration_ms) + .fetch_one(&self.pool) + .await?; if needs_suspend { return Err(TaskError::Control(ControlFlow::Suspend)); diff --git a/src/postgres/migrations/20251202002136_initial_setup.sql b/src/postgres/migrations/20251202002136_initial_setup.sql index db2325d..f2971fc 100644 --- a/src/postgres/migrations/20251202002136_initial_setup.sql +++ b/src/postgres/migrations/20251202002136_initial_setup.sql @@ -334,6 +334,7 @@ declare v_claim_until timestamptz; v_sql text; v_expired_run record; + v_cancelled_task uuid; begin if v_claim_timeout <= 0 then raise exception 'claim_timeout must be greater than zero'; @@ -344,54 +345,76 @@ begin -- Apply cancellation rules before claiming. -- These are max_delay (delay before starting) and -- max_duration (duration from created to finished) - execute format( - 'with limits as ( - select task_id, - (cancellation->>''max_delay'')::bigint as max_delay, - (cancellation->>''max_duration'')::bigint as max_duration, - enqueue_at, - first_started_at, - state - from durable.%I - where state in (''pending'', ''sleeping'', ''running'') - ), - to_cancel as ( - select task_id - from limits - where - ( - max_delay is not null - and first_started_at is null - and extract(epoch from ($1 - enqueue_at)) >= max_delay - ) - or - ( - max_duration is not null - and first_started_at is not null - and extract(epoch from ($1 - first_started_at)) >= max_duration - ) - ) - update durable.%I t - set state = ''cancelled'', - cancelled_at = coalesce(t.cancelled_at, $1) - where t.task_id in (select task_id from to_cancel)', - 't_' || p_queue_name, - 't_' || p_queue_name - ) using v_now; + -- Use a loop so we can cleanup each cancelled task properly. + for v_cancelled_task in + execute format( + 'with limits as ( + select task_id, + (cancellation->>''max_delay'')::bigint as max_delay, + (cancellation->>''max_duration'')::bigint as max_duration, + enqueue_at, + first_started_at, + state + from durable.%I + where state in (''pending'', ''sleeping'', ''running'') + ), + to_cancel as ( + select task_id + from limits + where + ( + max_delay is not null + and first_started_at is null + and extract(epoch from ($1 - enqueue_at)) >= max_delay + ) + or + ( + max_duration is not null + and first_started_at is not null + and extract(epoch from ($1 - first_started_at)) >= max_duration + ) + ) + update durable.%I t + set state = ''cancelled'', + cancelled_at = coalesce(t.cancelled_at, $1) + where t.task_id in (select task_id from to_cancel) + returning t.task_id', + 't_' || p_queue_name, + 't_' || p_queue_name + ) using v_now + loop + -- Cancel all runs for this task + execute format( + 'update durable.%I + set state = ''cancelled'', + claimed_by = null, + claim_expires_at = null + where task_id = $1 + and state not in (''completed'', ''failed'', ''cancelled'')', + 'r_' || p_queue_name + ) using v_cancelled_task; + + -- Cleanup: delete waiters, emit event, cascade cancel children + perform durable.cleanup_task_terminal(p_queue_name, v_cancelled_task, 'cancelled', null, true); + end loop; - -- Fail any run claims that have timed out + -- Fail any run claims that have timed out. + -- Lock tasks first to keep a consistent task -> run lock order. for v_expired_run in execute format( - 'select run_id, - claimed_by, - claim_expires_at, - attempt - from durable.%I - where state = ''running'' - and claim_expires_at is not null - and claim_expires_at <= $1 - for update skip locked', - 'r_' || p_queue_name + 'select r.run_id, + r.task_id, + r.claimed_by, + r.claim_expires_at, + r.attempt + from durable.%I r + join durable.%I t on t.task_id = r.task_id + where r.state = ''running'' + and r.claim_expires_at is not null + and r.claim_expires_at <= $1 + for update of t skip locked', + 'r_' || p_queue_name, + 't_' || p_queue_name ) using v_now loop @@ -507,10 +530,42 @@ create function durable.complete_run ( as $$ declare v_task_id uuid; + v_task_id_locked uuid; + v_run_task_id uuid; v_state text; - v_parent_task_id uuid; v_now timestamptz := durable.current_time(); begin + -- Lock task first to keep a consistent task -> run lock order. + -- Find task for this run (no lock). + execute format( + 'select task_id + from durable.%I + where run_id = $1', + 'r_' || p_queue_name + ) + into v_task_id + using p_run_id; + + if v_task_id is null then + raise exception 'Run "%" not found in queue "%"', p_run_id, p_queue_name; + end if; + + -- Lock the task + execute format( + 'select task_id + from durable.%I + where task_id = $1 + for update', + 't_' || p_queue_name + ) + into v_task_id_locked + using v_task_id; + + if v_task_id_locked is null then + raise exception 'Task "%" not found in queue "%"', v_task_id, p_queue_name; + end if; + + -- Lock the run after the task lock execute format( 'select task_id, state from durable.%I @@ -518,17 +573,22 @@ begin for update', 'r_' || p_queue_name ) - into v_task_id, v_state + into v_run_task_id, v_state using p_run_id; - if v_task_id is null then + if v_run_task_id is null then raise exception 'Run "%" not found in queue "%"', p_run_id, p_queue_name; end if; + if v_run_task_id <> v_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, v_task_id; + end if; + if v_state <> 'running' then raise exception 'Run "%" is not currently running in queue "%"', p_run_id, p_queue_name; end if; + -- Update run to completed execute format( 'update durable.%I set state = ''completed'', @@ -538,38 +598,30 @@ begin 'r_' || p_queue_name ) using p_run_id, v_now, p_state; - -- Get parent_task_id to check if this is a subtask + -- Update task to completed execute format( 'update durable.%I set state = ''completed'', completed_payload = $2, last_attempt_run = $3 - where task_id = $1 - returning parent_task_id', + where task_id = $1', 't_' || p_queue_name - ) - into v_parent_task_id - using v_task_id, p_state, p_run_id; - - -- Clean up any wait registrations for this run - execute format( - 'delete from durable.%I where run_id = $1', - 'w_' || p_queue_name - ) using p_run_id; - - -- Emit completion event for parent to join on (only if this is a subtask) - if v_parent_task_id is not null then - perform durable.emit_event( - p_queue_name, - '$child:' || v_task_id::text, - jsonb_build_object('status', 'completed', 'result', p_state) - ); - end if; + ) using v_task_id, p_state, p_run_id; + + -- Cleanup: delete waiters and emit completion event for parent + perform durable.cleanup_task_terminal( + p_queue_name, + v_task_id, + 'completed', + jsonb_build_object('result', p_state), + false -- don't cascade cancel children for completed tasks + ); end; $$; create function durable.sleep_for( p_queue_name text, + p_task_id uuid, p_run_id uuid, p_checkpoint_name text, p_duration_ms bigint @@ -581,33 +633,66 @@ declare v_wake_at timestamptz; v_existing_state jsonb; v_now timestamptz := durable.current_time(); - v_task_id uuid; + v_run_task_id uuid; + v_run_state text; + v_task_state text; begin - -- Get task_id from run (needed for checkpoint table key) + -- Lock task first to keep a consistent task -> run lock order. + execute format( + 'select state from durable.%I where task_id = $1 for update', + 't_' || p_queue_name + ) into v_task_state using p_task_id; + + if v_task_state is null then + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; + end if; + + if v_task_state = 'cancelled' then + raise exception sqlstate 'AB001' using message = 'Task has been cancelled'; + end if; + + -- Lock run after task execute format( - 'select task_id from durable.%I where run_id = $1 and state = ''running'' for update', + 'select task_id, state from durable.%I where run_id = $1 for update', 'r_' || p_queue_name - ) into v_task_id using p_run_id; + ) into v_run_task_id, v_run_state using p_run_id; - if v_task_id is null then + if v_run_task_id is null then + raise exception 'Run "%" not found in queue "%"', p_run_id, p_queue_name; + end if; + + if v_run_task_id <> p_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, p_task_id; + end if; + + if v_run_state <> 'running' then raise exception 'Run "%" is not currently running in queue "%"', p_run_id, p_queue_name; end if; - -- Check for existing checkpoint, else compute and store wake time + -- Check for existing checkpoint execute format( 'select state from durable.%I where task_id = $1 and checkpoint_name = $2', 'c_' || p_queue_name - ) into v_existing_state using v_task_id, p_checkpoint_name; + ) into v_existing_state using p_task_id, p_checkpoint_name; if v_existing_state is not null then v_wake_at := (v_existing_state #>> '{}')::timestamptz; else + -- Compute wake time and store checkpoint (first-writer-wins) v_wake_at := v_now + (p_duration_ms || ' milliseconds')::interval; execute format( 'insert into durable.%I (task_id, checkpoint_name, state, owner_run_id, updated_at) - values ($1, $2, $3, $4, $5)', + values ($1, $2, $3, $4, $5) + on conflict (task_id, checkpoint_name) do nothing', + 'c_' || p_queue_name + ) using p_task_id, p_checkpoint_name, to_jsonb(v_wake_at::text), p_run_id, v_now; + + -- Re-read in case we lost the race (first-writer-wins) + execute format( + 'select state from durable.%I where task_id = $1 and checkpoint_name = $2', 'c_' || p_queue_name - ) using v_task_id, p_checkpoint_name, to_jsonb(v_wake_at::text), p_run_id, v_now; + ) into v_existing_state using p_task_id, p_checkpoint_name; + v_wake_at := (v_existing_state #>> '{}')::timestamptz; end if; -- If wake time passed, return false (no suspend needed) @@ -632,12 +717,60 @@ begin set state = ''sleeping'' where task_id = $1', 't_' || p_queue_name - ) using v_task_id; + ) using p_task_id; return true; end; $$; +-- Consolidates cleanup logic for a task that has reached a terminal state. +-- This function: +-- 1. Deletes wait registrations for the task +-- 2. Emits a completion event for the parent (if this is a subtask) +-- 3. Optionally cascades cancellation to children +-- +-- Called by: complete_run, fail_run, cancel_task, cascade_cancel_children, claim_task +create function durable.cleanup_task_terminal ( + p_queue_name text, + p_task_id uuid, + p_status text, -- 'completed', 'failed', 'cancelled' + p_payload jsonb default null, + p_cascade_children boolean default false +) + returns void + language plpgsql +as $$ +declare + v_parent_task_id uuid; +begin + -- Get parent_task_id for event emission + execute format( + 'select parent_task_id from durable.%I where task_id = $1', + 't_' || p_queue_name + ) into v_parent_task_id using p_task_id; + + -- Delete wait registrations for this task + execute format( + 'delete from durable.%I where task_id = $1', + 'w_' || p_queue_name + ) using p_task_id; + + -- Emit completion event for parent (if subtask) + if v_parent_task_id is not null then + perform durable.emit_event( + p_queue_name, + '$child:' || p_task_id::text, + jsonb_build_object('status', p_status) || coalesce(p_payload, '{}'::jsonb) + ); + end if; + + -- Cascade cancel children if requested + if p_cascade_children then + perform durable.cascade_cancel_children(p_queue_name, p_task_id); + end if; +end; +$$; + -- Recursively cancels all children of a parent task. -- Used when a parent task fails or is cancelled to cascade the cancellation. create function durable.cascade_cancel_children ( @@ -649,13 +782,12 @@ create function durable.cascade_cancel_children ( as $$ declare v_child_id uuid; - v_child_state text; v_now timestamptz := durable.current_time(); begin -- Find all children of this parent that are not in terminal state - for v_child_id, v_child_state in + for v_child_id in execute format( - 'select task_id, state + 'select task_id from durable.%I where parent_task_id = $1 and state not in (''completed'', ''failed'', ''cancelled'') @@ -684,21 +816,8 @@ begin 'r_' || p_queue_name ) using v_child_id; - -- Delete wait registrations - execute format( - 'delete from durable.%I where task_id = $1', - 'w_' || p_queue_name - ) using v_child_id; - - -- Emit cancellation event so parent's join() can receive it - perform durable.emit_event( - p_queue_name, - '$child:' || v_child_id::text, - jsonb_build_object('status', 'cancelled') - ); - - -- Recursively cancel grandchildren - perform durable.cascade_cancel_children(p_queue_name, v_child_id); + -- Cleanup: delete waiters, emit event, and recursively cascade to grandchildren + perform durable.cleanup_task_terminal(p_queue_name, v_child_id, 'cancelled', null, true); end loop; end; $$; @@ -714,6 +833,7 @@ create function durable.fail_run ( as $$ declare v_task_id uuid; + v_run_task_id uuid; v_attempt integer; v_retry_strategy jsonb; v_max_attempts integer; @@ -735,36 +855,54 @@ declare v_recorded_attempt integer; v_last_attempt_run uuid := p_run_id; v_cancelled_at timestamptz := null; - v_parent_task_id uuid; begin - -- find the run to fail + -- Lock task first to keep a consistent task -> run lock order. + -- Find task for this run (no lock). execute format( - 'select r.task_id, r.attempt - from durable.%I r - where r.run_id = $1 - and r.state in (''running'', ''sleeping'') - for update', + 'select task_id + from durable.%I + where run_id = $1', 'r_' || p_queue_name ) - into v_task_id, v_attempt + into v_task_id using p_run_id; if v_task_id is null then raise exception 'Run "%" cannot be failed in queue "%"', p_run_id, p_queue_name; end if; - -- get the retry strategy and metadata about task + -- Lock task and get retry strategy execute format( - 'select retry_strategy, max_attempts, first_started_at, cancellation, state, parent_task_id + 'select retry_strategy, max_attempts, first_started_at, cancellation, state from durable.%I where task_id = $1 for update', 't_' || p_queue_name ) - into v_retry_strategy, v_max_attempts, v_first_started, v_cancellation, v_task_state, v_parent_task_id + into v_retry_strategy, v_max_attempts, v_first_started, v_cancellation, v_task_state using v_task_id; - -- actually fail the run + -- Lock run after task and ensure it's still eligible + execute format( + 'select task_id, attempt + from durable.%I + where run_id = $1 + and state in (''running'', ''sleeping'') + for update', + 'r_' || p_queue_name + ) + into v_run_task_id, v_attempt + using p_run_id; + + if v_run_task_id is null then + raise exception 'Run "%" cannot be failed in queue "%"', p_run_id, p_queue_name; + end if; + + if v_run_task_id <> v_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, v_task_id; + end if; + + -- Actually fail the run execute format( 'update durable.%I set state = ''failed'', @@ -779,7 +917,7 @@ begin v_task_state_after := 'failed'; v_recorded_attempt := v_attempt; - -- compute the next retry time + -- Compute the next retry time if v_max_attempts is null or v_next_attempt <= v_max_attempts then if p_retry_at is not null then v_next_available := p_retry_at; @@ -815,7 +953,7 @@ begin end if; end if; - -- set up the new run if not cancelling + -- Set up the new run if not cancelling if not v_task_cancel then v_task_state_after := case when v_next_available > v_now then 'sleeping' else 'pending' end; v_new_run_id := durable.portable_uuidv7(); @@ -849,24 +987,21 @@ begin v_task_state_after ) using v_task_id, v_task_state_after, v_recorded_attempt, v_last_attempt_run, v_cancelled_at; + -- Delete wait registrations for this run execute format( 'delete from durable.%I where run_id = $1', 'w_' || p_queue_name ) using p_run_id; - -- If task reached terminal failure state (failed or cancelled), emit event and cascade cancel + -- If task reached terminal state, cleanup (emit event, cascade cancel) if v_task_state_after in ('failed', 'cancelled') then - -- Cascade cancel all children - perform durable.cascade_cancel_children(p_queue_name, v_task_id); - - -- Emit completion event for parent to join on (only if this is a subtask) - if v_parent_task_id is not null then - perform durable.emit_event( - p_queue_name, - '$child:' || v_task_id::text, - jsonb_build_object('status', v_task_state_after, 'error', p_reason) - ); - end if; + perform durable.cleanup_task_terminal( + p_queue_name, + v_task_id, + v_task_state_after, + jsonb_build_object('error', p_reason), + true -- cascade cancel children + ); end if; end; $$; @@ -1081,6 +1216,7 @@ create function durable.await_event ( as $$ declare v_run_state text; + v_run_task_id uuid; v_existing_payload jsonb; v_event_payload jsonb; v_checkpoint_payload jsonb; @@ -1123,25 +1259,39 @@ begin return; end if; - -- let's get the run state, any existing event payload and wake event name + -- Lock task first to keep a consistent task -> run lock order. execute format( - 'select r.state, r.event_payload, r.wake_event, t.state - from durable.%I r - join durable.%I t on t.task_id = r.task_id - where r.run_id = $1 - for update', - 'r_' || p_queue_name, + 'select state from durable.%I where task_id = $1 for update', 't_' || p_queue_name ) - into v_run_state, v_existing_payload, v_wake_event, v_task_state + into v_task_state + using p_task_id; + + if v_task_state is null then + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; + end if; + + if v_task_state = 'cancelled' then + raise exception sqlstate 'AB001' using message = 'Task has been cancelled'; + end if; + + -- Lock run after task + execute format( + 'select task_id, state, event_payload, wake_event + from durable.%I + where run_id = $1 + for update', + 'r_' || p_queue_name + ) + into v_run_task_id, v_run_state, v_existing_payload, v_wake_event using p_run_id; if v_run_state is null then raise exception 'Run "%" not found while awaiting event', p_run_id; end if; - if v_task_state = 'cancelled' then - raise exception sqlstate 'AB001' using message = 'Task has been cancelled'; + if v_run_task_id <> p_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, p_task_id; end if; execute format( @@ -1248,6 +1398,7 @@ as $$ declare v_now timestamptz := durable.current_time(); v_payload jsonb := coalesce(p_payload, 'null'::jsonb); + v_inserted_count integer; begin if p_event_name is null or length(trim(p_event_name)) = 0 then raise exception 'event_name must be provided'; @@ -1265,6 +1416,14 @@ begin 'e_' || p_queue_name ) using p_event_name, v_payload, v_now; + get diagnostics v_inserted_count = row_count; + + -- Only wake waiters if we actually inserted (first emit). + -- Subsequent emits are no-ops to maintain consistency. + if v_inserted_count = 0 then + return; + end if; + execute format( 'with expired_waits as ( delete from durable.%1$I w @@ -1279,6 +1438,17 @@ begin where event_name = $1 and (timeout_at is null or timeout_at > $2) ), + -- Lock tasks before updating runs to prevent waking cancelled tasks. + -- Only lock sleeping tasks to avoid interfering with other operations. + -- This prevents waking cancelled tasks (e.g., when cascade_cancel_children + -- is running concurrently). + locked_tasks as ( + select t.task_id + from durable.%4$I t + where t.task_id in (select task_id from affected) + and t.state = ''sleeping'' + for update + ), -- update the run table for all waiting runs so they are pending again updated_runs as ( update durable.%2$I r @@ -1290,6 +1460,7 @@ begin claim_expires_at = null where r.run_id in (select run_id from affected) and r.state = ''sleeping'' + and r.task_id in (select task_id from locked_tasks) returning r.run_id, r.task_id ), -- update checkpoints for all affected tasks/steps so they contain the event payload @@ -1336,16 +1507,15 @@ as $$ declare v_now timestamptz := durable.current_time(); v_task_state text; - v_parent_task_id uuid; begin execute format( - 'select state, parent_task_id + 'select state from durable.%I where task_id = $1 for update', 't_' || p_queue_name ) - into v_task_state, v_parent_task_id + into v_task_state using p_task_id; if v_task_state is null then @@ -1374,22 +1544,8 @@ begin 'r_' || p_queue_name ) using p_task_id; - execute format( - 'delete from durable.%I where task_id = $1', - 'w_' || p_queue_name - ) using p_task_id; - - -- Cascade cancel all children - perform durable.cascade_cancel_children(p_queue_name, p_task_id); - - -- Emit cancellation event for parent to join on (only if this is a subtask) - if v_parent_task_id is not null then - perform durable.emit_event( - p_queue_name, - '$child:' || p_task_id::text, - jsonb_build_object('status', 'cancelled') - ); - end if; + -- Cleanup: delete waiters, emit event, cascade cancel children + perform durable.cleanup_task_terminal(p_queue_name, p_task_id, 'cancelled', null, true); end; $$; diff --git a/tests/common/helpers.rs b/tests/common/helpers.rs index d8302aa..cfd5085 100644 --- a/tests/common/helpers.rs +++ b/tests/common/helpers.rs @@ -1,8 +1,31 @@ use chrono::{DateTime, Utc}; +use sqlx::postgres::PgPoolOptions; use sqlx::{AssertSqlSafe, PgPool}; use std::time::Duration; use uuid::Uuid; +// ============================================================================ +// Pool helpers +// ============================================================================ + +/// Create a single-connection pool for tests that use fake_time. +/// +/// PostgreSQL session variables like `durable.fake_now` are scoped to a single +/// connection. When using a multi-connection pool, different connections may +/// have different values for the session variable, causing flaky tests. +/// +/// This function creates a pool with max_connections=1, ensuring all queries +/// run on the same connection and see the same fake_now value. +#[allow(dead_code)] +pub async fn single_conn_pool(pool: &PgPool) -> PgPool { + let connect_options = (*pool.connect_options()).clone(); + PgPoolOptions::new() + .max_connections(1) + .connect_with(connect_options) + .await + .expect("Failed to create single-connection pool") +} + /// Set fake time for deterministic testing. /// Uses the durable.fake_now session variable. #[allow(dead_code)] @@ -245,3 +268,61 @@ pub async fn get_failed_payload( .await?; Ok(result.and_then(|(p,)| p)) } + +// ============================================================================ +// Run inspection helpers (detailed) +// ============================================================================ + +/// Information about a run. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct RunInfo { + pub run_id: Uuid, + pub task_id: Uuid, + pub attempt: i32, + pub state: String, + pub available_at: DateTime, + pub claim_expires_at: Option>, + pub failure_reason: Option, +} + +/// Get all runs for a task, ordered by attempt number. +#[allow(dead_code)] +pub async fn get_runs_for_task( + pool: &PgPool, + queue: &str, + task_id: Uuid, +) -> sqlx::Result> { + let query = AssertSqlSafe(format!( + "SELECT run_id, task_id, attempt, state, available_at, claim_expires_at, failure_reason + FROM durable.r_{} WHERE task_id = $1 ORDER BY attempt", + queue + )); + type RunInfoRow = ( + Uuid, + Uuid, + i32, + String, + DateTime, + Option>, + Option, + ); + let rows: Vec = sqlx::query_as(query).bind(task_id).fetch_all(pool).await?; + + Ok(rows + .into_iter() + .map( + |(run_id, task_id, attempt, state, available_at, claim_expires_at, failure_reason)| { + RunInfo { + run_id, + task_id, + attempt, + state, + available_at, + claim_expires_at, + failure_reason, + } + }, + ) + .collect()) +} diff --git a/tests/fanout_test.rs b/tests/fanout_test.rs index ec039f9..53ccee4 100644 --- a/tests/fanout_test.rs +++ b/tests/fanout_test.rs @@ -350,6 +350,99 @@ async fn test_cascade_cancel_when_parent_cancelled(pool: PgPool) -> sqlx::Result Ok(()) } +/// Test that cascade cancellation happens when a parent task is auto-cancelled +/// due to max_duration expiring while waiting for a child. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_cascade_cancel_when_parent_auto_cancelled_by_max_duration( + pool: PgPool, +) -> sqlx::Result<()> { + use common::helpers::{advance_time, set_fake_time, single_conn_pool, wait_for_task_terminal}; + use durable::{CancellationPolicy, SpawnOptions}; + + // Use single-conn pool for fake_time + let test_pool = single_conn_pool(&pool).await; + + let client = create_client(test_pool.clone(), "fanout_auto_cancel").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + client.register::().await.unwrap(); + + let start_time = chrono::Utc::now(); + set_fake_time(&test_pool, start_time).await?; + + // Spawn parent with max_duration of 2 seconds + // The child will sleep for 10 seconds, so the parent will auto-cancel + let spawn_result = client + .spawn_with_options::( + SpawnSlowChildParams { + child_sleep_ms: 10000, // 10 seconds + }, + SpawnOptions { + cancellation: Some(CancellationPolicy { + max_delay: None, + max_duration: Some(2), // 2 seconds max duration + }), + ..Default::default() + }, + ) + .await + .expect("Failed to spawn task"); + + // Start worker + let worker = client + .start_worker(WorkerOptions { + poll_interval: 0.05, + claim_timeout: 30, + concurrency: 2, + ..Default::default() + }) + .await; + + // Wait for parent to spawn child + tokio::time::sleep(Duration::from_millis(500)).await; + + // Verify child was spawned + let query = "SELECT task_id FROM durable.t_fanout_auto_cancel WHERE parent_task_id = $1"; + let child_ids: Vec<(uuid::Uuid,)> = sqlx::query_as(query) + .bind(spawn_result.task_id) + .fetch_all(&test_pool) + .await?; + + assert!(!child_ids.is_empty(), "Child should have been spawned"); + + // Advance time past max_duration + advance_time(&test_pool, 3).await?; + + // Give time for auto-cancellation to trigger + tokio::time::sleep(Duration::from_millis(500)).await; + + // Wait for parent to reach terminal state + let terminal = wait_for_task_terminal( + &test_pool, + "fanout_auto_cancel", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + worker.shutdown().await; + + // Parent should be cancelled due to max_duration + assert_eq!( + terminal, + Some("cancelled".to_string()), + "Parent should be auto-cancelled due to max_duration" + ); + + // Child should also be cancelled (cascade) + let child_state = get_task_state(&pool, "fanout_auto_cancel", child_ids[0].0).await; + assert_eq!( + child_state, "cancelled", + "Child should be cascade cancelled when parent is auto-cancelled" + ); + + Ok(()) +} + // ============================================================================ // spawn_by_name Tests // ============================================================================ diff --git a/tests/lock_order_test.rs b/tests/lock_order_test.rs new file mode 100644 index 0000000..45a0d6e --- /dev/null +++ b/tests/lock_order_test.rs @@ -0,0 +1,365 @@ +#![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] + +//! Tests for lock ordering in SQL functions. +//! +//! These tests verify that functions that touch both tasks and runs +//! acquire locks in a consistent order (task first, then run) to prevent +//! deadlocks. +//! +//! The lock ordering pattern is: +//! 1. Find the task_id (no lock) +//! 2. Lock the task FOR UPDATE +//! 3. Lock the run FOR UPDATE +//! +//! Without this ordering, two concurrent transactions could deadlock: +//! - Transaction A: locks run, waits for task +//! - Transaction B: locks task, waits for run + +mod common; + +use common::helpers::{get_task_state, single_conn_pool, wait_for_task_terminal}; +use common::tasks::{ + DoubleParams, DoubleTask, FailingParams, FailingTask, SleepParams, SleepingTask, +}; +use durable::{Durable, MIGRATOR, RetryStrategy, SpawnOptions, WorkerOptions}; +use sqlx::{AssertSqlSafe, PgPool}; +use std::time::Duration; + +async fn create_client(pool: PgPool, queue_name: &str) -> Durable { + Durable::builder() + .pool(pool) + .queue_name(queue_name) + .build() + .await + .expect("Failed to create Durable client") +} + +// ============================================================================ +// Lock Ordering Tests +// ============================================================================ + +/// Test that complete_run works correctly with the lock ordering. +/// Completes a task and verifies the task reaches completed state. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_complete_run_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "lock_complete").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let spawn_result = client + .spawn::(DoubleParams { value: 21 }) + .await + .expect("Failed to spawn task"); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: 0.05, + claim_timeout: 30, + ..Default::default() + }) + .await; + + let terminal = wait_for_task_terminal( + &pool, + "lock_complete", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + worker.shutdown().await; + + assert_eq!(terminal, Some("completed".to_string())); + + Ok(()) +} + +/// Test that fail_run works correctly with the lock ordering. +/// Fails a task and verifies it eventually reaches failed state after retries. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_fail_run_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "lock_fail").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let spawn_result = client + .spawn_with_options::( + FailingParams { + error_message: "intentional failure".to_string(), + }, + SpawnOptions { + retry_strategy: Some(RetryStrategy::Fixed { base_seconds: 0 }), + max_attempts: Some(2), + ..Default::default() + }, + ) + .await + .expect("Failed to spawn task"); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: 0.05, + claim_timeout: 30, + ..Default::default() + }) + .await; + + let terminal = wait_for_task_terminal( + &pool, + "lock_fail", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + worker.shutdown().await; + + assert_eq!(terminal, Some("failed".to_string())); + + Ok(()) +} + +/// Test that sleep_for works correctly with the lock ordering. +/// Sleeps and verifies the task suspends and then completes. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_sleep_for_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "lock_sleep").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let spawn_result = client + .spawn::(SleepParams { seconds: 1 }) + .await + .expect("Failed to spawn task"); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: 0.05, + claim_timeout: 30, + ..Default::default() + }) + .await; + + // Wait for task to complete + let terminal = wait_for_task_terminal( + &pool, + "lock_sleep", + spawn_result.task_id, + Duration::from_secs(10), + ) + .await?; + worker.shutdown().await; + + assert_eq!(terminal, Some("completed".to_string())); + + Ok(()) +} + +/// Test concurrent complete and cancel operations don't deadlock. +/// This would deadlock if lock ordering were inconsistent. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_concurrent_complete_and_cancel(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "lock_conc_cc").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + // Spawn several tasks + let mut task_ids = Vec::new(); + for _ in 0..5 { + let spawn_result = client + .spawn::(SleepParams { seconds: 1 }) + .await + .expect("Failed to spawn task"); + task_ids.push(spawn_result.task_id); + } + + let worker = client + .start_worker(WorkerOptions { + poll_interval: 0.05, + claim_timeout: 30, + concurrency: 5, + ..Default::default() + }) + .await; + + // Let tasks start + tokio::time::sleep(Duration::from_millis(50)).await; + + // Cancel some tasks while others are completing + for (i, task_id) in task_ids.iter().enumerate() { + if i % 2 == 0 { + // Ignore errors - task might already be completed + let _ = client.cancel_task(*task_id, None).await; + } + } + + // Wait for all tasks to reach terminal state + for task_id in &task_ids { + let _ = + wait_for_task_terminal(&pool, "lock_conc_cc", *task_id, Duration::from_secs(5)).await?; + } + + worker.shutdown().await; + + // All tasks should be in terminal state (completed or cancelled) + for task_id in &task_ids { + let state = get_task_state(&pool, "lock_conc_cc", *task_id).await?; + assert!( + state == Some("completed".to_string()) || state == Some("cancelled".to_string()), + "Task should be terminal, got {:?}", + state + ); + } + + Ok(()) +} + +/// Test that emit_event wakes sleeping tasks correctly with lock ordering. +/// This tests the emit_event function's locked_tasks CTE. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_emit_event_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { + use common::tasks::{EventWaitParams, EventWaitingTask}; + + let client = create_client(pool.clone(), "lock_emit").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let spawn_result = client + .spawn::(EventWaitParams { + event_name: "test_event".to_string(), + timeout_seconds: Some(30), + }) + .await + .expect("Failed to spawn task"); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: 0.05, + claim_timeout: 30, + ..Default::default() + }) + .await; + + // Wait for task to start sleeping (awaiting event) + tokio::time::sleep(Duration::from_millis(200)).await; + + let state = get_task_state(&pool, "lock_emit", spawn_result.task_id).await?; + assert_eq!( + state, + Some("sleeping".to_string()), + "Task should be sleeping waiting for event" + ); + + // Emit the event + let emit_query = AssertSqlSafe( + "SELECT durable.emit_event('lock_emit', 'test_event', '\"hello\"'::jsonb)".to_string(), + ); + sqlx::query(emit_query).execute(&pool).await?; + + // Wait for task to complete + let terminal = wait_for_task_terminal( + &pool, + "lock_emit", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + worker.shutdown().await; + + assert_eq!(terminal, Some("completed".to_string())); + + Ok(()) +} + +/// Test concurrent emit and cancel operations. +/// This tests that emit_event's locked_tasks CTE properly handles +/// tasks being cancelled concurrently. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_concurrent_emit_and_cancel(pool: PgPool) -> sqlx::Result<()> { + use common::tasks::{EventWaitParams, EventWaitingTask}; + + // Use single-conn pool to ensure deterministic event emission + let test_pool = single_conn_pool(&pool).await; + + let client = create_client(pool.clone(), "lock_emit_cancel").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + // Spawn multiple tasks waiting for the same event + let mut task_ids = Vec::new(); + for _ in 0..3 { + let spawn_result = client + .spawn::(EventWaitParams { + event_name: "shared_event".to_string(), + timeout_seconds: Some(30), + }) + .await + .expect("Failed to spawn task"); + task_ids.push(spawn_result.task_id); + } + + let worker = client + .start_worker(WorkerOptions { + poll_interval: 0.05, + claim_timeout: 30, + concurrency: 5, + ..Default::default() + }) + .await; + + // Wait for all tasks to start sleeping + tokio::time::sleep(Duration::from_millis(500)).await; + + for task_id in &task_ids { + let state = get_task_state(&pool, "lock_emit_cancel", *task_id).await?; + assert_eq!( + state, + Some("sleeping".to_string()), + "Task should be sleeping" + ); + } + + // Cancel one task while emitting the event + let cancel_task_id = task_ids[0]; + let emit_handle = tokio::spawn({ + let test_pool = test_pool.clone(); + async move { + let emit_query = AssertSqlSafe( + "SELECT durable.emit_event('lock_emit_cancel', 'shared_event', '\"wakeup\"'::jsonb)" + .to_string(), + ); + sqlx::query(emit_query).execute(&test_pool).await + } + }); + + // Cancel concurrently + let _ = client.cancel_task(cancel_task_id, None).await; + + // Wait for emit to complete + emit_handle.await.unwrap()?; + + // Wait for all tasks to reach terminal state + for task_id in &task_ids { + let _ = wait_for_task_terminal(&pool, "lock_emit_cancel", *task_id, Duration::from_secs(5)) + .await?; + } + + worker.shutdown().await; + + // The cancelled task should be cancelled, others should be completed + let cancelled_state = get_task_state(&pool, "lock_emit_cancel", cancel_task_id).await?; + assert_eq!( + cancelled_state, + Some("cancelled".to_string()), + "Cancelled task should be cancelled" + ); + + for task_id in &task_ids[1..] { + let state = get_task_state(&pool, "lock_emit_cancel", *task_id).await?; + assert_eq!( + state, + Some("completed".to_string()), + "Non-cancelled task should be completed" + ); + } + + Ok(()) +} From e6eb86ccd93cd46cd4872fbf0a240e8dffba8b5d Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Sat, 20 Dec 2025 19:40:09 -0500 Subject: [PATCH 3/7] pass task ID to more endpoints so they can lock in order efficiently --- .../20251202002136_initial_setup.sql | 63 ++++++------------- src/worker.rs | 29 ++++++--- 2 files changed, 41 insertions(+), 51 deletions(-) diff --git a/src/postgres/migrations/20251202002136_initial_setup.sql b/src/postgres/migrations/20251202002136_initial_setup.sql index f2971fc..b00a09d 100644 --- a/src/postgres/migrations/20251202002136_initial_setup.sql +++ b/src/postgres/migrations/20251202002136_initial_setup.sql @@ -420,6 +420,7 @@ begin loop perform durable.fail_run( p_queue_name, + v_expired_run.task_id, v_expired_run.run_id, jsonb_strip_nulls(jsonb_build_object( 'name', '$ClaimTimeout', @@ -522,6 +523,7 @@ $$; -- Marks a run as completed create function durable.complete_run ( p_queue_name text, + p_task_id uuid, p_run_id uuid, p_state jsonb default null ) @@ -529,28 +531,12 @@ create function durable.complete_run ( language plpgsql as $$ declare - v_task_id uuid; v_task_id_locked uuid; v_run_task_id uuid; v_state text; v_now timestamptz := durable.current_time(); begin -- Lock task first to keep a consistent task -> run lock order. - -- Find task for this run (no lock). - execute format( - 'select task_id - from durable.%I - where run_id = $1', - 'r_' || p_queue_name - ) - into v_task_id - using p_run_id; - - if v_task_id is null then - raise exception 'Run "%" not found in queue "%"', p_run_id, p_queue_name; - end if; - - -- Lock the task execute format( 'select task_id from durable.%I @@ -559,10 +545,10 @@ begin 't_' || p_queue_name ) into v_task_id_locked - using v_task_id; + using p_task_id; if v_task_id_locked is null then - raise exception 'Task "%" not found in queue "%"', v_task_id, p_queue_name; + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; end if; -- Lock the run after the task lock @@ -580,8 +566,8 @@ begin raise exception 'Run "%" not found in queue "%"', p_run_id, p_queue_name; end if; - if v_run_task_id <> v_task_id then - raise exception 'Run "%" does not belong to task "%"', p_run_id, v_task_id; + if v_run_task_id <> p_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, p_task_id; end if; if v_state <> 'running' then @@ -606,12 +592,12 @@ begin last_attempt_run = $3 where task_id = $1', 't_' || p_queue_name - ) using v_task_id, p_state, p_run_id; + ) using p_task_id, p_state, p_run_id; -- Cleanup: delete waiters and emit completion event for parent perform durable.cleanup_task_terminal( p_queue_name, - v_task_id, + p_task_id, 'completed', jsonb_build_object('result', p_state), false -- don't cascade cancel children for completed tasks @@ -824,6 +810,7 @@ $$; create function durable.fail_run ( p_queue_name text, + p_task_id uuid, p_run_id uuid, p_reason jsonb, p_retry_at timestamptz default null @@ -832,7 +819,6 @@ create function durable.fail_run ( language plpgsql as $$ declare - v_task_id uuid; v_run_task_id uuid; v_attempt integer; v_retry_strategy jsonb; @@ -857,21 +843,6 @@ declare v_cancelled_at timestamptz := null; begin -- Lock task first to keep a consistent task -> run lock order. - -- Find task for this run (no lock). - execute format( - 'select task_id - from durable.%I - where run_id = $1', - 'r_' || p_queue_name - ) - into v_task_id - using p_run_id; - - if v_task_id is null then - raise exception 'Run "%" cannot be failed in queue "%"', p_run_id, p_queue_name; - end if; - - -- Lock task and get retry strategy execute format( 'select retry_strategy, max_attempts, first_started_at, cancellation, state from durable.%I @@ -880,7 +851,11 @@ begin 't_' || p_queue_name ) into v_retry_strategy, v_max_attempts, v_first_started, v_cancellation, v_task_state - using v_task_id; + using p_task_id; + + if v_task_state is null then + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; + end if; -- Lock run after task and ensure it's still eligible execute format( @@ -898,8 +873,8 @@ begin raise exception 'Run "%" cannot be failed in queue "%"', p_run_id, p_queue_name; end if; - if v_run_task_id <> v_task_id then - raise exception 'Run "%" does not belong to task "%"', p_run_id, v_task_id; + if v_run_task_id <> p_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, p_task_id; end if; -- Actually fail the run @@ -965,7 +940,7 @@ begin 'r_' || p_queue_name, v_task_state_after ) - using v_new_run_id, v_task_id, v_next_attempt, v_next_available; + using v_new_run_id, p_task_id, v_next_attempt, v_next_available; end if; end if; @@ -985,7 +960,7 @@ begin where task_id = $1', 't_' || p_queue_name, v_task_state_after - ) using v_task_id, v_task_state_after, v_recorded_attempt, v_last_attempt_run, v_cancelled_at; + ) using p_task_id, v_task_state_after, v_recorded_attempt, v_last_attempt_run, v_cancelled_at; -- Delete wait registrations for this run execute format( @@ -997,7 +972,7 @@ begin if v_task_state_after in ('failed', 'cancelled') then perform durable.cleanup_task_terminal( p_queue_name, - v_task_id, + p_task_id, v_task_state_after, jsonb_build_object('error', p_reason), true -- cascade cancel children diff --git a/src/worker.rs b/src/worker.rs index ce6ead9..d99e49b 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -339,7 +339,7 @@ impl Worker { Ok(ctx) => ctx, Err(e) => { tracing::error!("Failed to create task context: {}", e); - Self::fail_run(&pool, &queue_name, task.run_id, &e.into()).await; + Self::fail_run(&pool, &queue_name, task.task_id, task.run_id, &e.into()).await; return; } }; @@ -353,6 +353,7 @@ impl Worker { Self::fail_run( &pool, &queue_name, + task.task_id, task.run_id, &TaskError::Validation { message: format!("Unknown task: {}", task.task_name), @@ -493,7 +494,7 @@ impl Worker { { outcome = "completed"; } - Self::complete_run(&pool, &queue_name, task.run_id, output).await; + Self::complete_run(&pool, &queue_name, task.task_id, task.run_id, output).await; #[cfg(feature = "telemetry")] crate::telemetry::record_task_completed(&queue_name_for_metrics, &task_name); @@ -521,7 +522,7 @@ impl Worker { outcome = "failed"; } tracing::error!("Task {} failed: {}", task_label, e); - Self::fail_run(&pool, &queue_name, task.run_id, e).await; + Self::fail_run(&pool, &queue_name, task.task_id, task.run_id, e).await; #[cfg(feature = "telemetry")] crate::telemetry::record_task_failed( @@ -545,10 +546,17 @@ impl Worker { } } - async fn complete_run(pool: &PgPool, queue_name: &str, run_id: Uuid, result: JsonValue) { - let query = "SELECT durable.complete_run($1, $2, $3)"; + async fn complete_run( + pool: &PgPool, + queue_name: &str, + task_id: Uuid, + run_id: Uuid, + result: JsonValue, + ) { + let query = "SELECT durable.complete_run($1, $2, $3, $4)"; if let Err(e) = sqlx::query(query) .bind(queue_name) + .bind(task_id) .bind(run_id) .bind(&result) .execute(pool) @@ -558,11 +566,18 @@ impl Worker { } } - async fn fail_run(pool: &PgPool, queue_name: &str, run_id: Uuid, error: &TaskError) { + async fn fail_run( + pool: &PgPool, + queue_name: &str, + task_id: Uuid, + run_id: Uuid, + error: &TaskError, + ) { let error_json = serialize_task_error(error); - let query = "SELECT durable.fail_run($1, $2, $3, $4)"; + let query = "SELECT durable.fail_run($1, $2, $3, $4, $5)"; if let Err(e) = sqlx::query(query) .bind(queue_name) + .bind(task_id) .bind(run_id) .bind(&error_json) .bind(None::>) From 3b1b099a2bfaaf20ef157e9812570b41657bbe00 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Sat, 20 Dec 2025 20:09:05 -0500 Subject: [PATCH 4/7] make checkpoint writing atomically check the version --- .../20251202002136_initial_setup.sql | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/src/postgres/migrations/20251202002136_initial_setup.sql b/src/postgres/migrations/20251202002136_initial_setup.sql index b00a09d..a053015 100644 --- a/src/postgres/migrations/20251202002136_initial_setup.sql +++ b/src/postgres/migrations/20251202002136_initial_setup.sql @@ -999,8 +999,6 @@ as $$ declare v_now timestamptz := durable.current_time(); v_new_attempt integer; - v_existing_attempt integer; - v_existing_owner uuid; v_task_state text; begin if p_step_name is null or length(trim(p_step_name)) = 0 then @@ -1042,29 +1040,22 @@ begin end if; execute format( - 'select c.owner_run_id, - r.attempt - from durable.%I c - left join durable.%I r on r.run_id = c.owner_run_id - where c.task_id = $1 - and c.checkpoint_name = $2', + 'insert into durable.%I (task_id, checkpoint_name, state, owner_run_id, updated_at) + values ($1, $2, $3, $4, $5) + on conflict (task_id, checkpoint_name) + do update set state = excluded.state, + owner_run_id = excluded.owner_run_id, + updated_at = excluded.updated_at + where $6 >= coalesce( + (select r.attempt + from durable.%I r + where r.run_id = durable.%I.owner_run_id), + $6 + )', 'c_' || p_queue_name, - 'r_' || p_queue_name - ) - into v_existing_owner, v_existing_attempt - using p_task_id, p_step_name; - - if v_existing_owner is null or v_existing_attempt is null or v_new_attempt >= v_existing_attempt then - execute format( - 'insert into durable.%I (task_id, checkpoint_name, state, owner_run_id, updated_at) - values ($1, $2, $3, $4, $5) - on conflict (task_id, checkpoint_name) - do update set state = excluded.state, - owner_run_id = excluded.owner_run_id, - updated_at = excluded.updated_at', - 'c_' || p_queue_name - ) using p_task_id, p_step_name, p_state, p_owner_run, v_now; - end if; + 'r_' || p_queue_name, + 'c_' || p_queue_name + ) using p_task_id, p_step_name, p_state, p_owner_run, v_now, v_new_attempt; end; $$; From 4edacc2a3c274460a56e399837219bfca7b5fdc9 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Wed, 24 Dec 2025 11:56:36 -0500 Subject: [PATCH 5/7] added test that covers lock order case for events / claim task --- tests/lock_order_test.rs | 137 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 135 insertions(+), 2 deletions(-) diff --git a/tests/lock_order_test.rs b/tests/lock_order_test.rs index 45a0d6e..e773359 100644 --- a/tests/lock_order_test.rs +++ b/tests/lock_order_test.rs @@ -22,8 +22,10 @@ use common::tasks::{ DoubleParams, DoubleTask, FailingParams, FailingTask, SleepParams, SleepingTask, }; use durable::{Durable, MIGRATOR, RetryStrategy, SpawnOptions, WorkerOptions}; -use sqlx::{AssertSqlSafe, PgPool}; -use std::time::Duration; +use sqlx::postgres::{PgConnectOptions, PgConnection}; +use sqlx::{AssertSqlSafe, Connection, PgPool}; +use std::time::{Duration, Instant}; +use uuid::Uuid; async fn create_client(pool: PgPool, queue_name: &str) -> Durable { Durable::builder() @@ -363,3 +365,134 @@ async fn test_concurrent_emit_and_cancel(pool: PgPool) -> sqlx::Result<()> { Ok(()) } + +/// Regression test: claim_task uses SKIP LOCKED to avoid deadlock with emit_event. +/// +/// emit_event locks tasks first (locked_tasks CTE with FOR UPDATE), then updates runs. +/// claim_task joins runs+tasks with FOR UPDATE SKIP LOCKED. +/// +/// This test verifies that when a task is locked (simulating emit_event holding the lock), +/// claim_task skips that task instead of blocking (which would cause deadlock). +/// +/// We make the test deterministic by: +/// - Creating a task and making it claimable (pending state) +/// - Holding a FOR UPDATE lock on the task row in a separate connection +/// - Calling claim_task - it should complete immediately with 0 results (not block) +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_claim_task_skips_locked_tasks_no_deadlock(pool: PgPool) -> sqlx::Result<()> { + let queue = "skip_locked_test"; + + // Setup: Create queue and spawn a task + sqlx::query("SELECT durable.create_queue($1)") + .bind(queue) + .execute(&pool) + .await?; + + let (task_id, run_id): (Uuid, Uuid) = + sqlx::query_as("SELECT task_id, run_id FROM durable.spawn_task($1, $2, $3, $4)") + .bind(queue) + .bind("test-task") + .bind(serde_json::json!({})) + .bind(serde_json::json!({})) + .fetch_one(&pool) + .await?; + + // Verify task is pending and claimable + let state = get_task_state(&pool, queue, task_id).await?; + assert_eq!(state, Some("pending".to_string())); + + // Get connect options from pool for creating separate connections + let connect_opts: PgConnectOptions = (*pool.connect_options()).clone(); + + // Open lock connection and hold FOR UPDATE lock on the task row + // This simulates emit_event's locked_tasks CTE holding the lock mid-transaction + let lock_opts = connect_opts.clone().application_name("durable-task-locker"); + let mut lock_conn = PgConnection::connect_with(&lock_opts).await?; + + sqlx::query("BEGIN").execute(&mut lock_conn).await?; + sqlx::query(AssertSqlSafe(format!( + "SELECT 1 FROM durable.t_{} WHERE task_id = $1 FOR UPDATE", + queue + ))) + .bind(task_id) + .execute(&mut lock_conn) + .await?; + + // Wait until the lock is confirmed held by checking pg_stat_activity + let deadline = Instant::now() + Duration::from_secs(5); + loop { + let row: Option<(String,)> = sqlx::query_as( + "SELECT state FROM pg_stat_activity WHERE application_name = $1", + ) + .bind("durable-task-locker") + .fetch_optional(&pool) + .await?; + + if let Some((ref state,)) = row + && state == "idle in transaction" + { + break; + } + assert!( + Instant::now() < deadline, + "Lock connection did not reach expected state" + ); + tokio::time::sleep(Duration::from_millis(10)).await; + } + + // Now call claim_task from another connection + // If SKIP LOCKED works correctly, it should complete immediately with 0 results + // If SKIP LOCKED didn't apply to the task table, it would block and timeout + let claim_opts = connect_opts.clone().application_name("durable-claimer"); + let mut claim_conn = PgConnection::connect_with(&claim_opts).await?; + + // Set a short statement timeout - if claim_task blocks, it will fail + sqlx::query("SET statement_timeout = '500ms'") + .execute(&mut claim_conn) + .await?; + + let claim_result: Vec<(Uuid,)> = + sqlx::query_as("SELECT run_id FROM durable.claim_task($1, $2, $3, $4)") + .bind(queue) + .bind("worker") + .bind(60) + .bind(1) + .fetch_all(&mut claim_conn) + .await?; + + // claim_task should have completed (not timed out) and returned 0 results + // because the task was locked and SKIP LOCKED caused it to be skipped + assert!( + claim_result.is_empty(), + "claim_task should skip locked task, but got {} results", + claim_result.len() + ); + + // Reset statement timeout + sqlx::query("SET statement_timeout = 0") + .execute(&mut claim_conn) + .await?; + + // Release the lock + sqlx::query("ROLLBACK").execute(&mut lock_conn).await?; + drop(lock_conn); + + // Now claim_task should be able to claim the task + let claim_result2: Vec<(Uuid,)> = + sqlx::query_as("SELECT run_id FROM durable.claim_task($1, $2, $3, $4)") + .bind(queue) + .bind("worker") + .bind(60) + .bind(1) + .fetch_all(&mut claim_conn) + .await?; + + assert_eq!( + claim_result2.len(), + 1, + "claim_task should claim the task after lock is released" + ); + assert_eq!(claim_result2[0].0, run_id); + + Ok(()) +} From 2a2d4c1383524f6974a0088420dbdda440782dd5 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Wed, 24 Dec 2025 13:44:35 -0500 Subject: [PATCH 6/7] fmtted --- tests/lock_order_test.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/lock_order_test.rs b/tests/lock_order_test.rs index e773359..f2f84fd 100644 --- a/tests/lock_order_test.rs +++ b/tests/lock_order_test.rs @@ -421,12 +421,11 @@ async fn test_claim_task_skips_locked_tasks_no_deadlock(pool: PgPool) -> sqlx::R // Wait until the lock is confirmed held by checking pg_stat_activity let deadline = Instant::now() + Duration::from_secs(5); loop { - let row: Option<(String,)> = sqlx::query_as( - "SELECT state FROM pg_stat_activity WHERE application_name = $1", - ) - .bind("durable-task-locker") - .fetch_optional(&pool) - .await?; + let row: Option<(String,)> = + sqlx::query_as("SELECT state FROM pg_stat_activity WHERE application_name = $1") + .bind("durable-task-locker") + .fetch_optional(&pool) + .await?; if let Some((ref state,)) = row && state == "idle in transaction" From 72beb59257cf80ba9476ef7d40fadfa29e42b615 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Wed, 24 Dec 2025 14:26:16 -0500 Subject: [PATCH 7/7] fixed fmt --- tests/lock_order_test.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/lock_order_test.rs b/tests/lock_order_test.rs index 4498c6b..9056154 100644 --- a/tests/lock_order_test.rs +++ b/tests/lock_order_test.rs @@ -90,7 +90,9 @@ async fn test_fail_run_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { error_message: "intentional failure".to_string(), }, SpawnOptions { - retry_strategy: Some(RetryStrategy::Fixed { base_delay: Duration::from_secs(0) }), + retry_strategy: Some(RetryStrategy::Fixed { + base_delay: Duration::from_secs(0), + }), max_attempts: Some(2), ..Default::default() },