diff --git a/src/task_processor/admin.py b/src/task_processor/admin.py index 6284a768..30c30064 100644 --- a/src/task_processor/admin.py +++ b/src/task_processor/admin.py @@ -16,7 +16,11 @@ class RecurringTaskAdmin(admin.ModelAdmin[RecurringTask]): "last_run_status", "last_run_finished_at", "is_locked", + "is_disabled", + "num_consecutive_failures", ) + list_filter = ("is_disabled",) + actions = ("unlock", "enable") readonly_fields = ("args", "kwargs") def last_run_status(self, instance: RecurringTask) -> str | None: @@ -36,3 +40,11 @@ def unlock( queryset: QuerySet[RecurringTask], ) -> None: queryset.update(is_locked=False) + + @admin.action(description="Re-enable selected tasks") + def enable( + self, + request: HttpRequest, + queryset: QuerySet[RecurringTask], + ) -> None: + queryset.update(is_disabled=False, num_consecutive_failures=0) diff --git a/src/task_processor/migrations/0015_add_is_disabled.py b/src/task_processor/migrations/0015_add_is_disabled.py new file mode 100644 index 00000000..482ac03c --- /dev/null +++ b/src/task_processor/migrations/0015_add_is_disabled.py @@ -0,0 +1,37 @@ +import os + +from django.db import migrations, models + +from common.migrations.helpers import PostgresOnlyRunSQL + + +class Migration(migrations.Migration): + + dependencies = [ + ("task_processor", "0014_add_trace_context"), + ] + + operations = [ + migrations.AddField( + model_name="recurringtask", + name="is_disabled", + field=models.BooleanField(default=False), + ), + migrations.AddField( + model_name="recurringtask", + name="num_consecutive_failures", + field=models.IntegerField(default=0), + ), + PostgresOnlyRunSQL.from_sql_file( + os.path.join( + os.path.dirname(__file__), + "sql", + "0015_get_recurringtasks_to_process.sql", + ), + reverse_sql=os.path.join( + os.path.dirname(__file__), + "sql", + "0013_get_recurringtasks_to_process.sql", + ), + ), + ] diff --git a/src/task_processor/migrations/sql/0015_get_recurringtasks_to_process.sql b/src/task_processor/migrations/sql/0015_get_recurringtasks_to_process.sql new file mode 100644 index 00000000..bfb078a0 --- /dev/null +++ b/src/task_processor/migrations/sql/0015_get_recurringtasks_to_process.sql @@ -0,0 +1,33 @@ +CREATE OR REPLACE FUNCTION get_recurringtasks_to_process() +RETURNS SETOF task_processor_recurringtask AS $$ +DECLARE + row_to_return task_processor_recurringtask; +BEGIN + -- Select the tasks that needs to be processed + FOR row_to_return IN + SELECT * + FROM task_processor_recurringtask + -- Skip disabled tasks; add one minute to the timeout as a grace period for overhead + WHERE is_disabled = FALSE + AND (is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout + INTERVAL '1 minute')) + ORDER BY last_picked_at NULLS FIRST + LIMIT 1 + -- Select for update to ensure that no other workers can select these tasks while in this transaction block + FOR UPDATE SKIP LOCKED + LOOP + -- Lock every selected task(by updating `is_locked` to true) + UPDATE task_processor_recurringtask + -- Lock this row by setting is_locked True, so that no other workers can select these tasks after this + -- transaction is complete (but the tasks are still being executed by the current worker) + SET is_locked = TRUE, locked_at = NOW(), last_picked_at = NOW() + WHERE id = row_to_return.id; + -- If we don't explicitly update the columns here, the client will receive a row + -- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`. + row_to_return.is_locked := TRUE; + row_to_return.locked_at := NOW(); + RETURN NEXT row_to_return; + END LOOP; + + RETURN; +END; +$$ LANGUAGE plpgsql diff --git a/src/task_processor/models.py b/src/task_processor/models.py index de9069ba..0992fc94 100644 --- a/src/task_processor/models.py +++ b/src/task_processor/models.py @@ -154,6 +154,8 @@ def mark_success(self) -> None: class RecurringTask(AbstractBaseTask): + MAX_CONSECUTIVE_FAILURES = 4 + run_every = models.DurationField() first_run_time = models.TimeField(blank=True, null=True) @@ -161,6 +163,8 @@ class RecurringTask(AbstractBaseTask): timeout = models.DurationField(default=timedelta(minutes=30)) last_picked_at = models.DateTimeField(blank=True, null=True) + is_disabled = models.BooleanField(default=False) + num_consecutive_failures = models.IntegerField(default=0) objects: RecurringTaskManager = RecurringTaskManager() class Meta: @@ -196,6 +200,21 @@ def reconcile_abandoned_run(self) -> None: abandoned_run.error_details, ) + def mark_failure(self) -> None: + super().mark_failure() + self.num_consecutive_failures += 1 + if self.num_consecutive_failures >= self.MAX_CONSECUTIVE_FAILURES: + self.is_disabled = True + logger.error( + "Recurring task '%s' auto-disabled after %d consecutive failures", + self.task_identifier, + self.num_consecutive_failures, + ) + + def mark_success(self) -> None: + super().mark_success() + self.num_consecutive_failures = 0 + @property def should_execute(self) -> bool: now = timezone.now() diff --git a/src/task_processor/processor.py b/src/task_processor/processor.py index 114b71d0..2d1ea321 100644 --- a/src/task_processor/processor.py +++ b/src/task_processor/processor.py @@ -108,7 +108,15 @@ def run_recurring_task(database: str) -> RecurringTaskRun | None: else: task.unlock() - task.save(using=database, update_fields=["is_locked", "locked_at"]) + task.save( + using=database, + update_fields=[ + "is_locked", + "locked_at", + "is_disabled", + "num_consecutive_failures", + ], + ) if task_run: task_run.save(using=database) diff --git a/tests/unit/task_processor/test_unit_task_processor_processor.py b/tests/unit/task_processor/test_unit_task_processor_processor.py index 953fccd9..8fdd2aba 100644 --- a/tests/unit/task_processor/test_unit_task_processor_processor.py +++ b/tests/unit/task_processor/test_unit_task_processor_processor.py @@ -629,6 +629,102 @@ def _raise_exception(organisation_name: str) -> None: assert task_run.error_details is not None +@pytest.mark.multi_database +@pytest.mark.task_processor_mode +def test_run_recurring_task__disabled_task__not_picked_up( + current_database: str, +) -> None: + # Given + @register_recurring_task(run_every=timedelta(seconds=1)) + def _dummy_recurring_task() -> None: ... + + initialise() + + task = RecurringTask.objects.using(current_database).get( + task_identifier="test_unit_task_processor_processor._dummy_recurring_task", + ) + task.is_disabled = True + task.save(using=current_database) + + # When + task_run = run_recurring_task(current_database) + + # Then + assert task_run is None + assert ( + RecurringTaskRun.objects.using(current_database).filter(task=task).count() == 0 + ) + + +@pytest.mark.multi_database(transaction=True) +@pytest.mark.task_processor_mode +def test_run_recurring_task__four_consecutive_failures__auto_disables( + current_database: str, +) -> None: + # Given - a task that always fails + task_identifier = "test_unit_task_processor_processor._auto_disable_raise_exception" + + @register_recurring_task(run_every=timedelta(seconds=1)) + def _auto_disable_raise_exception() -> None: + raise RuntimeError("test exception") + + initialise() + + task = RecurringTask.objects.using(current_database).get( + task_identifier=task_identifier, + ) + + # When - we run the failing task 4 times + for _ in range(RecurringTask.MAX_CONSECUTIVE_FAILURES): + run_recurring_task(current_database) + + # Then - the task is disabled and the counter reflects every failure + task.refresh_from_db(using=current_database) + assert task.is_disabled is True + assert task.num_consecutive_failures == RecurringTask.MAX_CONSECUTIVE_FAILURES + assert ( + RecurringTaskRun.objects.using(current_database).filter(task=task).count() + == RecurringTask.MAX_CONSECUTIVE_FAILURES + ) + + # And a subsequent pickup attempt is skipped at the SQL layer + assert run_recurring_task(current_database) is None + assert ( + RecurringTaskRun.objects.using(current_database).filter(task=task).count() + == RecurringTask.MAX_CONSECUTIVE_FAILURES + ) + + +@pytest.mark.multi_database(transaction=True) +@pytest.mark.task_processor_mode +def test_run_recurring_task__success__resets_consecutive_failures( + current_database: str, +) -> None: + # Given - a registered task with prior failures recorded on the row + @register_recurring_task(run_every=timedelta(seconds=1)) + def _dummy_recurring_task() -> None: + cache.set(DEFAULT_CACHE_KEY, DEFAULT_CACHE_VALUE) + + initialise() + + task = RecurringTask.objects.using(current_database).get( + task_identifier="test_unit_task_processor_processor._dummy_recurring_task", + ) + task.num_consecutive_failures = 2 + task.save(using=current_database) + + # When - the task runs successfully + task_run = run_recurring_task(current_database) + + # Then - the failure counter is cleared and the task stays enabled + assert task_run is not None + assert task_run.result == TaskResult.SUCCESS.value + + task.refresh_from_db(using=current_database) + assert task.num_consecutive_failures == 0 + assert task.is_disabled is False + + @pytest.mark.multi_database @pytest.mark.task_processor_mode def test_run_task__no_tasks__does_nothing(current_database: str) -> None: