Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/task_processor/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,13 @@ def __init__(

class TaskQueueFullError(Exception):
pass


class TaskAbandonedError(TaskProcessingError):
"""
Marker error for recurring task runs whose worker died before
recording the result (process killed, OOM, host evicted, DB
connection lost during the post-execution save). Never raised —
used as the prefix in `error_details` so monitoring and log scrapers
can match on a single authoritative class name.
"""
26 changes: 25 additions & 1 deletion src/task_processor/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import typing
import uuid
from datetime import datetime, timedelta
Expand All @@ -7,11 +8,13 @@
from django.db import models
from django.utils import timezone

from task_processor.exceptions import TaskQueueFullError
from task_processor.exceptions import TaskAbandonedError, TaskQueueFullError
from task_processor.managers import RecurringTaskManager, TaskManager
from task_processor.task_registry import get_task, registered_tasks
from task_processor.types import TaskCallable, TraceContext

logger = logging.getLogger(__name__)

_django_json_encoder_default = DjangoJSONEncoder().default


Expand Down Expand Up @@ -172,6 +175,27 @@ def unlock(self) -> None:
self.is_locked = False
self.locked_at = None

def reconcile_abandoned_run(self) -> None:
# if for some reason the worker died before before writing the task run result
# we mark that run as explict failure here
abandoned_run = self.task_runs.filter(result__isnull=True).first()
if abandoned_run is None:
return
abandoned_run.finished_at = timezone.now()
abandoned_run.result = TaskResult.FAILURE.value
abandoned_run.error_details = (
f"{TaskAbandonedError.__name__}: "
"no result was written before the SQL reaper unlocked the task"
)
abandoned_run.save(
update_fields=["finished_at", "result", "error_details"],
)
logger.error(
"Recurring task '%s' was abandoned: %s",
self.task_identifier,
abandoned_run.error_details,
)

@property
def should_execute(self) -> bool:
now = timezone.now()
Expand Down
17 changes: 13 additions & 4 deletions src/task_processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def run_recurring_task(database: str) -> RecurringTaskRun | None:

logger.debug(f"Running recurring task '{task.task_identifier}'")

task.reconcile_abandoned_run()

if not task.is_task_registered:
# This is necessary to ensure that old instances of the task processor,
# which may still be running during deployment, do not remove tasks added by new instances.
Expand All @@ -91,9 +93,14 @@ def run_recurring_task(database: str) -> RecurringTaskRun | None:

task_run: RecurringTaskRun | None = None
if task.should_execute:
task, run = _run_task(task)
assert isinstance(run, RecurringTaskRun)
task_run = run
# Persist the task run before execution so that, if the worker is
# killed mid-task, we still have a row we can later mark as timed
# out when the task is unlocked by the timeout-based reaper in
# `get_recurringtasks_to_process`.
task_run = RecurringTaskRun(started_at=timezone.now(), task=task)
task_run.save(using=database)
task, run = _run_task(task, task_run=task_run)
assert run is task_run
# task.run() may have idled the DB connection past the server's
# session timeout; drop stale connections so the saves below open
# a fresh one. See Sentry FLAGSMITH-API-5EM.
Expand All @@ -113,6 +120,7 @@ def run_recurring_task(database: str) -> RecurringTaskRun | None:

def _run_task(
task: T,
task_run: AnyTaskRun | None = None,
) -> typing.Tuple[T, AnyTaskRun]:
assert settings.TASK_PROCESSOR_MODE, (
"Attempt to run tasks in a non-task-processor environment"
Expand All @@ -128,7 +136,8 @@ def _run_task(
logger.debug(
f"Running task {task_identifier} id={task.pk} args={task.args} kwargs={task.kwargs}"
)
task_run: AnyTaskRun = task.task_runs.model(started_at=timezone.now(), task=task) # type: ignore[attr-defined]
if task_run is None:
task_run = task.task_runs.model(started_at=timezone.now(), task=task) # type: ignore[attr-defined]
result: str
executor = None

Expand Down
64 changes: 63 additions & 1 deletion tests/unit/task_processor/test_unit_task_processor_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytest_mock import MockerFixture

from task_processor.decorators import register_task_handler
from task_processor.models import RecurringTask, Task
from task_processor.models import RecurringTask, RecurringTaskRun, Task, TaskResult
from task_processor.task_registry import initialise

now = timezone.now()
Expand Down Expand Up @@ -144,3 +144,65 @@ def test_task_create__trace_context__persists_expected(
# Then
task.refresh_from_db()
assert task.trace_context == trace_context


@pytest.mark.django_db
def test_recurring_task_reconcile_abandoned_run__no_abandoned_run__noop() -> None:
# Given - a task with one completed run and no abandoned rows
task = RecurringTask.objects.create(
task_identifier="test_recurring_task",
run_every=timedelta(seconds=1),
)
finished_at = timezone.now()
finished_run = RecurringTaskRun.objects.create(
task=task,
started_at=finished_at - timedelta(seconds=1),
finished_at=finished_at,
result=TaskResult.SUCCESS.value,
)

# When
task.reconcile_abandoned_run()

# Then - the finished run is untouched
finished_run.refresh_from_db()
assert finished_run.result == TaskResult.SUCCESS.value
assert finished_run.finished_at == finished_at
assert finished_run.error_details is None


@pytest.mark.django_db
def test_recurring_task_reconcile_abandoned_run__finished_run_present__only_abandoned_touched() -> (
None
):
# Given - a task with both a completed run and an abandoned run
task = RecurringTask.objects.create(
task_identifier="test_recurring_task",
run_every=timedelta(seconds=1),
)
finished_started_at = timezone.now() - timedelta(hours=2)
finished_at = timezone.now() - timedelta(hours=1)
finished_run = RecurringTaskRun.objects.create(
task=task,
started_at=finished_started_at,
finished_at=finished_at,
result=TaskResult.SUCCESS.value,
)
abandoned_run = RecurringTaskRun.objects.create(
task=task,
started_at=timezone.now() - timedelta(minutes=30),
)

# When
task.reconcile_abandoned_run()

# Then - only the abandoned row is marked FAILURE
abandoned_run.refresh_from_db()
assert abandoned_run.result == TaskResult.FAILURE.value
assert abandoned_run.finished_at is not None
assert abandoned_run.error_details

finished_run.refresh_from_db()
assert finished_run.result == TaskResult.SUCCESS.value
assert finished_run.finished_at == finished_at
assert finished_run.error_details is None
45 changes: 41 additions & 4 deletions tests/unit/task_processor/test_unit_task_processor_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
register_recurring_task,
register_task_handler,
)
from task_processor.exceptions import TaskBackoffError
from task_processor.exceptions import TaskAbandonedError, TaskBackoffError
from task_processor.models import (
RecurringTask,
RecurringTaskRun,
Expand Down Expand Up @@ -288,6 +288,43 @@ def _dummy_recurring_task() -> None:
assert task.locked_at is None


@pytest.mark.multi_database(transaction=True)
@pytest.mark.task_processor_mode
def test_run_recurring_task__abandoned_run__reconciled_as_failure(
current_database: str,
) -> None:
# Given - a recurring task with a stale lock and a pre-saved
# RecurringTaskRun row that a previous worker left behind when it
# died mid-task (result/finished_at still null).
@register_recurring_task(run_every=timedelta(seconds=1))
def _dummy_recurring_task() -> None:
pass

initialise()

task = RecurringTask.objects.using(current_database).get(
task_identifier="test_unit_task_processor_processor._dummy_recurring_task",
)
abandoned_run = RecurringTaskRun.objects.using(current_database).create(
task=task,
started_at=timezone.now() - timedelta(hours=1),
)
task.is_locked = True
task.locked_at = timezone.now() - timedelta(hours=1)
task.save(using=current_database)

# When
run_recurring_task(current_database)

# Then - the abandoned row is marked as FAILURE with a distinguishing
# error message
abandoned_run.refresh_from_db(using=current_database)
assert abandoned_run.result == TaskResult.FAILURE.value
assert abandoned_run.finished_at is not None
assert abandoned_run.error_details is not None
assert TaskAbandonedError.__name__ in abandoned_run.error_details


@pytest.mark.multi_database(transaction=True)
@pytest.mark.task_processor_mode
def test_run_recurring_task__multiple_runs__executes_expected_times(
Expand Down Expand Up @@ -344,15 +381,15 @@ def test_run_recurring_task__multiple_tasks__loops_over_all(
settings: SettingsWrapper,
) -> None:
# Given, Three recurring tasks
@register_recurring_task(run_every=timedelta(milliseconds=200))
@register_recurring_task(run_every=timedelta(hours=1))
def _dummy_recurring_task_1() -> None:
pass

@register_recurring_task(run_every=timedelta(milliseconds=200))
@register_recurring_task(run_every=timedelta(hours=1))
def _dummy_recurring_task_2() -> None:
pass

@register_recurring_task(run_every=timedelta(milliseconds=200))
@register_recurring_task(run_every=timedelta(hours=1))
def _dummy_recurring_task_3() -> None:
pass

Expand Down