Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/cashet/_client_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import os
from collections.abc import Callable
from collections.abc import Callable, Mapping
from datetime import timedelta
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -85,6 +85,16 @@ def resolve_status(status: TaskStatus | str | None) -> TaskStatus | None:
return status


def normalize_tag_filters(
tags: Mapping[str, str | None] | None,
) -> dict[str, str | None]:
if not tags:
return {}
if not tags:
return {}
return {str(k): str(v) if v is not None else None for k, v in tags.items()}


async def load_result(commit: Commit, store: AsyncStore, serializer: Serializer) -> Any:
if commit.output_ref is None:
return None
Expand Down
9 changes: 7 additions & 2 deletions src/cashet/redis_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from redis.exceptions import WatchError

from cashet._client_base import normalize_tag_filters
from cashet._runner import BlockingAsyncRunner
from cashet.models import Commit, ObjectRef, StorageTier, TaskDef, TaskStatus

Expand Down Expand Up @@ -412,6 +413,7 @@ async def list_commits(
status: TaskStatus | None = None,
tags: dict[str, str | None] | None = None,
) -> list[Commit]:
tag_filters = normalize_tag_filters(tags)
if func_name:
hashes = await self._redis.zrevrange(_func_key(func_name), 0, -1)
else:
Expand All @@ -424,7 +426,7 @@ async def list_commits(
continue
if status is not None and commit.status != status:
continue
if tags is not None and not _matches_tags(commit, tags):
if tags is not None and not _matches_tags(commit, tag_filters):
continue
commits.append(commit)
if len(commits) >= limit:
Expand Down Expand Up @@ -579,8 +581,11 @@ async def delete_commit(self, hash: str) -> bool:
return await self._delete_commit(hash)

async def delete_by_tags(self, tags: dict[str, str | None]) -> int:
tag_filters = normalize_tag_filters(tags)
if not tag_filters:
return 0
set_keys: list[str] = []
for key, val in tags.items():
for key, val in tag_filters.items():
set_keys.append(_tag_key(key) if val is None else _tag_value_key(key, val))
if len(set_keys) == 1:
hashes = await self._redis.smembers(set_keys[0])
Expand Down
12 changes: 9 additions & 3 deletions src/cashet/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path
from typing import Any

from cashet._client_base import normalize_tag_filters
from cashet._runner import BlockingAsyncRunner
from cashet.models import Commit, ObjectRef, StorageTier, TaskDef, TaskStatus

Expand Down Expand Up @@ -390,6 +391,7 @@ def list_commits(
tags: dict[str, str | None] | None = None,
) -> list[Commit]:
conn = self._connect()
tag_filters = normalize_tag_filters(tags)
query = "SELECT * FROM commits WHERE 1=1"
params: list[Any] = []
if func_name:
Expand All @@ -398,8 +400,8 @@ def list_commits(
if status:
query += " AND status = ?"
params.append(status.value)
if tags:
for key, val in tags.items():
if tag_filters:
for key, val in tag_filters.items():
if val is None:
query += " AND json_extract(tags, ?) IS NOT NULL"
params.append(f"$.{key}")
Expand Down Expand Up @@ -665,9 +667,13 @@ def delete_commit(self, hash: str) -> bool:

def delete_by_tags(self, tags: dict[str, str | None]) -> int:
conn = self._connect(immediate=True)
tag_filters = normalize_tag_filters(tags)
if not tag_filters:
conn.execute("ROLLBACK")
return 0
query = "SELECT hash, output_hash, input_refs FROM commits WHERE 1=1"
params: list[Any] = []
for key, val in tags.items():
for key, val in tag_filters.items():
if val is None:
query += " AND json_extract(tags, ?) IS NOT NULL"
params.append(f"$.{key}")
Expand Down
Loading