-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathredis_utils.py
More file actions
64 lines (52 loc) · 2.07 KB
/
redis_utils.py
File metadata and controls
64 lines (52 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Redis utilities for tracking chat completions via insertion IDs.
"""
import logging
from typing import Set, cast
import redis
logger = logging.getLogger(__name__)
DEFAULT_ROLLOUT_TTL_SECONDS = 60 * 60 * 24
def register_insertion_id(
redis_client: redis.Redis, rollout_id: str, insertion_id: str, ttl_seconds: int = DEFAULT_ROLLOUT_TTL_SECONDS
) -> bool:
"""Register an insertion_id for a rollout_id in Redis.
Tracks all expected completion insertion_ids for this rollout.
Args:
rollout_id: The rollout ID
insertion_id: Unique identifier for this specific completion
Returns:
True if successful, False otherwise
"""
try:
pipe = redis_client.pipeline()
pipe.sadd(rollout_id, insertion_id)
pipe.expire(rollout_id, int(ttl_seconds))
pipe.execute()
logger.info(f"Registered insertion_id {insertion_id} for rollout {rollout_id}")
return True
except Exception as e:
logger.error(f"Failed to register insertion_id for {rollout_id}: {e}")
return False
def get_insertion_ids(redis_client: redis.Redis, rollout_id: str) -> Set[str]:
"""Get all expected insertion_ids for a rollout_id from Redis.
Args:
rollout_id: The rollout ID to get insertion_ids for
Returns:
Set of insertion_id strings, empty set if none found or on error
"""
try:
raw = redis_client.smembers(rollout_id)
# Typing in redis stubs may be Awaitable[Set[Any]] | Set[Any]; at runtime this is a Set[bytes]
raw_ids = cast(Set[object], raw)
# Normalize to set[str]
insertion_ids: Set[str] = set()
for b in raw_ids:
try:
insertion_ids.add(b.decode("utf-8") if isinstance(b, (bytes, bytearray)) else cast(str, b))
except Exception:
continue
logger.debug(f"Found {len(insertion_ids)} expected insertion_ids for rollout {rollout_id}")
return insertion_ids
except Exception as e:
logger.error(f"Failed to get insertion_ids for {rollout_id}: {e}")
return set()