Skip to content

Commit 719f0fd

Browse files
committed
fix(store): add value serialization for LangChain message objects (#128)
RedisStore was passing values directly to redisvl.SearchIndex.load() which uses standard json.dumps() internally. This caused failures when storing values containing LangChain message objects (HumanMessage, AIMessage, etc.). Changes: - Add _serialize_value() method to BaseRedisStore that uses JsonPlusRedisSerializer - Add _deserialize_value() method for proper revival of complex objects - Update _row_to_item() and _row_to_search_item() to accept deserialize function - Update all call sites in RedisStore and AsyncRedisStore The serialization is smart: simple JSON-serializable values are stored as-is for backward compatibility, while complex objects get the serde wrapper. Closes #128
1 parent 87791a1 commit 719f0fd

File tree

3 files changed

+146
-10
lines changed

3 files changed

+146
-10
lines changed

langgraph/store/redis/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@ def _batch_get_ops(
255255
for idx, key in items:
256256
if key in key_to_row:
257257
data, doc_id = key_to_row[key]
258-
results[idx] = _row_to_item(namespace, data)
258+
results[idx] = _row_to_item(
259+
namespace, data, deserialize_fn=self._deserialize_value
260+
)
259261

260262
# Find the corresponding operation by looking it up in the operation list
261263
# This is needed because idx is the index in the overall operation list
@@ -585,6 +587,7 @@ def _batch_search_ops(
585587
_decode_ns(store_doc["prefix"]),
586588
store_doc,
587589
score=score,
590+
deserialize_fn=self._deserialize_value,
588591
)
589592
)
590593

@@ -651,7 +654,13 @@ def _batch_search_ops(
651654
)
652655
refresh_keys.append(vector_key)
653656

654-
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
657+
items.append(
658+
_row_to_search_item(
659+
_decode_ns(data["prefix"]),
660+
data,
661+
deserialize_fn=self._deserialize_value,
662+
)
663+
)
655664

656665
# Refresh TTL if requested
657666
if op.refresh_ttl and refresh_keys and self.ttl_config:

langgraph/store/redis/aio.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,9 @@ async def _batch_get_ops(
470470
for idx, key in items:
471471
if key in key_to_row:
472472
data, doc_id = key_to_row[key]
473-
results[idx] = _row_to_item(namespace, data)
473+
results[idx] = _row_to_item(
474+
namespace, data, deserialize_fn=self._deserialize_value
475+
)
474476

475477
# Find the corresponding operation by looking it up in the operation list
476478
# This is needed because idx is the index in the overall operation list
@@ -870,6 +872,7 @@ async def _batch_search_ops(
870872
_decode_ns(store_doc["prefix"]),
871873
store_doc,
872874
score=score,
875+
deserialize_fn=self._deserialize_value,
873876
)
874877
)
875878

@@ -937,7 +940,13 @@ async def _batch_search_ops(
937940
)
938941
refresh_keys.append(vector_key)
939942

940-
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
943+
items.append(
944+
_row_to_search_item(
945+
_decode_ns(data["prefix"]),
946+
data,
947+
deserialize_fn=self._deserialize_value,
948+
)
949+
)
941950

942951
# Refresh TTL if requested
943952
if op.refresh_ttl and refresh_keys and self.ttl_config:

langgraph/store/redis/base.py

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from datetime import datetime, timedelta, timezone
1010
from typing import (
1111
Any,
12+
Callable,
1213
Dict,
1314
Generic,
1415
Iterable,
@@ -40,6 +41,8 @@
4041
from redisvl.query.filter import Tag, Text
4142
from redisvl.utils.token_escaper import TokenEscaper
4243

44+
from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer
45+
4346
from .token_unescaper import TokenUnescaper
4447
from .types import IndexType, RedisClientType
4548

@@ -124,6 +127,9 @@ class BaseRedisStore(Generic[RedisClientType, IndexType]):
124127
supports_ttl: bool = True
125128
ttl_config: Optional[TTLConfig] = None
126129

130+
# Serializer for handling complex objects like LangChain messages
131+
_serde: JsonPlusRedisSerializer
132+
127133
def _apply_ttl_to_keys(
128134
self,
129135
main_key: str,
@@ -223,6 +229,8 @@ def __init__(
223229
self._redis = conn
224230
# Store cluster_mode; None means auto-detect in RedisStore or AsyncRedisStore
225231
self.cluster_mode = cluster_mode
232+
# Initialize the serializer for handling complex objects like LangChain messages
233+
self._serde = JsonPlusRedisSerializer()
226234

227235
# Store custom prefixes
228236
self.store_prefix = store_prefix
@@ -357,6 +365,84 @@ async def aset_client_info(self) -> None:
357365
# Silently fail if even echo doesn't work
358366
pass
359367

368+
def _serialize_value(self, value: Any) -> Any:
369+
"""Serialize a value for storage in Redis.
370+
371+
This method handles complex objects like LangChain messages by
372+
serializing them to a JSON-compatible format.
373+
374+
The method is smart about serialization:
375+
- If the value is a simple JSON-serializable dict/list, it's stored as-is
376+
- If the value contains complex objects (HumanMessage, etc.), it uses
377+
the serde wrapper format
378+
379+
Args:
380+
value: The value to serialize (can contain HumanMessage, AIMessage, etc.)
381+
382+
Returns:
383+
A JSON-serializable representation of the value
384+
"""
385+
import json
386+
387+
if value is None:
388+
return None
389+
390+
# First, try standard JSON serialization to check if it's needed
391+
try:
392+
json.dumps(value)
393+
# Value is already JSON-serializable, return as-is for backward
394+
# compatibility and to preserve filter functionality
395+
return value
396+
except (TypeError, ValueError):
397+
# Value contains non-JSON-serializable objects, use serde wrapper
398+
pass
399+
400+
# Use the serializer to handle complex objects
401+
type_str, data_bytes = self._serde.dumps_typed(value)
402+
# Store the serialized data with type info for proper deserialization
403+
return {
404+
"__serde_type__": type_str,
405+
"__serde_data__": (
406+
data_bytes.decode("utf-8") if type_str == "json" else data_bytes.hex()
407+
),
408+
}
409+
410+
def _deserialize_value(self, value: Any) -> Any:
411+
"""Deserialize a value from Redis storage.
412+
413+
This method handles both new serialized format and legacy plain values
414+
for backward compatibility.
415+
416+
Args:
417+
value: The value from Redis (may be serialized or plain)
418+
419+
Returns:
420+
The deserialized value with proper Python objects (HumanMessage, etc.)
421+
"""
422+
if value is None:
423+
return None
424+
425+
# Check if this is a serialized value (new format)
426+
if (
427+
isinstance(value, dict)
428+
and "__serde_type__" in value
429+
and "__serde_data__" in value
430+
):
431+
type_str = value["__serde_type__"]
432+
data_str = value["__serde_data__"]
433+
434+
# Convert back to bytes
435+
if type_str == "json":
436+
data_bytes = data_str.encode("utf-8")
437+
else:
438+
data_bytes = bytes.fromhex(data_str)
439+
440+
return self._serde.loads_typed((type_str, data_bytes))
441+
442+
# Legacy format: value is stored as-is (plain JSON-serializable data)
443+
# Return as-is for backward compatibility
444+
return value
445+
360446
def _get_batch_GET_ops_queries(
361447
self,
362448
get_ops: Sequence[tuple[int, GetOp]],
@@ -433,7 +519,7 @@ def _prepare_batch_PUT_queries(
433519
doc = RedisDocument(
434520
prefix=_namespace_to_text(op.namespace),
435521
key=op.key,
436-
value=op.value,
522+
value=self._serialize_value(op.value),
437523
created_at=now,
438524
updated_at=now,
439525
ttl_minutes=ttl_minutes,
@@ -568,10 +654,27 @@ def _decode_ns(ns: str) -> tuple[str, ...]:
568654
return tuple(_token_unescaper.unescape(ns).split("."))
569655

570656

571-
def _row_to_item(namespace: tuple[str, ...], row: dict[str, Any]) -> Item:
572-
"""Convert a row from Redis to an Item."""
657+
def _row_to_item(
658+
namespace: tuple[str, ...],
659+
row: dict[str, Any],
660+
deserialize_fn: Optional[Callable[[Any], Any]] = None,
661+
) -> Item:
662+
"""Convert a row from Redis to an Item.
663+
664+
Args:
665+
namespace: The namespace tuple for this item
666+
row: The raw row data from Redis
667+
deserialize_fn: Optional function to deserialize the value (handles
668+
LangChain messages, etc.)
669+
670+
Returns:
671+
An Item with properly deserialized value
672+
"""
673+
value = row["value"]
674+
if deserialize_fn is not None:
675+
value = deserialize_fn(value)
573676
return Item(
574-
value=row["value"],
677+
value=value,
575678
key=row["key"],
576679
namespace=namespace,
577680
created_at=datetime.fromtimestamp(row["created_at"] / 1_000_000, timezone.utc),
@@ -583,10 +686,25 @@ def _row_to_search_item(
583686
namespace: tuple[str, ...],
584687
row: dict[str, Any],
585688
score: Optional[float] = None,
689+
deserialize_fn: Optional[Callable[[Any], Any]] = None,
586690
) -> SearchItem:
587-
"""Convert a row from Redis to a SearchItem."""
691+
"""Convert a row from Redis to a SearchItem.
692+
693+
Args:
694+
namespace: The namespace tuple for this item
695+
row: The raw row data from Redis
696+
score: Optional similarity score from vector search
697+
deserialize_fn: Optional function to deserialize the value (handles
698+
LangChain messages, etc.)
699+
700+
Returns:
701+
A SearchItem with properly deserialized value
702+
"""
703+
value = row["value"]
704+
if deserialize_fn is not None:
705+
value = deserialize_fn(value)
588706
return SearchItem(
589-
value=row["value"],
707+
value=value,
590708
key=row["key"],
591709
namespace=namespace,
592710
created_at=datetime.fromtimestamp(row["created_at"] / 1_000_000, timezone.utc),

0 commit comments

Comments
 (0)