|
| 1 | +"""Test Pydantic BaseModel state serialization with Redis checkpoint saver. |
| 2 | +
|
| 3 | +This tests the fix in PR #126 which changes the order of serde methods |
| 4 | +to prefer _revive_if_needed over _reviver for better Pydantic compatibility. |
| 5 | +""" |
| 6 | + |
| 7 | +from contextlib import contextmanager |
| 8 | +from typing import Any, List, Optional |
| 9 | + |
| 10 | +import pytest |
| 11 | +from pydantic import BaseModel, Field |
| 12 | +from testcontainers.redis import RedisContainer |
| 13 | + |
| 14 | +from langgraph.checkpoint.redis import RedisSaver |
| 15 | + |
| 16 | + |
| 17 | +class Address(BaseModel): |
| 18 | + """Nested Pydantic model for testing.""" |
| 19 | + |
| 20 | + street: str |
| 21 | + city: str |
| 22 | + zip_code: str = Field(default="00000") |
| 23 | + |
| 24 | + |
| 25 | +class Person(BaseModel): |
| 26 | + """Pydantic model with nested objects.""" |
| 27 | + |
| 28 | + name: str |
| 29 | + age: int |
| 30 | + address: Optional[Address] = None |
| 31 | + tags: List[str] = Field(default_factory=list) |
| 32 | + |
| 33 | + |
| 34 | +class ChatState(BaseModel): |
| 35 | + """Pydantic model representing a chat state.""" |
| 36 | + |
| 37 | + messages: List[dict] = Field(default_factory=list) |
| 38 | + user: Optional[Person] = None |
| 39 | + metadata: dict = Field(default_factory=dict) |
| 40 | + |
| 41 | + |
| 42 | +@contextmanager |
| 43 | +def _saver(redis_url: str): |
| 44 | + """Create a RedisSaver context manager.""" |
| 45 | + with RedisSaver.from_conn_string(redis_url) as saver: |
| 46 | + saver.setup() |
| 47 | + yield saver |
| 48 | + |
| 49 | + |
| 50 | +@pytest.fixture(scope="module") |
| 51 | +def redis_url(): |
| 52 | + """Provide a Redis URL using TestContainers.""" |
| 53 | + redis_container = RedisContainer("redis:8") |
| 54 | + redis_container.start() |
| 55 | + try: |
| 56 | + yield f"redis://{redis_container.get_container_host_ip()}:{redis_container.get_exposed_port(6379)}" |
| 57 | + finally: |
| 58 | + redis_container.stop() |
| 59 | + |
| 60 | + |
| 61 | +def test_pydantic_basemodel_in_checkpoint(redis_url: str) -> None: |
| 62 | + """Test that Pydantic BaseModel objects can be stored and retrieved. |
| 63 | +
|
| 64 | + This is the key test for PR #126 - ensures that the _revive_if_needed |
| 65 | + method is called properly for Pydantic model reconstruction. |
| 66 | + """ |
| 67 | + with _saver(redis_url) as saver: |
| 68 | + config = { |
| 69 | + "configurable": { |
| 70 | + "thread_id": "pydantic-test-1", |
| 71 | + "checkpoint_ns": "", |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + # Create a checkpoint with Pydantic models in channel_values |
| 76 | + checkpoint = { |
| 77 | + "v": 1, |
| 78 | + "ts": "2024-01-01T00:00:00.000000+00:00", |
| 79 | + "id": "checkpoint-pydantic-1", |
| 80 | + "channel_values": { |
| 81 | + "state": { |
| 82 | + "person": Person( |
| 83 | + name="Alice", |
| 84 | + age=30, |
| 85 | + address=Address( |
| 86 | + street="123 Main St", city="NYC", zip_code="10001" |
| 87 | + ), |
| 88 | + tags=["developer", "python"], |
| 89 | + ).model_dump(), |
| 90 | + "chat": ChatState( |
| 91 | + messages=[ |
| 92 | + {"role": "user", "content": "Hello"}, |
| 93 | + {"role": "assistant", "content": "Hi there!"}, |
| 94 | + ], |
| 95 | + user=Person(name="Bob", age=25), |
| 96 | + metadata={"session_id": "abc123"}, |
| 97 | + ).model_dump(), |
| 98 | + } |
| 99 | + }, |
| 100 | + "channel_versions": {"state": "1"}, |
| 101 | + "versions_seen": {}, |
| 102 | + "pending_sends": [], |
| 103 | + } |
| 104 | + |
| 105 | + # Store the checkpoint |
| 106 | + next_config = saver.put( |
| 107 | + config, checkpoint, {"source": "test", "step": 1}, {"state": "1"} |
| 108 | + ) |
| 109 | + |
| 110 | + # Retrieve the checkpoint |
| 111 | + retrieved = saver.get(next_config) |
| 112 | + |
| 113 | + assert retrieved is not None |
| 114 | + assert "state" in retrieved["channel_values"] |
| 115 | + state = retrieved["channel_values"]["state"] |
| 116 | + |
| 117 | + # Verify the person data was preserved |
| 118 | + assert state["person"]["name"] == "Alice" |
| 119 | + assert state["person"]["age"] == 30 |
| 120 | + assert state["person"]["address"]["city"] == "NYC" |
| 121 | + |
| 122 | + # Verify the chat state was preserved |
| 123 | + assert len(state["chat"]["messages"]) == 2 |
| 124 | + assert state["chat"]["user"]["name"] == "Bob" |
| 125 | + |
| 126 | + |
| 127 | +def test_nested_pydantic_models_roundtrip(redis_url: str) -> None: |
| 128 | + """Test deeply nested Pydantic models can survive a checkpoint roundtrip.""" |
| 129 | + with _saver(redis_url) as saver: |
| 130 | + config = { |
| 131 | + "configurable": { |
| 132 | + "thread_id": "nested-pydantic-test", |
| 133 | + "checkpoint_ns": "", |
| 134 | + } |
| 135 | + } |
| 136 | + |
| 137 | + # Create deeply nested structure |
| 138 | + nested_state = { |
| 139 | + "level1": { |
| 140 | + "level2": { |
| 141 | + "person": Person( |
| 142 | + name="Charlie", |
| 143 | + age=35, |
| 144 | + address=Address( |
| 145 | + street="456 Elm St", city="Boston", zip_code="02101" |
| 146 | + ), |
| 147 | + ).model_dump(), |
| 148 | + "items": [ |
| 149 | + {"id": 1, "data": Person(name="Dave", age=40).model_dump()}, |
| 150 | + {"id": 2, "data": Person(name="Eve", age=28).model_dump()}, |
| 151 | + ], |
| 152 | + } |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + checkpoint = { |
| 157 | + "v": 1, |
| 158 | + "ts": "2024-01-01T00:00:00.000000+00:00", |
| 159 | + "id": "checkpoint-nested-1", |
| 160 | + "channel_values": {"state": nested_state}, |
| 161 | + "channel_versions": {"state": "1"}, |
| 162 | + "versions_seen": {}, |
| 163 | + "pending_sends": [], |
| 164 | + } |
| 165 | + |
| 166 | + next_config = saver.put( |
| 167 | + config, checkpoint, {"source": "test", "step": 1}, {"state": "1"} |
| 168 | + ) |
| 169 | + |
| 170 | + retrieved = saver.get(next_config) |
| 171 | + |
| 172 | + assert retrieved is not None |
| 173 | + state = retrieved["channel_values"]["state"] |
| 174 | + |
| 175 | + # Verify deeply nested data |
| 176 | + level2 = state["level1"]["level2"] |
| 177 | + assert level2["person"]["name"] == "Charlie" |
| 178 | + assert level2["person"]["address"]["city"] == "Boston" |
| 179 | + |
| 180 | + # Verify list items |
| 181 | + assert len(level2["items"]) == 2 |
| 182 | + assert level2["items"][0]["data"]["name"] == "Dave" |
| 183 | + assert level2["items"][1]["data"]["name"] == "Eve" |
| 184 | + |
| 185 | + |
| 186 | +def test_pydantic_model_with_langchain_messages(redis_url: str) -> None: |
| 187 | + """Test Pydantic state with LangChain-style message objects. |
| 188 | +
|
| 189 | + This is the critical test case mentioned in PR #126 - when users |
| 190 | + use Pydantic BaseModel as state with LangChain message types. |
| 191 | + """ |
| 192 | + try: |
| 193 | + from langchain_core.messages import AIMessage, HumanMessage |
| 194 | + except ImportError: |
| 195 | + pytest.skip("langchain-core not installed") |
| 196 | + |
| 197 | + with _saver(redis_url) as saver: |
| 198 | + config = { |
| 199 | + "configurable": { |
| 200 | + "thread_id": "langchain-pydantic-test", |
| 201 | + "checkpoint_ns": "", |
| 202 | + } |
| 203 | + } |
| 204 | + |
| 205 | + # Simulate a state that mixes Pydantic with LangChain messages |
| 206 | + checkpoint = { |
| 207 | + "v": 1, |
| 208 | + "ts": "2024-01-01T00:00:00.000000+00:00", |
| 209 | + "id": "checkpoint-lc-pydantic-1", |
| 210 | + "channel_values": { |
| 211 | + "messages": [ |
| 212 | + HumanMessage(content="Hello, how are you?"), |
| 213 | + AIMessage(content="I'm doing well, thank you!"), |
| 214 | + ], |
| 215 | + "user_profile": Person( |
| 216 | + name="TestUser", age=25, tags=["test", "demo"] |
| 217 | + ).model_dump(), |
| 218 | + }, |
| 219 | + "channel_versions": {"messages": "1", "user_profile": "1"}, |
| 220 | + "versions_seen": {}, |
| 221 | + "pending_sends": [], |
| 222 | + } |
| 223 | + |
| 224 | + next_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {}) |
| 225 | + |
| 226 | + retrieved = saver.get(next_config) |
| 227 | + |
| 228 | + assert retrieved is not None |
| 229 | + |
| 230 | + # Verify messages were deserialized correctly |
| 231 | + messages = retrieved["channel_values"]["messages"] |
| 232 | + assert len(messages) == 2 |
| 233 | + |
| 234 | + # Messages should be proper LangChain message objects after deserialization |
| 235 | + # with _revive_if_needed properly handling them |
| 236 | + if hasattr(messages[0], "content"): |
| 237 | + # Message objects were properly reconstructed |
| 238 | + assert messages[0].content == "Hello, how are you?" |
| 239 | + assert messages[1].content == "I'm doing well, thank you!" |
| 240 | + else: |
| 241 | + # If they're still dicts, verify the content is there |
| 242 | + assert messages[0].get("content") == "Hello, how are you?" |
| 243 | + |
| 244 | + # Verify user profile data |
| 245 | + user_profile = retrieved["channel_values"]["user_profile"] |
| 246 | + assert user_profile["name"] == "TestUser" |
| 247 | + assert user_profile["age"] == 25 |
| 248 | + |
| 249 | + |
| 250 | +def test_revive_if_needed_fallback_behavior(redis_url: str) -> None: |
| 251 | + """Test that _revive_if_needed properly falls back when _reviver fails. |
| 252 | +
|
| 253 | + The PR #126 change ensures _revive_if_needed is tried first, which |
| 254 | + includes its own fallback to _reconstruct_from_constructor. |
| 255 | + """ |
| 256 | + with _saver(redis_url) as saver: |
| 257 | + config = { |
| 258 | + "configurable": { |
| 259 | + "thread_id": "fallback-test", |
| 260 | + "checkpoint_ns": "", |
| 261 | + } |
| 262 | + } |
| 263 | + |
| 264 | + # Create a checkpoint with complex nested data |
| 265 | + checkpoint = { |
| 266 | + "v": 1, |
| 267 | + "ts": "2024-01-01T00:00:00.000000+00:00", |
| 268 | + "id": "checkpoint-fallback-1", |
| 269 | + "channel_values": { |
| 270 | + "complex_state": { |
| 271 | + "nested_dict": { |
| 272 | + "key1": [1, 2, 3], |
| 273 | + "key2": {"a": "b", "c": {"d": "e"}}, |
| 274 | + }, |
| 275 | + "simple_value": 42, |
| 276 | + "string_value": "test", |
| 277 | + "bool_value": True, |
| 278 | + "none_value": None, |
| 279 | + } |
| 280 | + }, |
| 281 | + "channel_versions": {"complex_state": "1"}, |
| 282 | + "versions_seen": {}, |
| 283 | + "pending_sends": [], |
| 284 | + } |
| 285 | + |
| 286 | + next_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {}) |
| 287 | + |
| 288 | + retrieved = saver.get(next_config) |
| 289 | + |
| 290 | + assert retrieved is not None |
| 291 | + state = retrieved["channel_values"]["complex_state"] |
| 292 | + |
| 293 | + # All values should be preserved exactly |
| 294 | + assert state["nested_dict"]["key1"] == [1, 2, 3] |
| 295 | + assert state["nested_dict"]["key2"]["c"]["d"] == "e" |
| 296 | + assert state["simple_value"] == 42 |
| 297 | + assert state["string_value"] == "test" |
| 298 | + assert state["bool_value"] is True |
| 299 | + assert state["none_value"] is None |
0 commit comments