Skip to content

Commit 303d4e7

Browse files
authored
feat: Enable custom prefixes for checkpoint savers (#125) (#127)
* Reuse `STORE_PREFIX` and `STORE_VECTOR_PREFIX` constants * Implement custom prefixes for checkpoint savers (closes #125) * Fix default values for `AsyncRedisSaver.from_conn_string` prefixes * Clean up index schema organization * Name indexes using custom prefix * Have both Savers exist concurrently in isolation tests
1 parent 5ce5acd commit 303d4e7

File tree

13 files changed

+731
-262
lines changed

13 files changed

+731
-262
lines changed

langgraph/checkpoint/redis/__init__.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,20 @@ def __init__(
6767
redis_client: Optional[Union[Redis, RedisCluster]] = None,
6868
connection_args: Optional[Dict[str, Any]] = None,
6969
ttl: Optional[Dict[str, Any]] = None,
70+
checkpoint_prefix: str = CHECKPOINT_PREFIX,
71+
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
72+
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
7073
) -> None:
7174
super().__init__(
7275
redis_url=redis_url,
7376
redis_client=redis_client,
7477
connection_args=connection_args,
7578
ttl=ttl,
79+
checkpoint_prefix=checkpoint_prefix,
80+
checkpoint_blob_prefix=checkpoint_blob_prefix,
81+
checkpoint_write_prefix=checkpoint_write_prefix,
7682
)
77-
# Pre-compute common prefixes for performance
78-
self._checkpoint_prefix = CHECKPOINT_PREFIX
79-
self._checkpoint_blob_prefix = CHECKPOINT_BLOB_PREFIX
80-
self._checkpoint_write_prefix = CHECKPOINT_WRITE_PREFIX
83+
# Prefixes are now set in BaseRedisSaver.__init__
8184
self._separator = REDIS_KEY_SEPARATOR
8285

8386
# Instance-level cache for frequently used keys (limited size to prevent memory issues)
@@ -116,13 +119,13 @@ def configure_client(
116119

117120
def create_indexes(self) -> None:
118121
self.checkpoints_index = SearchIndex.from_dict(
119-
self.SCHEMAS[0], redis_client=self._redis
122+
self.checkpoints_schema, redis_client=self._redis
120123
)
121124
self.checkpoint_blobs_index = SearchIndex.from_dict(
122-
self.SCHEMAS[1], redis_client=self._redis
125+
self.blobs_schema, redis_client=self._redis
123126
)
124127
self.checkpoint_writes_index = SearchIndex.from_dict(
125-
self.SCHEMAS[2], redis_client=self._redis
128+
self.writes_schema, redis_client=self._redis
126129
)
127130

128131
def _make_redis_checkpoint_key_cached(
@@ -848,7 +851,7 @@ def _get_write_keys_from_search(
848851
write_results = self.checkpoint_writes_index.search(write_query)
849852

850853
return [
851-
BaseRedisSaver._make_redis_checkpoint_writes_key(
854+
self._make_redis_checkpoint_writes_key(
852855
to_storage_safe_id(thread_id),
853856
to_storage_safe_str(checkpoint_ns),
854857
to_storage_safe_id(checkpoint_id),
@@ -1119,6 +1122,9 @@ def from_conn_string(
11191122
redis_client: Optional[Union[Redis, RedisCluster]] = None,
11201123
connection_args: Optional[Dict[str, Any]] = None,
11211124
ttl: Optional[Dict[str, Any]] = None,
1125+
checkpoint_prefix: str = CHECKPOINT_PREFIX,
1126+
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
1127+
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
11221128
) -> Iterator[RedisSaver]:
11231129
"""Create a new RedisSaver instance."""
11241130
saver: Optional[RedisSaver] = None
@@ -1128,6 +1134,9 @@ def from_conn_string(
11281134
redis_client=redis_client,
11291135
connection_args=connection_args,
11301136
ttl=ttl,
1137+
checkpoint_prefix=checkpoint_prefix,
1138+
checkpoint_blob_prefix=checkpoint_blob_prefix,
1139+
checkpoint_write_prefix=checkpoint_write_prefix,
11311140
)
11321141

11331142
yield saver
@@ -1615,7 +1624,7 @@ def delete_thread(self, thread_id: str) -> None:
16151624
channel = getattr(doc, "channel", "")
16161625
version = getattr(doc, "version", "")
16171626

1618-
blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key(
1627+
blob_key = self._make_redis_checkpoint_blob_key(
16191628
storage_safe_thread_id, checkpoint_ns, channel, version
16201629
)
16211630
keys_to_delete.append(blob_key)
@@ -1635,7 +1644,7 @@ def delete_thread(self, thread_id: str) -> None:
16351644
task_id = getattr(doc, "task_id", "")
16361645
idx = getattr(doc, "idx", 0)
16371646

1638-
write_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
1647+
write_key = self._make_redis_checkpoint_writes_key(
16391648
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
16401649
)
16411650
keys_to_delete.append(write_key)

langgraph/checkpoint/redis/aio.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
from redisvl.query.filter import Num, Tag
4141
from ulid import ULID
4242

43-
from langgraph.checkpoint.redis.base import BaseRedisSaver
43+
from langgraph.checkpoint.redis.base import (
44+
BaseRedisSaver,
45+
CHECKPOINT_BLOB_PREFIX,
46+
CHECKPOINT_PREFIX,
47+
CHECKPOINT_WRITE_PREFIX,
48+
REDIS_KEY_SEPARATOR,
49+
)
4450
from langgraph.checkpoint.redis.key_registry import (
4551
AsyncCheckpointKeyRegistry as AsyncKeyRegistry,
4652
)
@@ -81,30 +87,25 @@ def __init__(
8187
redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None,
8288
connection_args: Optional[Dict[str, Any]] = None,
8389
ttl: Optional[Dict[str, Any]] = None,
90+
checkpoint_prefix: str = CHECKPOINT_PREFIX,
91+
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
92+
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
8493
) -> None:
8594
super().__init__(
8695
redis_url=redis_url,
8796
redis_client=redis_client,
8897
connection_args=connection_args,
8998
ttl=ttl,
99+
checkpoint_prefix=checkpoint_prefix,
100+
checkpoint_blob_prefix=checkpoint_blob_prefix,
101+
checkpoint_write_prefix=checkpoint_write_prefix,
90102
)
91103
self.loop = asyncio.get_running_loop()
92104

93105
# Instance-level cache for frequently used keys (limited size to prevent memory issues)
94106
self._key_cache: Dict[str, str] = {}
95107
self._key_cache_max_size = 1000 # Configurable limit
96108

97-
# Pre-compute common prefixes for performance
98-
from langgraph.checkpoint.redis.base import (
99-
CHECKPOINT_BLOB_PREFIX,
100-
CHECKPOINT_PREFIX,
101-
CHECKPOINT_WRITE_PREFIX,
102-
REDIS_KEY_SEPARATOR,
103-
)
104-
105-
self._checkpoint_prefix = CHECKPOINT_PREFIX
106-
self._checkpoint_blob_prefix = CHECKPOINT_BLOB_PREFIX
107-
self._checkpoint_write_prefix = CHECKPOINT_WRITE_PREFIX
108109
self._separator = REDIS_KEY_SEPARATOR
109110

110111
def configure_client(
@@ -128,13 +129,13 @@ def configure_client(
128129
def create_indexes(self) -> None:
129130
"""Create indexes without connecting to Redis."""
130131
self.checkpoints_index = AsyncSearchIndex.from_dict(
131-
self.SCHEMAS[0], redis_client=self._redis
132+
self.checkpoints_schema, redis_client=self._redis
132133
)
133134
self.checkpoint_blobs_index = AsyncSearchIndex.from_dict(
134-
self.SCHEMAS[1], redis_client=self._redis
135+
self.blobs_schema, redis_client=self._redis
135136
)
136137
self.checkpoint_writes_index = AsyncSearchIndex.from_dict(
137-
self.SCHEMAS[2], redis_client=self._redis
138+
self.writes_schema, redis_client=self._redis
138139
)
139140

140141
def _make_redis_checkpoint_key_cached(
@@ -375,7 +376,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
375376
storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)
376377

377378
# Construct direct key for checkpoint data
378-
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
379+
checkpoint_key = self._make_redis_checkpoint_key(
379380
storage_safe_thread_id,
380381
storage_safe_checkpoint_ns,
381382
storage_safe_checkpoint_id,
@@ -476,7 +477,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
476477
# If we didn't get TTL from pipeline (i.e., came from else branch), fetch it now
477478
if "current_ttl" not in locals():
478479
# Get the checkpoint key
479-
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
480+
checkpoint_key = self._make_redis_checkpoint_key(
480481
to_storage_safe_id(doc_thread_id),
481482
to_storage_safe_str(doc_checkpoint_ns),
482483
to_storage_safe_id(doc_checkpoint_id),
@@ -1054,7 +1055,7 @@ async def aput(
10541055
}
10551056

10561057
# Prepare checkpoint key
1057-
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
1058+
checkpoint_key = self._make_redis_checkpoint_key(
10581059
storage_safe_thread_id,
10591060
storage_safe_checkpoint_ns,
10601061
storage_safe_checkpoint_id,
@@ -1441,12 +1442,18 @@ async def from_conn_string(
14411442
redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None,
14421443
connection_args: Optional[Dict[str, Any]] = None,
14431444
ttl: Optional[Dict[str, Any]] = None,
1445+
checkpoint_prefix: str = CHECKPOINT_PREFIX,
1446+
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
1447+
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
14441448
) -> AsyncIterator[AsyncRedisSaver]:
14451449
async with cls(
14461450
redis_url=redis_url,
14471451
redis_client=redis_client,
14481452
connection_args=connection_args,
14491453
ttl=ttl,
1454+
checkpoint_prefix=checkpoint_prefix,
1455+
checkpoint_blob_prefix=checkpoint_blob_prefix,
1456+
checkpoint_write_prefix=checkpoint_write_prefix,
14501457
) as saver:
14511458
yield saver
14521459

@@ -1980,7 +1987,7 @@ async def adelete_thread(self, thread_id: str) -> None:
19801987
checkpoint_namespaces.add(checkpoint_ns)
19811988

19821989
# Delete checkpoint key
1983-
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
1990+
checkpoint_key = self._make_redis_checkpoint_key(
19841991
storage_safe_thread_id, checkpoint_ns, checkpoint_id
19851992
)
19861993
keys_to_delete.append(checkpoint_key)
@@ -2004,7 +2011,7 @@ async def adelete_thread(self, thread_id: str) -> None:
20042011
channel = getattr(doc, "channel", "")
20052012
version = getattr(doc, "version", "")
20062013

2007-
blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key(
2014+
blob_key = self._make_redis_checkpoint_blob_key(
20082015
storage_safe_thread_id, checkpoint_ns, channel, version
20092016
)
20102017
keys_to_delete.append(blob_key)
@@ -2024,7 +2031,7 @@ async def adelete_thread(self, thread_id: str) -> None:
20242031
task_id = getattr(doc, "task_id", "")
20252032
idx = getattr(doc, "idx", 0)
20262033

2027-
write_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
2034+
write_key = self._make_redis_checkpoint_writes_key(
20282035
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
20292036
)
20302037
keys_to_delete.append(write_key)

langgraph/checkpoint/redis/ashallow.py

Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -38,52 +38,6 @@
3838
to_storage_safe_str,
3939
)
4040

41-
SCHEMAS = [
42-
{
43-
"index": {
44-
"name": "checkpoints",
45-
"prefix": CHECKPOINT_PREFIX + REDIS_KEY_SEPARATOR,
46-
"storage_type": "json",
47-
},
48-
"fields": [
49-
{"name": "thread_id", "type": "tag"},
50-
{"name": "checkpoint_ns", "type": "tag"},
51-
{"name": "source", "type": "tag"},
52-
{"name": "step", "type": "numeric"},
53-
],
54-
},
55-
{
56-
"index": {
57-
"name": "checkpoints_blobs",
58-
"prefix": CHECKPOINT_BLOB_PREFIX + REDIS_KEY_SEPARATOR,
59-
"storage_type": "json",
60-
},
61-
"fields": [
62-
{"name": "thread_id", "type": "tag"},
63-
{"name": "checkpoint_ns", "type": "tag"},
64-
{"name": "channel", "type": "tag"},
65-
{"name": "type", "type": "tag"},
66-
],
67-
},
68-
{
69-
"index": {
70-
"name": "checkpoint_writes",
71-
"prefix": CHECKPOINT_WRITE_PREFIX + REDIS_KEY_SEPARATOR,
72-
"storage_type": "json",
73-
},
74-
"fields": [
75-
{"name": "thread_id", "type": "tag"},
76-
{"name": "checkpoint_ns", "type": "tag"},
77-
{"name": "checkpoint_id", "type": "tag"},
78-
{"name": "task_id", "type": "tag"},
79-
{"name": "idx", "type": "numeric"},
80-
{"name": "channel", "type": "tag"},
81-
{"name": "type", "type": "tag"},
82-
],
83-
},
84-
]
85-
86-
8741
class AsyncShallowRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]):
8842
"""Async Redis implementation that only stores the most recent checkpoint."""
8943

@@ -101,12 +55,18 @@ def __init__(
10155
redis_client: Optional[AsyncRedis] = None,
10256
connection_args: Optional[dict[str, Any]] = None,
10357
ttl: Optional[dict[str, Any]] = None,
58+
checkpoint_prefix: str = CHECKPOINT_PREFIX,
59+
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
60+
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
10461
) -> None:
10562
super().__init__(
10663
redis_url=redis_url,
10764
redis_client=redis_client,
10865
connection_args=connection_args,
10966
ttl=ttl,
67+
checkpoint_prefix=checkpoint_prefix,
68+
checkpoint_blob_prefix=checkpoint_blob_prefix,
69+
checkpoint_write_prefix=checkpoint_write_prefix,
11070
)
11171
self.loop = asyncio.get_running_loop()
11272

@@ -115,9 +75,6 @@ def __init__(
11575
self._key_cache_max_size = 1000 # Configurable limit
11676
self._channel_cache: Dict[str, Any] = {}
11777

118-
# Cache commonly used prefixes
119-
self._checkpoint_prefix = CHECKPOINT_PREFIX
120-
self._checkpoint_write_prefix = CHECKPOINT_WRITE_PREFIX
12178
self._separator = REDIS_KEY_SEPARATOR
12279

12380
async def __aenter__(self) -> AsyncShallowRedisSaver:
@@ -158,13 +115,19 @@ async def from_conn_string(
158115
redis_client: Optional[AsyncRedis] = None,
159116
connection_args: Optional[dict[str, Any]] = None,
160117
ttl: Optional[dict[str, Any]] = None,
118+
checkpoint_prefix: str = CHECKPOINT_PREFIX,
119+
checkpoint_blob_prefix: str = CHECKPOINT_BLOB_PREFIX,
120+
checkpoint_write_prefix: str = CHECKPOINT_WRITE_PREFIX,
161121
) -> AsyncIterator[AsyncShallowRedisSaver]:
162122
"""Create a new AsyncShallowRedisSaver instance."""
163123
async with cls(
164124
redis_url=redis_url,
165125
redis_client=redis_client,
166126
connection_args=connection_args,
167127
ttl=ttl,
128+
checkpoint_prefix=checkpoint_prefix,
129+
checkpoint_blob_prefix=checkpoint_blob_prefix,
130+
checkpoint_write_prefix=checkpoint_write_prefix,
168131
) as saver:
169132
yield saver
170133

@@ -733,14 +696,14 @@ def configure_client(
733696
def create_indexes(self) -> None:
734697
"""Create indexes without connecting to Redis."""
735698
self.checkpoints_index = AsyncSearchIndex.from_dict(
736-
self.SCHEMAS[0], redis_client=self._redis
699+
self.checkpoints_schema, redis_client=self._redis
737700
)
738701
# Shallow implementation doesn't use blobs, but base class requires the attribute
739702
self.checkpoint_blobs_index = AsyncSearchIndex.from_dict(
740-
self.SCHEMAS[1], redis_client=self._redis
703+
self.blobs_schema, redis_client=self._redis
741704
)
742705
self.checkpoint_writes_index = AsyncSearchIndex.from_dict(
743-
self.SCHEMAS[2], redis_client=self._redis
706+
self.writes_schema, redis_client=self._redis
744707
)
745708

746709
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
@@ -837,7 +800,7 @@ def _make_redis_checkpoint_writes_key_cached(
837800
)
838801
if cache_key not in self._key_cache:
839802
self._key_cache[cache_key] = (
840-
BaseRedisSaver._make_redis_checkpoint_writes_key(
803+
self._make_redis_checkpoint_writes_key(
841804
thread_id, checkpoint_ns, checkpoint_id, task_id, idx
842805
)
843806
)
@@ -884,7 +847,7 @@ def _make_shallow_redis_checkpoint_blob_key_cached(
884847
if len(self._key_cache) >= self._key_cache_max_size:
885848
# Remove oldest entry when cache is full
886849
self._key_cache.pop(next(iter(self._key_cache)))
887-
self._key_cache[cache_key] = BaseRedisSaver._make_redis_checkpoint_blob_key(
850+
self._key_cache[cache_key] = self._make_redis_checkpoint_blob_key(
888851
thread_id, checkpoint_ns, channel, version
889852
)
890853
return self._key_cache[cache_key]

0 commit comments

Comments
 (0)