From 886ee0ddcc88eaca1b3444e135ef7c21d3810697 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Jolovi=C4=87?= Date: Sat, 9 May 2026 01:11:26 +0200 Subject: [PATCH 1/2] Normalize tag filters across stores --- src/cashet/_client_base.py | 10 +++++++++- src/cashet/redis_store.py | 9 +++++++-- src/cashet/store.py | 12 +++++++++--- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/cashet/_client_base.py b/src/cashet/_client_base.py index e63b5f2..66c55d0 100644 --- a/src/cashet/_client_base.py +++ b/src/cashet/_client_base.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from collections.abc import Callable +from collections.abc import Callable, Mapping from datetime import timedelta from pathlib import Path from typing import Any @@ -85,6 +85,14 @@ def resolve_status(status: TaskStatus | str | None) -> TaskStatus | None: return status +def normalize_tag_filters( + tags: Mapping[str, str | None] | None, +) -> dict[str, str | None]: + if not tags: + return {} + return {str(key): str(value) for key, value in tags.items() if value is not None} + + async def load_result(commit: Commit, store: AsyncStore, serializer: Serializer) -> Any: if commit.output_ref is None: return None diff --git a/src/cashet/redis_store.py b/src/cashet/redis_store.py index 3bb0951..0b4d4f5 100644 --- a/src/cashet/redis_store.py +++ b/src/cashet/redis_store.py @@ -9,6 +9,7 @@ from redis.exceptions import WatchError +from cashet._client_base import normalize_tag_filters from cashet._runner import BlockingAsyncRunner from cashet.models import Commit, ObjectRef, StorageTier, TaskDef, TaskStatus @@ -412,6 +413,7 @@ async def list_commits( status: TaskStatus | None = None, tags: dict[str, str | None] | None = None, ) -> list[Commit]: + tag_filters = normalize_tag_filters(tags) if func_name: hashes = await self._redis.zrevrange(_func_key(func_name), 0, -1) else: @@ -424,7 +426,7 @@ async def list_commits( continue if status is not None and commit.status != status: continue - if tags is not None and not _matches_tags(commit, tags): + if tags is not None and not _matches_tags(commit, tag_filters): continue commits.append(commit) if len(commits) >= limit: @@ -579,8 +581,11 @@ async def delete_commit(self, hash: str) -> bool: return await self._delete_commit(hash) async def delete_by_tags(self, tags: dict[str, str | None]) -> int: + tag_filters = normalize_tag_filters(tags) + if not tag_filters: + return 0 set_keys: list[str] = [] - for key, val in tags.items(): + for key, val in tag_filters.items(): set_keys.append(_tag_key(key) if val is None else _tag_value_key(key, val)) if len(set_keys) == 1: hashes = await self._redis.smembers(set_keys[0]) diff --git a/src/cashet/store.py b/src/cashet/store.py index a80c60a..b433ad8 100644 --- a/src/cashet/store.py +++ b/src/cashet/store.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any +from cashet._client_base import normalize_tag_filters from cashet._runner import BlockingAsyncRunner from cashet.models import Commit, ObjectRef, StorageTier, TaskDef, TaskStatus @@ -390,6 +391,7 @@ def list_commits( tags: dict[str, str | None] | None = None, ) -> list[Commit]: conn = self._connect() + tag_filters = normalize_tag_filters(tags) query = "SELECT * FROM commits WHERE 1=1" params: list[Any] = [] if func_name: @@ -398,8 +400,8 @@ def list_commits( if status: query += " AND status = ?" params.append(status.value) - if tags: - for key, val in tags.items(): + if tag_filters: + for key, val in tag_filters.items(): if val is None: query += " AND json_extract(tags, ?) IS NOT NULL" params.append(f"$.{key}") @@ -665,9 +667,13 @@ def delete_commit(self, hash: str) -> bool: def delete_by_tags(self, tags: dict[str, str | None]) -> int: conn = self._connect(immediate=True) + tag_filters = normalize_tag_filters(tags) + if not tag_filters: + conn.execute("ROLLBACK") + return 0 query = "SELECT hash, output_hash, input_refs FROM commits WHERE 1=1" params: list[Any] = [] - for key, val in tags.items(): + for key, val in tag_filters.items(): if val is None: query += " AND json_extract(tags, ?) IS NOT NULL" params.append(f"$.{key}") From c848d00ed9796d11fd3ac644d349a874c46f6e29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Jolovi=C4=87?= Date: Sat, 9 May 2026 01:18:53 +0200 Subject: [PATCH 2/2] Update src/cashet/_client_base.py Thanks Co-authored-by: ds-review[bot] <279337960+ds-review[bot]@users.noreply.github.com> --- src/cashet/_client_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cashet/_client_base.py b/src/cashet/_client_base.py index 66c55d0..59e6110 100644 --- a/src/cashet/_client_base.py +++ b/src/cashet/_client_base.py @@ -90,7 +90,9 @@ def normalize_tag_filters( ) -> dict[str, str | None]: if not tags: return {} - return {str(key): str(value) for key, value in tags.items() if value is not None} +if not tags: + return {} +return {str(k): str(v) if v is not None else None for k, v in tags.items()} async def load_result(commit: Commit, store: AsyncStore, serializer: Serializer) -> Any: