diff --git a/src/cashet/async_executor.py b/src/cashet/async_executor.py index f37c545..3eefcc6 100644 --- a/src/cashet/async_executor.py +++ b/src/cashet/async_executor.py @@ -46,7 +46,9 @@ async def _async_store_lock( def _is_stale_claim(commit: Commit, ttl: timedelta) -> bool: - return datetime.now(UTC) - commit.claimed_at > ttl + # Use created_at so that long-pending tasks (created but never successfully + # claimed) are still eligible for reclaim, rather than being stuck forever. + return datetime.now(UTC) - commit.created_at > ttl class AsyncLocalExecutor: diff --git a/src/cashet/models.py b/src/cashet/models.py index 1e36b30..f21434c 100644 --- a/src/cashet/models.py +++ b/src/cashet/models.py @@ -62,6 +62,9 @@ class Commit: output_ref: ObjectRef | None = None parent_hash: str | None = None status: TaskStatus = TaskStatus.PENDING + # created_at is the canonical anchor for task lifetime; it never changes + # and is used for stale-claim detection so that pending tasks cannot hide + # behind a recent heartbeat. created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) claimed_at: datetime = field(default_factory=lambda: datetime.now(UTC)) error: str | None = None diff --git a/src/cashet/store.py b/src/cashet/store.py index a80c60a..80cb9a7 100644 --- a/src/cashet/store.py +++ b/src/cashet/store.py @@ -352,7 +352,7 @@ def find_running_by_fingerprint(self, fingerprint: str) -> Commit | None: row = conn.execute( """SELECT * FROM commits WHERE fingerprint = ? AND status = 'running' - ORDER BY claimed_at DESC LIMIT 1""", + ORDER BY created_at DESC LIMIT 1""", (fingerprint,), ).fetchone() if row is None: diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 04da39a..b6a1fd8 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import UTC, datetime, timedelta from pathlib import Path import pytest @@ -281,6 +282,27 @@ def non_cached() -> int: assert await ref1.load() == 1 assert await ref2.load() == 2 + async def test_old_created_at_causes_reclaim_despite_fresh_claim( + self, async_client: AsyncClient + ) -> None: + import cashet.dag as dag + import cashet.hashing as hashing + from cashet.models import TaskStatus + + def work() -> int: + return 42 + + task_def = hashing.build_task_def(work, (), {}) + input_refs = dag.resolve_input_refs((), {}) + commit = dag.build_commit(task_def, input_refs) + commit.status = TaskStatus.RUNNING + commit.created_at = datetime.now(UTC) - timedelta(seconds=400) + commit.claimed_at = datetime.now(UTC) - timedelta(seconds=5) + await async_client.store.put_commit(commit) + + ref = await async_client.submit(work) + assert await ref.load() == 42 + async def test_task_decorator_callable_returns_async_result_ref( self, async_client: AsyncClient ) -> None: diff --git a/tests/test_store.py b/tests/test_store.py index 13cee6b..f2bfbfd 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -299,6 +299,27 @@ def slow() -> int: with pytest.raises(TaskError, match="TimeoutError"): client.submit(slow, _timeout=0.01) + def test_old_created_at_trumps_recent_claimed_at(self, store_dir: Path) -> None: + import cashet.dag as dag + import cashet.hashing as hashing + from cashet.models import TaskStatus + + client = Client(store_dir=store_dir) + + def work() -> int: + return 42 + + task_def = hashing.build_task_def(work, (), {}) + input_refs = dag.resolve_input_refs((), {}) + commit = dag.build_commit(task_def, input_refs) + commit.status = TaskStatus.RUNNING + commit.created_at = datetime.now(UTC) - timedelta(seconds=400) + commit.claimed_at = datetime.now(UTC) - timedelta(seconds=5) + client.store.put_commit(commit) + + ref = client.submit(work) + assert ref.load() == 42 + def test_running_claim_lookup_is_not_limited_to_1000_rows(self, store_dir: Path) -> None: import cashet.dag as dag import cashet.hashing as hashing