Skip to content

Commit 963efbf

Browse files
committed
test: add Pydantic BaseModel state serialization tests
Add comprehensive test suite for Pydantic model state handling: - test_pydantic_basemodel_in_checkpoint: Basic Pydantic model storage/retrieval - test_pydantic_model_with_langchain_messages: Mixed Pydantic + LangChain messages
1 parent 3f3af65 commit 963efbf

File tree

1 file changed

+299
-0
lines changed

1 file changed

+299
-0
lines changed

tests/test_pydantic_state.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
"""Test Pydantic BaseModel state serialization with Redis checkpoint saver.
2+
3+
This tests the fix in PR #126 which changes the order of serde methods
4+
to prefer _revive_if_needed over _reviver for better Pydantic compatibility.
5+
"""
6+
7+
from contextlib import contextmanager
8+
from typing import Any, List, Optional
9+
10+
import pytest
11+
from pydantic import BaseModel, Field
12+
from testcontainers.redis import RedisContainer
13+
14+
from langgraph.checkpoint.redis import RedisSaver
15+
16+
17+
class Address(BaseModel):
18+
"""Nested Pydantic model for testing."""
19+
20+
street: str
21+
city: str
22+
zip_code: str = Field(default="00000")
23+
24+
25+
class Person(BaseModel):
26+
"""Pydantic model with nested objects."""
27+
28+
name: str
29+
age: int
30+
address: Optional[Address] = None
31+
tags: List[str] = Field(default_factory=list)
32+
33+
34+
class ChatState(BaseModel):
35+
"""Pydantic model representing a chat state."""
36+
37+
messages: List[dict] = Field(default_factory=list)
38+
user: Optional[Person] = None
39+
metadata: dict = Field(default_factory=dict)
40+
41+
42+
@contextmanager
43+
def _saver(redis_url: str):
44+
"""Create a RedisSaver context manager."""
45+
with RedisSaver.from_conn_string(redis_url) as saver:
46+
saver.setup()
47+
yield saver
48+
49+
50+
@pytest.fixture(scope="module")
51+
def redis_url():
52+
"""Provide a Redis URL using TestContainers."""
53+
redis_container = RedisContainer("redis:8")
54+
redis_container.start()
55+
try:
56+
yield f"redis://{redis_container.get_container_host_ip()}:{redis_container.get_exposed_port(6379)}"
57+
finally:
58+
redis_container.stop()
59+
60+
61+
def test_pydantic_basemodel_in_checkpoint(redis_url: str) -> None:
62+
"""Test that Pydantic BaseModel objects can be stored and retrieved.
63+
64+
This is the key test for PR #126 - ensures that the _revive_if_needed
65+
method is called properly for Pydantic model reconstruction.
66+
"""
67+
with _saver(redis_url) as saver:
68+
config = {
69+
"configurable": {
70+
"thread_id": "pydantic-test-1",
71+
"checkpoint_ns": "",
72+
}
73+
}
74+
75+
# Create a checkpoint with Pydantic models in channel_values
76+
checkpoint = {
77+
"v": 1,
78+
"ts": "2024-01-01T00:00:00.000000+00:00",
79+
"id": "checkpoint-pydantic-1",
80+
"channel_values": {
81+
"state": {
82+
"person": Person(
83+
name="Alice",
84+
age=30,
85+
address=Address(
86+
street="123 Main St", city="NYC", zip_code="10001"
87+
),
88+
tags=["developer", "python"],
89+
).model_dump(),
90+
"chat": ChatState(
91+
messages=[
92+
{"role": "user", "content": "Hello"},
93+
{"role": "assistant", "content": "Hi there!"},
94+
],
95+
user=Person(name="Bob", age=25),
96+
metadata={"session_id": "abc123"},
97+
).model_dump(),
98+
}
99+
},
100+
"channel_versions": {"state": "1"},
101+
"versions_seen": {},
102+
"pending_sends": [],
103+
}
104+
105+
# Store the checkpoint
106+
next_config = saver.put(
107+
config, checkpoint, {"source": "test", "step": 1}, {"state": "1"}
108+
)
109+
110+
# Retrieve the checkpoint
111+
retrieved = saver.get(next_config)
112+
113+
assert retrieved is not None
114+
assert "state" in retrieved["channel_values"]
115+
state = retrieved["channel_values"]["state"]
116+
117+
# Verify the person data was preserved
118+
assert state["person"]["name"] == "Alice"
119+
assert state["person"]["age"] == 30
120+
assert state["person"]["address"]["city"] == "NYC"
121+
122+
# Verify the chat state was preserved
123+
assert len(state["chat"]["messages"]) == 2
124+
assert state["chat"]["user"]["name"] == "Bob"
125+
126+
127+
def test_nested_pydantic_models_roundtrip(redis_url: str) -> None:
128+
"""Test deeply nested Pydantic models can survive a checkpoint roundtrip."""
129+
with _saver(redis_url) as saver:
130+
config = {
131+
"configurable": {
132+
"thread_id": "nested-pydantic-test",
133+
"checkpoint_ns": "",
134+
}
135+
}
136+
137+
# Create deeply nested structure
138+
nested_state = {
139+
"level1": {
140+
"level2": {
141+
"person": Person(
142+
name="Charlie",
143+
age=35,
144+
address=Address(
145+
street="456 Elm St", city="Boston", zip_code="02101"
146+
),
147+
).model_dump(),
148+
"items": [
149+
{"id": 1, "data": Person(name="Dave", age=40).model_dump()},
150+
{"id": 2, "data": Person(name="Eve", age=28).model_dump()},
151+
],
152+
}
153+
}
154+
}
155+
156+
checkpoint = {
157+
"v": 1,
158+
"ts": "2024-01-01T00:00:00.000000+00:00",
159+
"id": "checkpoint-nested-1",
160+
"channel_values": {"state": nested_state},
161+
"channel_versions": {"state": "1"},
162+
"versions_seen": {},
163+
"pending_sends": [],
164+
}
165+
166+
next_config = saver.put(
167+
config, checkpoint, {"source": "test", "step": 1}, {"state": "1"}
168+
)
169+
170+
retrieved = saver.get(next_config)
171+
172+
assert retrieved is not None
173+
state = retrieved["channel_values"]["state"]
174+
175+
# Verify deeply nested data
176+
level2 = state["level1"]["level2"]
177+
assert level2["person"]["name"] == "Charlie"
178+
assert level2["person"]["address"]["city"] == "Boston"
179+
180+
# Verify list items
181+
assert len(level2["items"]) == 2
182+
assert level2["items"][0]["data"]["name"] == "Dave"
183+
assert level2["items"][1]["data"]["name"] == "Eve"
184+
185+
186+
def test_pydantic_model_with_langchain_messages(redis_url: str) -> None:
187+
"""Test Pydantic state with LangChain-style message objects.
188+
189+
This is the critical test case mentioned in PR #126 - when users
190+
use Pydantic BaseModel as state with LangChain message types.
191+
"""
192+
try:
193+
from langchain_core.messages import AIMessage, HumanMessage
194+
except ImportError:
195+
pytest.skip("langchain-core not installed")
196+
197+
with _saver(redis_url) as saver:
198+
config = {
199+
"configurable": {
200+
"thread_id": "langchain-pydantic-test",
201+
"checkpoint_ns": "",
202+
}
203+
}
204+
205+
# Simulate a state that mixes Pydantic with LangChain messages
206+
checkpoint = {
207+
"v": 1,
208+
"ts": "2024-01-01T00:00:00.000000+00:00",
209+
"id": "checkpoint-lc-pydantic-1",
210+
"channel_values": {
211+
"messages": [
212+
HumanMessage(content="Hello, how are you?"),
213+
AIMessage(content="I'm doing well, thank you!"),
214+
],
215+
"user_profile": Person(
216+
name="TestUser", age=25, tags=["test", "demo"]
217+
).model_dump(),
218+
},
219+
"channel_versions": {"messages": "1", "user_profile": "1"},
220+
"versions_seen": {},
221+
"pending_sends": [],
222+
}
223+
224+
next_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
225+
226+
retrieved = saver.get(next_config)
227+
228+
assert retrieved is not None
229+
230+
# Verify messages were deserialized correctly
231+
messages = retrieved["channel_values"]["messages"]
232+
assert len(messages) == 2
233+
234+
# Messages should be proper LangChain message objects after deserialization
235+
# with _revive_if_needed properly handling them
236+
if hasattr(messages[0], "content"):
237+
# Message objects were properly reconstructed
238+
assert messages[0].content == "Hello, how are you?"
239+
assert messages[1].content == "I'm doing well, thank you!"
240+
else:
241+
# If they're still dicts, verify the content is there
242+
assert messages[0].get("content") == "Hello, how are you?"
243+
244+
# Verify user profile data
245+
user_profile = retrieved["channel_values"]["user_profile"]
246+
assert user_profile["name"] == "TestUser"
247+
assert user_profile["age"] == 25
248+
249+
250+
def test_revive_if_needed_fallback_behavior(redis_url: str) -> None:
251+
"""Test that _revive_if_needed properly falls back when _reviver fails.
252+
253+
The PR #126 change ensures _revive_if_needed is tried first, which
254+
includes its own fallback to _reconstruct_from_constructor.
255+
"""
256+
with _saver(redis_url) as saver:
257+
config = {
258+
"configurable": {
259+
"thread_id": "fallback-test",
260+
"checkpoint_ns": "",
261+
}
262+
}
263+
264+
# Create a checkpoint with complex nested data
265+
checkpoint = {
266+
"v": 1,
267+
"ts": "2024-01-01T00:00:00.000000+00:00",
268+
"id": "checkpoint-fallback-1",
269+
"channel_values": {
270+
"complex_state": {
271+
"nested_dict": {
272+
"key1": [1, 2, 3],
273+
"key2": {"a": "b", "c": {"d": "e"}},
274+
},
275+
"simple_value": 42,
276+
"string_value": "test",
277+
"bool_value": True,
278+
"none_value": None,
279+
}
280+
},
281+
"channel_versions": {"complex_state": "1"},
282+
"versions_seen": {},
283+
"pending_sends": [],
284+
}
285+
286+
next_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {})
287+
288+
retrieved = saver.get(next_config)
289+
290+
assert retrieved is not None
291+
state = retrieved["channel_values"]["complex_state"]
292+
293+
# All values should be preserved exactly
294+
assert state["nested_dict"]["key1"] == [1, 2, 3]
295+
assert state["nested_dict"]["key2"]["c"]["d"] == "e"
296+
assert state["simple_value"] == 42
297+
assert state["string_value"] == "test"
298+
assert state["bool_value"] is True
299+
assert state["none_value"] is None

0 commit comments

Comments
 (0)