Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ uv run --group docs zensical build
- `_on_complete()` callback handles task completion and LRU eviction
- **ThreadPoolBackend**: Uses `ThreadPoolExecutor` for I/O-bound tasks
- **ProcessPoolBackend**: Uses `ProcessPoolExecutor` for CPU-bound tasks
- Validates pickling of arguments upfront
- **_execute_task()**: Module-level wrapper that sets ContextVar and invokes the task function
- **current_result_id**: ContextVar allowing tasks to access their own result ID

**Backend capabilities:**

Expand All @@ -68,7 +65,6 @@ uv run --group docs zensical build

1. **Shared state registry**: Multiple backend instances with the same alias share executor and result storage via `ExecutorState`
2. **Task status from Future**: READY/RUNNING determined by `Future.running()` at `get_result()` time
3. **ContextVar in both backends**: Works by setting the ContextVar inside the worker thread/process
4. **Function path resolution**: Tasks are imported by path string for ProcessPool pickling compatibility

## Key Limitation
Expand Down
2 changes: 0 additions & 2 deletions django_tasks_local/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
from .backend import (
ProcessPoolBackend,
ThreadPoolBackend,
current_result_id,
)

__all__ = [
"ThreadPoolBackend",
"ProcessPoolBackend",
"current_result_id",
]
149 changes: 68 additions & 81 deletions django_tasks_local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,53 +18,67 @@
"""

import logging
import pickle
import traceback
import uuid
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from contextvars import ContextVar
from datetime import datetime, timezone
from typing import Any

from django.tasks import TaskResult, TaskResultStatus
from django.tasks.backends.base import BaseTaskBackend
from django.tasks.base import Task, TaskError
from django.tasks.base import Task, TaskError, TaskContext
from django.tasks.signals import task_enqueued, task_finished, task_started
from django.tasks.exceptions import TaskResultDoesNotExist
from django.utils.module_loading import import_string

from django.utils.json import normalize_json
from .state import get_executor_state, shutdown_executor

logger = logging.getLogger(__name__)

# Context variable for tasks to access their own result ID
current_result_id: ContextVar[str] = ContextVar("current_result_id")

class PicklableTaskResult(TaskResult):
"""
A modified TaskResult which smuggles the Task instance as a string when pickled.
"""

def _execute_task(func_path: str, args: tuple, kwargs: dict, result_id: str):
"""Execute task in worker thread/process.
def __getstate__(self):
state = super().__getstate__()
assert isinstance(state[0], Task)
state[0] = state[0].module_path
return state

Sets ContextVar for task access. Works in both backends:
- ThreadPoolExecutor: shared memory, ContextVar set in worker thread
- ProcessPoolExecutor: separate memory, ContextVar set in child process
def __setstate__(self, state):
state[0] = import_string(state[0])
assert isinstance(state[0], Task)
return super().__setstate__(state)


def _execute_task(backend: type[BaseTaskBackend], name: str, task_result: PicklableTaskResult):
"""Execute task in worker thread/process.

Must be module-level (not a method) for ProcessPoolExecutor pickling.
"""
token = current_result_id.set(result_id)
try:
obj = import_string(func_path)
# If it's a Task wrapper, get the underlying function
func = getattr(obj, "func", obj)
return func(*args, **kwargs)
finally:
current_result_id.reset(token)
task_result.worker_ids.append(name)
task_started.send(backend, task_result=task_result)
if task_result.task.takes_context:
raw_return_value = task_result.task.call(
TaskContext(task_result=task_result),
*task_result.args,
**task_result.kwargs,
)
else:
raw_return_value = task_result.task.call(
*task_result.args, **task_result.kwargs
)
return normalize_json(raw_return_value)


class FuturesBackend(BaseTaskBackend):
"""Base class for concurrent.futures-based backends."""

supports_defer = False
supports_async_task = False
supports_get_result = True
supports_get_result = False
supports_priority = False

