Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions taskiq_postgresql/result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from taskiq import AsyncResultBackend, TaskiqResult
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
from taskiq.depends.progress_tracker import TaskProgress
from taskiq.serializers.pickle import PickleSerializer

from taskiq_postgresql.abc.driver import QueryDriver
Expand All @@ -13,6 +15,8 @@

_ReturnType = TypeVar("_ReturnType")

PROGRESS_KEY_SUFFIX = "__progress"


@dataclass
class Table:
Expand Down Expand Up @@ -204,3 +208,81 @@ async def delete_by_date(
to_date (datetime | date | None): Date to which to delete results.
"""
await self.driver.delete_by_date(from_date, to_date)


async def set_progress(
self,
task_id: Any,
progress: TaskProgress[_ReturnType],
) -> None:
"""
Store task progress.

Args:
task_id: ID of the task.
progress: Progress payload.
"""
await self.driver.insert_or_update(
[
self.columns.primary_key,
self.columns.result,
],
[
f"{task_id}{PROGRESS_KEY_SUFFIX}",
self.serializer.dumpb(model_dump(progress)),
],
[
self.columns.primary_key,
],
[
self.columns.result,
],
)

async def get_progress(
self,
task_id: Any,
) -> TaskProgress[_ReturnType] | None:
"""
Retrieve task progress.

Args:
task_id: ID of the task.

Returns:
TaskProgress instance or None.
"""
data = await self.driver.select(
[
self.columns.result,
],
[
self.columns.primary_key,
],
[f"{task_id}{PROGRESS_KEY_SUFFIX}"],
)

if not data:
return None

progress_bytes = data[0]["result"]

if progress_bytes is None:
return None

return model_validate(
TaskProgress[_ReturnType],
self.serializer.loadb(progress_bytes),
)

async def delete_progress(self, task_id: Any) -> None:
"""
Delete stored progress for a task.

Args:
task_id: ID of the task.
"""
await self.driver.delete(
self.columns.primary_key,
f"{task_id}{PROGRESS_KEY_SUFFIX}",
)