Skip to content

Commit f0ab606

Browse files
committed
fix(store): address PR review feedback for message serialization
Improvements based on PR #129 review: - Move json import to top of file for consistency - Add exact key check (set equality) to prevent collisions with user data that might contain __serde_type__ and __serde_data__ keys - Add comprehensive error handling for deserialization failures with logging - Add hex validation for non-json type deserialization - Fix async store to also apply value serialization in _prepare_batch_PUT_queries_async - Make type handling explicit in serialization (json vs bytes/msgpack) New tests added: - Backward compatibility test for plain JSON values (filters continue to work) - Serde key collision prevention test - Bytes value serialization test - Full async store test suite mirroring sync tests
1 parent 49ae3aa commit f0ab606

File tree

4 files changed

+318
-20
lines changed

4 files changed

+318
-20
lines changed

langgraph/store/redis/aio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ async def _aprepare_batch_PUT_queries(
578578
doc = RedisDocument(
579579
prefix=_namespace_to_text(op.namespace),
580580
key=op.key,
581-
value=op.value,
581+
value=self._serialize_value(op.value),
582582
created_at=now,
583583
updated_at=now,
584584
ttl_minutes=ttl_minutes,

langgraph/store/redis/base.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import copy
6+
import json
67
import logging
78
import threading
89
from collections import defaultdict
@@ -374,16 +375,17 @@ def _serialize_value(self, value: Any) -> Any:
374375
The method is smart about serialization:
375376
- If the value is a simple JSON-serializable dict/list, it's stored as-is
376377
- If the value contains complex objects (HumanMessage, etc.), it uses
377-
the serde wrapper format
378+
the serde wrapper format with __serde_type__ and __serde_data__ keys
379+
380+
Note: Values containing LangChain messages will be wrapped in a serde format,
381+
which means filters on nested fields won't work for such values.
378382
379383
Args:
380384
value: The value to serialize (can contain HumanMessage, AIMessage, etc.)
381385
382386
Returns:
383387
A JSON-serializable representation of the value
384388
"""
385-
import json
386-
387389
if value is None:
388390
return None
389391

@@ -393,18 +395,23 @@ def _serialize_value(self, value: Any) -> Any:
393395
# Value is already JSON-serializable, return as-is for backward
394396
# compatibility and to preserve filter functionality
395397
return value
396-
except (TypeError, ValueError):
398+
except TypeError:
397399
# Value contains non-JSON-serializable objects, use serde wrapper
398400
pass
399401

400402
# Use the serializer to handle complex objects
401403
type_str, data_bytes = self._serde.dumps_typed(value)
402404
# Store the serialized data with type info for proper deserialization
405+
# Handle different type formats explicitly for clarity
406+
if type_str == "json":
407+
data_encoded = data_bytes.decode("utf-8")
408+
else:
409+
# bytes, bytearray, msgpack, and other types are hex-encoded
410+
data_encoded = data_bytes.hex()
411+
403412
return {
404413
"__serde_type__": type_str,
405-
"__serde_data__": (
406-
data_bytes.decode("utf-8") if type_str == "json" else data_bytes.hex()
407-
),
414+
"__serde_data__": data_encoded,
408415
}
409416

410417
def _deserialize_value(self, value: Any) -> Any:
@@ -423,21 +430,40 @@ def _deserialize_value(self, value: Any) -> Any:
423430
return None
424431

425432
# Check if this is a serialized value (new format)
426-
if (
427-
isinstance(value, dict)
428-
and "__serde_type__" in value
429-
and "__serde_data__" in value
430-
):
433+
# Use exact key check to prevent collisions with user data
434+
if isinstance(value, dict) and set(value.keys()) == {
435+
"__serde_type__",
436+
"__serde_data__",
437+
}:
431438
type_str = value["__serde_type__"]
432439
data_str = value["__serde_data__"]
433440

434-
# Convert back to bytes
435-
if type_str == "json":
436-
data_bytes = data_str.encode("utf-8")
437-
else:
438-
data_bytes = bytes.fromhex(data_str)
439-
440-
return self._serde.loads_typed((type_str, data_bytes))
441+
try:
442+
# Convert back to bytes based on type
443+
if type_str == "json":
444+
data_bytes = data_str.encode("utf-8")
445+
else:
446+
# bytes, bytearray, msgpack types are hex-encoded
447+
data_bytes = bytes.fromhex(data_str)
448+
449+
return self._serde.loads_typed((type_str, data_bytes))
450+
except (ValueError, TypeError) as e:
451+
# Handle hex decoding errors or deserialization failures
452+
logger.error(
453+
"Failed to deserialize value from Redis: type=%r, error=%s",
454+
type_str,
455+
e,
456+
)
457+
# Return None to indicate deserialization failure
458+
return None
459+
except Exception as e:
460+
# Handle any other unexpected errors during deserialization
461+
logger.error(
462+
"Unexpected error deserializing value from Redis: type=%r, error=%s",
463+
type_str,
464+
e,
465+
)
466+
return None
441467

442468
# Legacy format: value is stored as-is (plain JSON-serializable data)
443469
# Return as-is for backward compatibility
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""Async tests for issue #128: HumanMessage serialization in AsyncRedisStore.
2+
3+
This test suite mirrors the sync tests to ensure the async store implementation
4+
properly handles LangChain message serialization.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from typing import AsyncIterator
10+
11+
import pytest
12+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
13+
14+
from langgraph.store.redis.aio import AsyncRedisStore
15+
16+
17+
@pytest.fixture
18+
async def async_store(redis_url: str) -> AsyncIterator[AsyncRedisStore]:
19+
"""Fixture to create an async Redis store."""
20+
async with AsyncRedisStore.from_conn_string(redis_url) as store:
21+
await store.setup()
22+
yield store
23+
24+
25+
class TestIssue128AsyncMessageSerialization:
26+
"""Async test suite for issue #128: HumanMessage serialization."""
27+
28+
@pytest.mark.asyncio
29+
async def test_store_human_message_in_value(
30+
self, async_store: AsyncRedisStore
31+
) -> None:
32+
"""Test storing a HumanMessage object in async store."""
33+
namespace = ("test", "async_messages")
34+
key = "message1"
35+
36+
value = {
37+
"messages": [
38+
HumanMessage(content="Hello from async!"),
39+
]
40+
}
41+
42+
await async_store.aput(namespace, key, value)
43+
44+
item = await async_store.aget(namespace, key)
45+
assert item is not None
46+
assert "messages" in item.value
47+
assert len(item.value["messages"]) == 1
48+
49+
retrieved_message = item.value["messages"][0]
50+
assert isinstance(retrieved_message, HumanMessage)
51+
assert retrieved_message.content == "Hello from async!"
52+
53+
@pytest.mark.asyncio
54+
async def test_store_multiple_message_types(
55+
self, async_store: AsyncRedisStore
56+
) -> None:
57+
"""Test storing multiple message types in async store."""
58+
namespace = ("test", "async_conversation")
59+
key = "conv1"
60+
61+
value = {
62+
"messages": [
63+
SystemMessage(content="You are a helpful assistant."),
64+
HumanMessage(content="What is Python?"),
65+
AIMessage(content="Python is a programming language."),
66+
]
67+
}
68+
69+
await async_store.aput(namespace, key, value)
70+
71+
item = await async_store.aget(namespace, key)
72+
assert item is not None
73+
assert len(item.value["messages"]) == 3
74+
75+
messages = item.value["messages"]
76+
assert isinstance(messages[0], SystemMessage)
77+
assert isinstance(messages[1], HumanMessage)
78+
assert isinstance(messages[2], AIMessage)
79+
80+
@pytest.mark.asyncio
81+
async def test_store_ai_message_with_tool_calls(
82+
self, async_store: AsyncRedisStore
83+
) -> None:
84+
"""Test storing AIMessage with tool calls in async store."""
85+
namespace = ("test", "async_tools")
86+
key = "tool_call"
87+
88+
value = {
89+
"messages": [
90+
AIMessage(
91+
content="",
92+
tool_calls=[
93+
{
94+
"id": "call_123",
95+
"name": "get_weather",
96+
"args": {"location": "NYC"},
97+
}
98+
],
99+
),
100+
]
101+
}
102+
103+
await async_store.aput(namespace, key, value)
104+
105+
item = await async_store.aget(namespace, key)
106+
assert item is not None
107+
108+
retrieved_message = item.value["messages"][0]
109+
assert isinstance(retrieved_message, AIMessage)
110+
assert len(retrieved_message.tool_calls) == 1
111+
assert retrieved_message.tool_calls[0]["name"] == "get_weather"
112+
113+
@pytest.mark.asyncio
114+
async def test_search_with_message_values(
115+
self, async_store: AsyncRedisStore
116+
) -> None:
117+
"""Test async search for items containing message values."""
118+
namespace = ("test", "async_searchable")
119+
120+
for i in range(3):
121+
await async_store.aput(
122+
namespace,
123+
f"msg{i}",
124+
{
125+
"topic": f"topic_{i}",
126+
"messages": [HumanMessage(content=f"Async Message {i}")],
127+
},
128+
)
129+
130+
results = await async_store.asearch(namespace)
131+
assert len(results) == 3
132+
133+
for result in results:
134+
assert "messages" in result.value
135+
assert isinstance(result.value["messages"][0], HumanMessage)
136+
137+
@pytest.mark.asyncio
138+
async def test_backward_compatibility_plain_json(
139+
self, async_store: AsyncRedisStore
140+
) -> None:
141+
"""Test that plain JSON values work correctly in async store."""
142+
namespace = ("test", "async_plain_json")
143+
key = "simple_value"
144+
145+
value = {
146+
"name": "async_test",
147+
"count": 42,
148+
"tags": ["a", "b", "c"],
149+
}
150+
151+
await async_store.aput(namespace, key, value)
152+
153+
item = await async_store.aget(namespace, key)
154+
assert item is not None
155+
assert item.value == value
156+
157+
# Verify filters work on plain JSON values
158+
results = await async_store.asearch(namespace, filter={"name": "async_test"})
159+
assert len(results) == 1
160+
assert results[0].key == key
161+
162+
@pytest.mark.asyncio
163+
async def test_serde_key_collision_prevention(
164+
self, async_store: AsyncRedisStore
165+
) -> None:
166+
"""Test that user data with serde-like keys is handled correctly."""
167+
namespace = ("test", "async_collision")
168+
key = "user_data_with_serde_keys"
169+
170+
value = {
171+
"__serde_type__": "user_defined_type",
172+
"__serde_data__": "user_defined_data",
173+
"extra_key": "this makes it not a serde wrapper",
174+
}
175+
176+
await async_store.aput(namespace, key, value)
177+
178+
item = await async_store.aget(namespace, key)
179+
assert item is not None
180+
assert item.value == value
181+
assert item.value["extra_key"] == "this makes it not a serde wrapper"
182+
183+
@pytest.mark.asyncio
184+
async def test_bytes_value_serialization(
185+
self, async_store: AsyncRedisStore
186+
) -> None:
187+
"""Test that bytes values are properly serialized in async store."""
188+
namespace = ("test", "async_bytes")
189+
key = "binary_data"
190+
191+
value = {
192+
"data": b"async binary content",
193+
"name": "test",
194+
}
195+
196+
await async_store.aput(namespace, key, value)
197+
198+
item = await async_store.aget(namespace, key)
199+
assert item is not None
200+
assert item.value["data"] == b"async binary content"
201+
assert item.value["name"] == "test"

tests/test_issue_128_store_message_serialization.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,74 @@ def test_search_with_message_values(self, store: RedisStore) -> None:
209209
assert item.value["topic"] == "topic_1"
210210
assert isinstance(item.value["messages"][0], HumanMessage)
211211
assert item.value["messages"][0].content == "Message 1"
212+
213+
def test_backward_compatibility_plain_json(self, store: RedisStore) -> None:
214+
"""Test that plain JSON-serializable values are stored as-is.
215+
216+
This ensures backward compatibility: simple dict/list values should
217+
be stored without the serde wrapper, preserving filter functionality.
218+
"""
219+
namespace = ("test", "plain_json")
220+
key = "simple_value"
221+
222+
# Plain JSON-serializable value (no LangChain objects)
223+
value = {
224+
"name": "test",
225+
"count": 42,
226+
"tags": ["a", "b", "c"],
227+
"nested": {"foo": "bar"},
228+
}
229+
230+
store.put(namespace, key, value)
231+
232+
# Verify we can retrieve it
233+
item = store.get(namespace, key)
234+
assert item is not None
235+
assert item.value == value
236+
237+
# Verify filters work on plain JSON values
238+
results = store.search(namespace, filter={"name": "test"})
239+
assert len(results) == 1
240+
assert results[0].key == key
241+
242+
def test_serde_key_collision_prevention(self, store: RedisStore) -> None:
243+
"""Test that user data with serde-like keys is not incorrectly deserialized.
244+
245+
If a user stores a dict with __serde_type__ and __serde_data__ keys
246+
but also other keys, it should NOT be treated as serialized data.
247+
"""
248+
namespace = ("test", "collision")
249+
key = "user_data_with_serde_keys"
250+
251+
# User data that happens to have serde-like keys plus other keys
252+
value = {
253+
"__serde_type__": "user_defined_type",
254+
"__serde_data__": "user_defined_data",
255+
"extra_key": "this makes it not a serde wrapper",
256+
}
257+
258+
store.put(namespace, key, value)
259+
260+
# Verify the value is retrieved as-is (not deserialized)
261+
item = store.get(namespace, key)
262+
assert item is not None
263+
assert item.value == value
264+
assert item.value["extra_key"] == "this makes it not a serde wrapper"
265+
266+
def test_bytes_value_serialization(self, store: RedisStore) -> None:
267+
"""Test that bytes values are properly serialized and deserialized."""
268+
namespace = ("test", "bytes")
269+
key = "binary_data"
270+
271+
# Value containing bytes (not JSON-serializable)
272+
value = {
273+
"data": b"binary content here",
274+
"name": "test",
275+
}
276+
277+
store.put(namespace, key, value)
278+
279+
item = store.get(namespace, key)
280+
assert item is not None
281+
assert item.value["data"] == b"binary content here"
282+
assert item.value["name"] == "test"

0 commit comments

Comments
 (0)