executor_class: type = None # Subclasses must set
Expand Down Expand Up @@ -102,14 +116,12 @@ def enqueue(
to check for RUNNING, SUCCESSFUL, or FAILED status.
"""
self.validate_task(task)
self._validate_pickleable(task, args, kwargs)

result_id = str(uuid.uuid4())
func_path = f"{task.func.__module__}.{task.func.__qualname__}"

# Store initial result before submitting (for immediate get_result calls)
now = datetime.now(timezone.utc)
initial_result = TaskResult(
initial_result = PicklableTaskResult(
id=result_id,
task=task,
status=TaskResultStatus.READY,
Expand All @@ -128,9 +140,7 @@ def enqueue(
self._state.results[result_id] = initial_result

# Submit to executor
future = self._state.executor.submit(
_execute_task, func_path, args or (), kwargs or {}, result_id
)
future = self._state.executor.submit(_execute_task, type(self), self._name, initial_result)

with self._state.lock:
self._state.futures[result_id] = future
Expand All @@ -140,63 +150,53 @@ def enqueue(
lambda f: self._on_complete(result_id, task, initial_result, f)
)

logger.debug(
"Task '%s' enqueued: result_id=%s",
task.name,
result_id,
)
task_enqueued.send(type(self), task_result=initial_result)

return initial_result

def _on_complete(
self,
result_id: str,
task: Task,
initial_result: TaskResult,
initial_result: PicklableTaskResult,
future: Future,
) -> None:
"""Called when future completes. Runs in parent process/thread."""
finished_at = datetime.now(timezone.utc)

final_result = TaskResult(
id=result_id,
task=task,
status=TaskResultStatus.SUCCESSFUL,
enqueued_at=initial_result.enqueued_at,
started_at=initial_result.enqueued_at, # Approximation
finished_at=finished_at,
last_attempted_at=initial_result.enqueued_at,
args=initial_result.args,
kwargs=initial_result.kwargs,
backend=self.alias,
errors=[],
worker_ids=[self._name],
)

try:
return_value = future.result()
final_result = TaskResult(
id=result_id,
task=task,
status=TaskResultStatus.SUCCESSFUL,
enqueued_at=initial_result.enqueued_at,
started_at=initial_result.enqueued_at, # Approximation
finished_at=finished_at,
last_attempted_at=initial_result.enqueued_at,
args=initial_result.args,
kwargs=initial_result.kwargs,
backend=self.alias,
errors=[],
worker_ids=[],
)
# TaskResult is frozen, use object.__setattr__ to set _return_value
object.__setattr__(final_result, "_return_value", return_value)
object.__setattr__(final_result, "_return_value", future.result())
except Exception as e:
logger.exception("Task %s failed: %s", task.name, e)
error = TaskError(
exception_class_path=f"{type(e).__module__}.{type(e).__qualname__}",
traceback=traceback.format_exc(),
)
final_result = TaskResult(
id=result_id,
task=task,
status=TaskResultStatus.FAILED,
enqueued_at=initial_result.enqueued_at,
started_at=initial_result.enqueued_at,
finished_at=finished_at,
last_attempted_at=initial_result.enqueued_at,
args=initial_result.args,
kwargs=initial_result.kwargs,
backend=self.alias,
errors=[error],
worker_ids=[],
final_result.errors.append(
TaskError(
exception_class_path=f"{type(e).__module__}.{type(e).__qualname__}",
traceback=traceback.format_exc(),
)
)

object.__setattr__(final_result, "status", TaskResultStatus.FAILED)

# Called inside the exception handler so the signal can access the exception
task_finished.send(type(self), task_result=final_result)
else:
task_finished.send(type(self), task_result=final_result)

with self._state.lock:
self._state.results[result_id] = final_result
self._state.futures.pop(result_id, None)
Expand All @@ -221,6 +221,7 @@ def get_result(self, result_id: str) -> TaskResult:
stored_result = self._state.results.get(result_id)

if stored_result is None:
print(self._state)
raise TaskResultDoesNotExist(result_id)

# If future still in flight, determine current status
Expand All @@ -244,15 +245,11 @@ def get_result(self, result_id: str) -> TaskResult:
kwargs=stored_result.kwargs,
backend=self.alias,
errors=[],
worker_ids=[],
worker_ids=[self._name],
)

return stored_result

def _validate_pickleable(self, task: Task, args: tuple | None, kwargs: dict | None):
"""Validate arguments can be pickled. Override in ProcessPoolBackend."""
pass

def close(self) -> None:
"""Shutdown the shared executor.

Expand All @@ -279,17 +276,7 @@ class ProcessPoolBackend(FuturesBackend):
Tasks run in a ProcessPoolExecutor, bypassing the GIL.

Constraints:
- Task arguments and return values must be pickleable
- No shared memory with main process (global state changes don't persist)
"""

executor_class = ProcessPoolExecutor

def _validate_pickleable(self, task: Task, args: tuple | None, kwargs: dict | None):
"""Fail fast if arguments can't be pickled."""
try:
pickle.dumps((args, kwargs))
except (pickle.PicklingError, TypeError, AttributeError) as e:
raise ValueError(
f"Task arguments must be pickleable for ProcessPoolBackend: {e}"
) from e
26 changes: 1 addition & 25 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ from django_tasks_local import ProcessPoolBackend

Executes tasks in a `ProcessPoolExecutor`. Best for CPU-bound tasks.

**Constraint:** Arguments and return values must be pickleable.

## Backend Capabilities

| Attribute | Value | Description |
Expand All @@ -39,7 +37,7 @@ Enqueue a task for background execution.

**Returns:** `TaskResult` with initial status `READY`.

**Raises:** `ValueError` if using ProcessPoolBackend with unpickleable arguments.
**Raises:** `TypeError` if arguments cannot be converted to JSON.

### `get_result(result_id)`

Expand All @@ -54,25 +52,3 @@ Retrieve a task result by its UUID string.
Shut down the executor.

**Warning:** This shuts down the executor for ALL backend instances using the same alias. Only call during application shutdown.

## Context Variable

### `current_result_id`

```python
from django_tasks_local import current_result_id
```

A `ContextVar[str]` holding the current task's result ID. Only available within a running task.

```python
from django.tasks import task
from django_tasks_local import current_result_id

@task
def my_task():
result_id = current_result_id.get()
# Use for logging, caching progress, etc.
```

Works in both ThreadPoolBackend and ProcessPoolBackend.
Loading