Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pytest-mock = "^3.14.0"
ruff = "==0.11.9"
setuptools = "^78.1.1"
types-simplejson = "^3.20.0.20250326"
types-python-dateutil = "^2.9.0.20250516"

[build-system]
requires = ["poetry-core"]
Expand Down
1 change: 1 addition & 0 deletions settings/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
ENABLE_CLEAN_UP_OLD_TASKS = True
ENABLE_TASK_PROCESSOR_HEALTH_CHECK = True
RECURRING_TASK_RUN_RETENTION_DAYS = 15
TASK_BACKOFF_DEFAULT_DELAY_SECONDS = 5
Comment thread
matthewelwell marked this conversation as resolved.
TASK_DELETE_BATCH_SIZE = 2000
TASK_DELETE_INCLUDE_FAILED_TASKS = False
TASK_DELETE_RETENTION_DAYS = 15
Expand Down
18 changes: 18 additions & 0 deletions src/task_processor/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from datetime import datetime


class TaskProcessingError(Exception):
pass

Expand All @@ -6,5 +9,20 @@ class InvalidArgumentsError(TaskProcessingError):
pass


class TaskBackoffError(TaskProcessingError):
"""
Raise this exception inside a task to indicate that it should be retried after a delay.
This is typically used when a task fails due to a temporary issue, such as
a network error or a service being unavailable.
"""

def __init__(
self,
delay_until: datetime | None = None,
) -> None:
super().__init__()
self.delay_until = delay_until


class TaskQueueFullError(Exception):
pass
23 changes: 7 additions & 16 deletions src/task_processor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from django.db import models
from django.utils import timezone

from task_processor.exceptions import TaskProcessingError, TaskQueueFullError
from task_processor.exceptions import TaskQueueFullError
from task_processor.managers import RecurringTaskManager, TaskManager
from task_processor.task_registry import registered_tasks
from task_processor.task_registry import get_task, registered_tasks
from task_processor.types import TaskCallable

_django_json_encoder_default = DjangoJSONEncoder().default
Expand All @@ -36,13 +36,11 @@ class Meta:
abstract = True

@property
def args(self) -> typing.List[typing.Any]:
def args(self) -> tuple[typing.Any, ...]:
if self.serialized_args:
args = self.deserialize_data(self.serialized_args)
if typing.TYPE_CHECKING:
assert isinstance(args, list)
return args
return []
return tuple(args)
return ()

@property
def kwargs(self) -> typing.Dict[str, typing.Any]:
Expand Down Expand Up @@ -75,15 +73,8 @@ def run(self) -> None:

@property
def callable(self) -> TaskCallable[typing.Any]:
try:
task = registered_tasks[self.task_identifier]
return task.task_function
except KeyError as e:
raise TaskProcessingError(
"No task registered with identifier '%s'. Ensure your task is "
"decorated with @register_task_handler.",
self.task_identifier,
) from e
task = get_task(self.task_identifier)
return task.task_function


class Task(AbstractBaseTask):
Expand Down
25 changes: 22 additions & 3 deletions src/task_processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from django.utils import timezone

from task_processor import metrics
from task_processor.exceptions import TaskBackoffError
from task_processor.managers import RecurringTaskManager, TaskManager
from task_processor.models import (
AbstractBaseTask,
Expand All @@ -18,7 +19,7 @@
TaskResult,
TaskRun,
)
from task_processor.task_registry import get_task
from task_processor.task_registry import TaskType, get_task

T = typing.TypeVar("T", bound=AbstractBaseTask)
AnyTaskRun = TaskRun | RecurringTaskRun
Expand Down Expand Up @@ -50,7 +51,7 @@ def run_tasks(database: str, num_tasks: int = 1) -> list[TaskRun]:
if executed_tasks:
Task.objects.using(database).bulk_update(
executed_tasks,
fields=["completed", "num_failures", "is_locked"],
fields=["completed", "num_failures", "is_locked", "scheduled_for"],
)

