diff --git a/src/cashet/_client_base.py b/src/cashet/_client_base.py index e63b5f2..59e6110 100644 --- a/src/cashet/_client_base.py +++ b/src/cashet/_client_base.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from collections.abc import Callable +from collections.abc import Callable, Mapping from datetime import timedelta from pathlib import Path from typing import Any @@ -85,6 +85,16 @@ def resolve_status(status: TaskStatus | str | None) -> TaskStatus | None: return status +def normalize_tag_filters( + tags: Mapping[str, str | None] | None, +) -> dict[str, str | None]: + if not tags: + return {} +if not tags: + return {} +return {str(k): str(v) if v is not None else None for k, v in tags.items()} + + async def load_result(commit: Commit, store: AsyncStore, serializer: Serializer) -> Any: if commit.output_ref is None: return None diff --git a/src/cashet/redis_store.py b/src/cashet/redis_store.py index 3bb0951..0b4d4f5 100644 --- a/src/cashet/redis_store.py +++ b/src/cashet/redis_store.py @@ -9,6 +9,7 @@ from redis.exceptions import WatchError +from cashet._client_base import normalize_tag_filters from cashet._runner import BlockingAsyncRunner from cashet.models import Commit, ObjectRef, StorageTier, TaskDef, TaskStatus @@ -412,6 +413,7 @@ async def list_commits( status: TaskStatus | None = None, tags: dict[str, str | None] | None = None, ) -> list[Commit]: + tag_filters = normalize_tag_filters(tags) if func_name: hashes = await self._redis.zrevrange(_func_key(func_name), 0, -1) else: @@ -424,7 +426,7 @@ async def list_commits( continue if status is not None and commit.status != status: continue - if tags is not None and not _matches_tags(commit, tags): + if tags is not None and not _matches_tags(commit, tag_filters): continue commits.append(commit) if len(commits) >= limit: @@ -579,8 +581,11 @@ async def delete_commit(self, hash: str) -> bool: return await self._delete_commit(hash) async def delete_by_tags(self, tags: dict[str, str | None]) -> int: + tag_filters = normalize_tag_filters(tags) + if not tag_filters: + return 0 set_keys: list[str] = [] - for key, val in tags.items(): + for key, val in tag_filters.items(): set_keys.append(_tag_key(key) if val is None else _tag_value_key(key, val)) if len(set_keys) == 1: hashes = await self._redis.smembers(set_keys[0]) diff --git a/src/cashet/store.py b/src/cashet/store.py index a80c60a..b433ad8 100644 --- a/src/cashet/store.py +++ b/src/cashet/store.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any +from cashet._client_base import normalize_tag_filters from cashet._runner import BlockingAsyncRunner from cashet.models import Commit, ObjectRef, StorageTier, TaskDef, TaskStatus @@ -390,6 +391,7 @@ def list_commits( tags: dict[str, str | None] | None = None, ) -> list[Commit]: conn = self._connect() + tag_filters = normalize_tag_filters(tags) query = "SELECT * FROM commits WHERE 1=1" params: list[Any] = [] if func_name: @@ -398,8 +400,8 @@ def list_commits( if status: query += " AND status = ?" params.append(status.value) - if tags: - for key, val in tags.items(): + if tag_filters: + for key, val in tag_filters.items(): if val is None: query += " AND json_extract(tags, ?) IS NOT NULL" params.append(f"$.{key}") @@ -665,9 +667,13 @@ def delete_commit(self, hash: str) -> bool: def delete_by_tags(self, tags: dict[str, str | None]) -> int: conn = self._connect(immediate=True) + tag_filters = normalize_tag_filters(tags) + if not tag_filters: + conn.execute("ROLLBACK") + return 0 query = "SELECT hash, output_hash, input_refs FROM commits WHERE 1=1" params: list[Any] = [] - for key, val in tags.items(): + for key, val in tag_filters.items(): if val is None: query += " AND json_extract(tags, ?) IS NOT NULL" params.append(f"$.{key}")