diff --git a/poetry.lock b/poetry.lock index 835f190d..4a5b092a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -900,7 +900,6 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, - {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -1540,6 +1539,18 @@ files = [ dev = ["build", "hatch"] doc = ["sphinx"] +[[package]] +name = "types-python-dateutil" +version = "2.9.0.20250516" +description = "Typing stubs for python-dateutil" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "types_python_dateutil-2.9.0.20250516-py3-none-any.whl", hash = "sha256:2b2b3f57f9c6a61fba26a9c0ffb9ea5681c9b83e69cd897c6b5f668d9c0cab93"}, + {file = "types_python_dateutil-2.9.0.20250516.tar.gz", hash = "sha256:13e80d6c9c47df23ad773d54b2826bd52dbbb41be87c3f339381c1700ad21ee5"}, +] + [[package]] name = "types-pyyaml" version = "6.0.12.20241230" @@ -1679,4 +1690,4 @@ test-tools = ["pyfakefs"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<4.0" -content-hash = "18369d05529bfce4aad0393b73baa2f6d32394acbd3f3965d2a4e4f96da98078" +content-hash = "38e85051fccc3bf8d25b47f20b680f2d81c388381198df7d2e8a153b9896f9f7" diff --git a/pyproject.toml b/pyproject.toml index 9fa33091..5846849c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ pytest-mock = "^3.14.0" ruff = "==0.11.9" setuptools = "^78.1.1" types-simplejson = "^3.20.0.20250326" +types-python-dateutil = "^2.9.0.20250516" [build-system] requires = ["poetry-core"] diff --git a/settings/dev.py b/settings/dev.py index 086a0eb3..73b0520f 100644 --- a/settings/dev.py +++ b/settings/dev.py @@ -58,6 +58,7 @@ ENABLE_CLEAN_UP_OLD_TASKS = True ENABLE_TASK_PROCESSOR_HEALTH_CHECK = True RECURRING_TASK_RUN_RETENTION_DAYS = 15 +TASK_BACKOFF_DEFAULT_DELAY_SECONDS = 5 TASK_DELETE_BATCH_SIZE = 2000 TASK_DELETE_INCLUDE_FAILED_TASKS = False TASK_DELETE_RETENTION_DAYS = 15 diff --git a/src/task_processor/exceptions.py b/src/task_processor/exceptions.py index 7f697a6e..b09bc688 100644 --- a/src/task_processor/exceptions.py +++ b/src/task_processor/exceptions.py @@ -1,3 +1,6 @@ +from datetime import datetime + + class TaskProcessingError(Exception): pass @@ -6,5 +9,20 @@ class InvalidArgumentsError(TaskProcessingError): pass +class TaskBackoffError(TaskProcessingError): + """ + Raise this exception inside a task to indicate that it should be retried after a delay. + This is typically used when a task fails due to a temporary issue, such as + a network error or a service being unavailable. + """ + + def __init__( + self, + delay_until: datetime | None = None, + ) -> None: + super().__init__() + self.delay_until = delay_until + + class TaskQueueFullError(Exception): pass diff --git a/src/task_processor/models.py b/src/task_processor/models.py index 116b7a6e..e7782fe1 100644 --- a/src/task_processor/models.py +++ b/src/task_processor/models.py @@ -7,9 +7,9 @@ from django.db import models from django.utils import timezone -from task_processor.exceptions import TaskProcessingError, TaskQueueFullError +from task_processor.exceptions import TaskQueueFullError from task_processor.managers import RecurringTaskManager, TaskManager -from task_processor.task_registry import registered_tasks +from task_processor.task_registry import get_task, registered_tasks from task_processor.types import TaskCallable _django_json_encoder_default = DjangoJSONEncoder().default @@ -36,13 +36,11 @@ class Meta: abstract = True @property - def args(self) -> typing.List[typing.Any]: + def args(self) -> tuple[typing.Any, ...]: if self.serialized_args: args = self.deserialize_data(self.serialized_args) - if typing.TYPE_CHECKING: - assert isinstance(args, list) - return args - return [] + return tuple(args) + return () @property def kwargs(self) -> typing.Dict[str, typing.Any]: @@ -75,15 +73,8 @@ def run(self) -> None: @property def callable(self) -> TaskCallable[typing.Any]: - try: - task = registered_tasks[self.task_identifier] - return task.task_function - except KeyError as e: - raise TaskProcessingError( - "No task registered with identifier '%s'. Ensure your task is " - "decorated with @register_task_handler.", - self.task_identifier, - ) from e + task = get_task(self.task_identifier) + return task.task_function class Task(AbstractBaseTask): diff --git a/src/task_processor/processor.py b/src/task_processor/processor.py index d6aa414b..501c9934 100644 --- a/src/task_processor/processor.py +++ b/src/task_processor/processor.py @@ -9,6 +9,7 @@ from django.utils import timezone from task_processor import metrics +from task_processor.exceptions import TaskBackoffError from task_processor.managers import RecurringTaskManager, TaskManager from task_processor.models import ( AbstractBaseTask, @@ -18,7 +19,7 @@ TaskResult, TaskRun, ) -from task_processor.task_registry import get_task +from task_processor.task_registry import TaskType, get_task T = typing.TypeVar("T", bound=AbstractBaseTask) AnyTaskRun = TaskRun | RecurringTaskRun @@ -50,7 +51,7 @@ def run_tasks(database: str, num_tasks: int = 1) -> list[TaskRun]: if executed_tasks: Task.objects.using(database).bulk_update( executed_tasks, - fields=["completed", "num_failures", "is_locked"], + fields=["completed", "num_failures", "is_locked", "scheduled_for"], ) if task_runs: @@ -120,6 +121,7 @@ def _run_task( ctx.enter_context(timer) task_identifier = task.task_identifier + registered_task = get_task(task_identifier) logger.debug( f"Running task {task_identifier} id={task.pk} args={task.args} kwargs={task.kwargs}" @@ -157,9 +159,26 @@ def _run_task( exc_info=True, ) + if isinstance(e, TaskBackoffError): + assert registered_task.task_type == TaskType.STANDARD, ( + "Attempt to back off a recurring task (currently not supported)" + ) + if typing.TYPE_CHECKING: + assert isinstance(task, Task) + if task.num_failures <= 3: + delay_until = e.delay_until or timezone.now() + timedelta( + seconds=settings.TASK_BACKOFF_DEFAULT_DELAY_SECONDS, + ) + task.scheduled_for = delay_until + logger.info( + "Backoff requested. Task '%s' set to retry at %s", + task_identifier, + delay_until, + ) + labels = { "task_identifier": task_identifier, - "task_type": get_task(task_identifier).task_type.value.lower(), + "task_type": registered_task.task_type.value.lower(), "result": result.lower(), } diff --git a/src/task_processor/task_registry.py b/src/task_processor/task_registry.py index c64a0898..eb1d3c45 100644 --- a/src/task_processor/task_registry.py +++ b/src/task_processor/task_registry.py @@ -3,6 +3,7 @@ import typing from dataclasses import dataclass +from task_processor.exceptions import TaskProcessingError from task_processor.types import TaskCallable logger = logging.getLogger(__name__) @@ -43,7 +44,14 @@ def initialise() -> None: def get_task(task_identifier: str) -> RegisteredTask: global registered_tasks - return registered_tasks[task_identifier] + try: + return registered_tasks[task_identifier] + except KeyError: + raise TaskProcessingError( + "No task registered with identifier '%s'. Ensure your task is " + "decorated with @register_task_handler.", + task_identifier, + ) def register_task( diff --git a/tests/unit/task_processor/test_unit_task_processor_decorators.py b/tests/unit/task_processor/test_unit_task_processor_decorators.py index 9091aa35..7b528da2 100644 --- a/tests/unit/task_processor/test_unit_task_processor_decorators.py +++ b/tests/unit/task_processor/test_unit_task_processor_decorators.py @@ -14,7 +14,7 @@ register_recurring_task, register_task_handler, ) -from task_processor.exceptions import InvalidArgumentsError +from task_processor.exceptions import InvalidArgumentsError, TaskProcessingError from task_processor.models import RecurringTask, Task, TaskPriority from task_processor.task_registry import get_task, initialise from task_processor.task_run_method import TaskRunMethod @@ -143,7 +143,7 @@ def some_function(first_arg: str, second_arg: str) -> None: # Then assert not RecurringTask.objects.filter(task_identifier=task_identifier).exists() - with pytest.raises(KeyError): + with pytest.raises(TaskProcessingError): assert get_task(task_identifier) diff --git a/tests/unit/task_processor/test_unit_task_processor_models.py b/tests/unit/task_processor/test_unit_task_processor_models.py index 311b9bcd..0adc96b1 100644 --- a/tests/unit/task_processor/test_unit_task_processor_models.py +++ b/tests/unit/task_processor/test_unit_task_processor_models.py @@ -40,6 +40,17 @@ def test_task_run(mocker: MockerFixture) -> None: mock.assert_called_once_with(*args, **kwargs) +def test_task_args__no_data__return_expected() -> None: + # Given + task = Task( + task_identifier="test_task", + scheduled_for=timezone.now(), + ) + + # When & Then + assert task.args == () + + @pytest.mark.parametrize( "input, expected_output", ( diff --git a/tests/unit/task_processor/test_unit_task_processor_processor.py b/tests/unit/task_processor/test_unit_task_processor_processor.py index 11a0ac4e..52b05f39 100644 --- a/tests/unit/task_processor/test_unit_task_processor_processor.py +++ b/tests/unit/task_processor/test_unit_task_processor_processor.py @@ -1,7 +1,7 @@ import logging import time import uuid -from datetime import timedelta +from datetime import datetime, timedelta from threading import Thread import pytest @@ -11,12 +11,13 @@ from pytest_django.fixtures import SettingsWrapper from pytest_mock import MockerFixture -from common.test_tools import AssertMetricFixture +from common.test_tools.types import AssertMetricFixture from task_processor.decorators import ( TaskHandler, register_recurring_task, register_task_handler, ) +from task_processor.exceptions import TaskBackoffError from task_processor.models import ( RecurringTask, RecurringTaskRun, @@ -500,7 +501,7 @@ def test_run_task_runs_task_and_creates_task_run_object_when_failure( ), ( logging.DEBUG, - f"Running task {task.task_identifier} id={task.id} args=['{msg}'] kwargs={{}}", + f"Running task {task.task_identifier} id={task.id} args=('{msg}',) kwargs={{}}", ), ( logging.ERROR, @@ -636,6 +637,111 @@ def test_run_task_runs_tasks_in_correct_priority( assert task_runs_3[0].task == task_2 +@pytest.mark.parametrize( + "exception, expected_scheduled_for", + [ + (TaskBackoffError(), datetime.fromisoformat("2023-12-08T06:05:57+00:00")), + ( + TaskBackoffError( + delay_until=datetime.fromisoformat("2023-12-08T06:15:52+00:00") + ), + datetime.fromisoformat("2023-12-08T06:15:52+00:00"), + ), + ], +) +@pytest.mark.freeze_time("2023-12-08T06:05:47+00:00") +@pytest.mark.multi_database +@pytest.mark.task_processor_mode +def test_run_task__backoff__persists_expected( + exception: TaskBackoffError, + expected_scheduled_for: datetime, + current_database: str, + settings: SettingsWrapper, + caplog: pytest.LogCaptureFixture, +) -> None: + # Given + settings.TASK_BACKOFF_DEFAULT_DELAY_SECONDS = 10 + + @register_task_handler() + def backoff_task() -> None: + raise exception + + task = Task.create( + backoff_task.task_identifier, + scheduled_for=timezone.now(), + args=(), + priority=TaskPriority.HIGH, + ) + task.save(using=current_database) + + caplog.set_level(logging.INFO) + expected_log_message = f"Backoff requested. Task '{backoff_task.task_identifier}' set to retry at {expected_scheduled_for}" + + # When + run_tasks(current_database) + + # Then + assert [ + record.message for record in caplog.records if record.levelno == logging.INFO + ] == [expected_log_message] + task.refresh_from_db(using=current_database) + assert task.scheduled_for == expected_scheduled_for + + +@pytest.mark.multi_database +@pytest.mark.task_processor_mode +def test_run_task__backoff__recurring__raises_expected( + current_database: str, +) -> None: + # Given + @register_recurring_task(run_every=timedelta(seconds=1)) + def backoff_task() -> None: + raise TaskBackoffError() + + initialise() + + # When & Then + with pytest.raises(AssertionError) as exc_info: + run_recurring_tasks(current_database) + + assert ( + str(exc_info.value) + == "Attempt to back off a recurring task (currently not supported)" + ) + + +@pytest.mark.multi_database +@pytest.mark.task_processor_mode +def test_run_task__backoff__max_num_failures__noop( + current_database: str, + caplog: pytest.LogCaptureFixture, +) -> None: + # Given + @register_task_handler() + def backoff_task() -> None: + raise TaskBackoffError() + + expected_scheduled_for = timezone.now() + task = Task.create( + backoff_task.task_identifier, + scheduled_for=expected_scheduled_for, + args=(), + priority=TaskPriority.HIGH, + ) + task.num_failures = 4 + task.save(using=current_database) + + caplog.set_level(logging.INFO) + + # When + run_tasks(current_database) + + # Then + task.refresh_from_db(using=current_database) + assert task.scheduled_for == expected_scheduled_for + assert not [record for record in caplog.records if record.levelno == logging.INFO] + + @pytest.mark.multi_database def test_run_tasks__fails_if_not_in_task_processor_mode( current_database: str,