if task_runs:
Expand Down Expand Up @@ -120,6 +121,7 @@ def _run_task(
ctx.enter_context(timer)

task_identifier = task.task_identifier
registered_task = get_task(task_identifier)

logger.debug(
f"Running task {task_identifier} id={task.pk} args={task.args} kwargs={task.kwargs}"
Expand Down Expand Up @@ -157,9 +159,26 @@ def _run_task(
exc_info=True,
)

if isinstance(e, TaskBackoffError):
assert registered_task.task_type == TaskType.STANDARD, (
"Attempt to back off a recurring task (currently not supported)"
)
if typing.TYPE_CHECKING:
assert isinstance(task, Task)
if task.num_failures <= 3:
delay_until = e.delay_until or timezone.now() + timedelta(
seconds=settings.TASK_BACKOFF_DEFAULT_DELAY_SECONDS,
)
task.scheduled_for = delay_until
logger.info(
"Backoff requested. Task '%s' set to retry at %s",
task_identifier,
delay_until,
)

labels = {
"task_identifier": task_identifier,
"task_type": get_task(task_identifier).task_type.value.lower(),
"task_type": registered_task.task_type.value.lower(),
"result": result.lower(),
}

Expand Down
10 changes: 9 additions & 1 deletion src/task_processor/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing
from dataclasses import dataclass

from task_processor.exceptions import TaskProcessingError
from task_processor.types import TaskCallable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,7 +44,14 @@ def initialise() -> None:
def get_task(task_identifier: str) -> RegisteredTask:
global registered_tasks

return registered_tasks[task_identifier]
try:
return registered_tasks[task_identifier]
except KeyError:
raise TaskProcessingError(
"No task registered with identifier '%s'. Ensure your task is "
"decorated with @register_task_handler.",
task_identifier,
)


def register_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
register_recurring_task,
register_task_handler,
)
from task_processor.exceptions import InvalidArgumentsError
from task_processor.exceptions import InvalidArgumentsError, TaskProcessingError
from task_processor.models import RecurringTask, Task, TaskPriority
from task_processor.task_registry import get_task, initialise
from task_processor.task_run_method import TaskRunMethod
Expand Down Expand Up @@ -143,7 +143,7 @@ def some_function(first_arg: str, second_arg: str) -> None:

# Then
assert not RecurringTask.objects.filter(task_identifier=task_identifier).exists()
with pytest.raises(KeyError):
with pytest.raises(TaskProcessingError):
assert get_task(task_identifier)


Expand Down
11 changes: 11 additions & 0 deletions tests/unit/task_processor/test_unit_task_processor_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def test_task_run(mocker: MockerFixture) -> None:
mock.assert_called_once_with(*args, **kwargs)


def test_task_args__no_data__return_expected() -> None:
# Given
task = Task(
task_identifier="test_task",
scheduled_for=timezone.now(),
)

# When & Then
assert task.args == ()


@pytest.mark.parametrize(
"input, expected_output",
(
Expand Down
112 changes: 109 additions & 3 deletions tests/unit/task_processor/test_unit_task_processor_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import time
import uuid
from datetime import timedelta
from datetime import datetime, timedelta
from threading import Thread

import pytest
Expand All @@ -11,12 +11,13 @@
from pytest_django.fixtures import SettingsWrapper
from pytest_mock import MockerFixture

from common.test_tools import AssertMetricFixture
from common.test_tools.types import AssertMetricFixture
from task_processor.decorators import (
TaskHandler,
register_recurring_task,
register_task_handler,
)
from task_processor.exceptions import TaskBackoffError
from task_processor.models import (
RecurringTask,
RecurringTaskRun,
Expand Down Expand Up @@ -500,7 +501,7 @@ def test_run_task_runs_task_and_creates_task_run_object_when_failure(
),
(
logging.DEBUG,
f"Running task {task.task_identifier} id={task.id} args=['{msg}'] kwargs={{}}",
f"Running task {task.task_identifier} id={task.id} args=('{msg}',) kwargs={{}}",
),
(
logging.ERROR,
Expand Down Expand Up @@ -636,6 +637,111 @@ def test_run_task_runs_tasks_in_correct_priority(
assert task_runs_3[0].task == task_2


@pytest.mark.parametrize(
"exception, expected_scheduled_for",
[
(TaskBackoffError(), datetime.fromisoformat("2023-12-08T06:05:57+00:00")),
(
TaskBackoffError(
delay_until=datetime.fromisoformat("2023-12-08T06:15:52+00:00")
),
datetime.fromisoformat("2023-12-08T06:15:52+00:00"),
),
],
)
@pytest.mark.freeze_time("2023-12-08T06:05:47+00:00")
@pytest.mark.multi_database
@pytest.mark.task_processor_mode
def test_run_task__backoff__persists_expected(
exception: TaskBackoffError,
expected_scheduled_for: datetime,
current_database: str,
settings: SettingsWrapper,
caplog: pytest.LogCaptureFixture,
) -> None:
# Given
settings.TASK_BACKOFF_DEFAULT_DELAY_SECONDS = 10

@register_task_handler()
def backoff_task() -> None:
raise exception

task = Task.create(
backoff_task.task_identifier,
scheduled_for=timezone.now(),
args=(),
priority=TaskPriority.HIGH,
)
task.save(using=current_database)

caplog.set_level(logging.INFO)
expected_log_message = f"Backoff requested. Task '{backoff_task.task_identifier}' set to retry at {expected_scheduled_for}"

# When
run_tasks(current_database)

# Then
assert [
record.message for record in caplog.records if record.levelno == logging.INFO
] == [expected_log_message]
task.refresh_from_db(using=current_database)
assert task.scheduled_for == expected_scheduled_for


@pytest.mark.multi_database
@pytest.mark.task_processor_mode
def test_run_task__backoff__recurring__raises_expected(
current_database: str,
) -> None:
# Given
@register_recurring_task(run_every=timedelta(seconds=1))
def backoff_task() -> None:
raise TaskBackoffError()

initialise()

# When & Then
with pytest.raises(AssertionError) as exc_info:
run_recurring_tasks(current_database)

assert (
str(exc_info.value)
== "Attempt to back off a recurring task (currently not supported)"
)


@pytest.mark.multi_database
@pytest.mark.task_processor_mode
def test_run_task__backoff__max_num_failures__noop(
current_database: str,
caplog: pytest.LogCaptureFixture,
) -> None:
# Given
@register_task_handler()
def backoff_task() -> None:
raise TaskBackoffError()

expected_scheduled_for = timezone.now()
task = Task.create(
backoff_task.task_identifier,
scheduled_for=expected_scheduled_for,
args=(),
priority=TaskPriority.HIGH,
)
task.num_failures = 4
task.save(using=current_database)

caplog.set_level(logging.INFO)

# When
run_tasks(current_database)

# Then
task.refresh_from_db(using=current_database)
assert task.scheduled_for == expected_scheduled_for
assert not [record for record in caplog.records if record.levelno == logging.INFO]


@pytest.mark.multi_database
def test_run_tasks__fails_if_not_in_task_processor_mode(
current_database: str,
Expand Down