Skip to content

Commit 889129e

Browse files
committed
fix: Update mock_redis to handle session compression in tests
- Add SmartRedisMock class that correctly handles compressed (:z suffix) and uncompressed session keys in tests - Add set_uncompressed_return_value() for security tests on session data - Add set_compressed_return_value() for testing compressed session retrieval - Update test_redis_client_security.py to use new mock pattern - Update test_get_session_compressed to use set_compressed_return_value() - All 1142 tests now pass (was 10 failing due to mock not handling session compression logic correctly)
1 parent d5e03c7 commit 889129e

3 files changed

Lines changed: 87 additions & 15 deletions

File tree

python/tests/conftest.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,83 @@ def sample_documents():
137137
]
138138

139139

140+
class SmartRedisMock:
141+
"""Custom mock for Redis get() that handles session compression logic.
142+
143+
The redis_client tries compressed keys (ending in :z) first,
144+
then falls back to uncompressed. This mock returns None for compressed
145+
keys by default.
146+
147+
For session tests (uncompressed), use: mock.get.set_uncompressed_return_value(value)
148+
For session tests (compressed), use: mock.get.set_compressed_return_value(value)
149+
For other tests (template/query/embedding), use: mock.get.return_value = value
150+
For error tests, use: mock.get.side_effect = Exception(...)
151+
"""
152+
153+
def __init__(self):
154+
self._direct_value = None
155+
self._session_value = None
156+
self._compressed_value = None
157+
self._side_effect = None
158+
self.call_args_list = []
159+
160+
@property
161+
def return_value(self):
162+
return self._direct_value
163+
164+
@return_value.setter
165+
def return_value(self, value):
166+
self._direct_value = value
167+
168+
@property
169+
def side_effect(self):
170+
return self._side_effect
171+
172+
@side_effect.setter
173+
def side_effect(self, value):
174+
self._side_effect = value
175+
176+
def set_uncompressed_return_value(self, value):
177+
"""Set return value for uncompressed session keys."""
178+
self._session_value = value
179+
180+
def set_compressed_return_value(self, value):
181+
"""Set return value for compressed session keys (:z suffix)."""
182+
self._compressed_value = value
183+
184+
def __call__(self, key):
185+
self.call_args_list.append((key,))
186+
187+
# If side_effect is set, raise/call it
188+
if self._side_effect is not None:
189+
if isinstance(self._side_effect, BaseException):
190+
raise self._side_effect
191+
elif callable(self._side_effect):
192+
return self._side_effect(key)
193+
else:
194+
raise self._side_effect
195+
196+
# Compressed session keys (:z suffix)
197+
if key.endswith(":z"):
198+
if self._compressed_value is not None:
199+
return self._compressed_value
200+
return None
201+
# Session keys use session value if set
202+
if ":session:" in key and self._session_value is not None:
203+
return self._session_value
204+
# All other keys use direct value
205+
return self._direct_value
206+
207+
140208
@pytest.fixture
141209
def mock_redis():
142210
"""Mock Redis client for tests without Redis."""
143211
mock = MagicMock()
144212
mock.ping.return_value = True
145-
mock.get.return_value = None
213+
214+
# Use smart mock for get() to handle session compression
215+
mock.get = SmartRedisMock()
216+
146217
mock.setex.return_value = True
147218
mock.delete.return_value = 1
148219
mock.keys.return_value = []

python/tests/test_redis_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def test_get_session_compressed(self, redis_client, sample_session_state, mock_r
113113
compressed = zlib.compress(data.encode(), level=6)
114114
encoded = base64.b64encode(compressed).decode()
115115

116-
# Mock get to return compressed data on first call (compressed key)
117-
mock_redis.get.return_value = encoded
116+
# Mock get to return compressed data for compressed key (:z suffix)
117+
mock_redis.get.set_compressed_return_value(encoded)
118118
result = redis_client.get_session(sample_session_state.session_id)
119119
assert result is not None
120120
assert result.session_id == sample_session_state.session_id

python/tests/test_redis_client_security.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def redis_client(self, test_config, mock_redis):
3939

4040
def test_get_session_valid_data(self, redis_client, sample_session_state, mock_redis):
4141
"""Test get_session with valid data."""
42-
mock_redis.get.return_value = json.dumps(sample_session_state.to_dict())
42+
# Set return value for uncompressed key (compressed key returns None)
43+
mock_redis.get.set_uncompressed_return_value(json.dumps(sample_session_state.to_dict()))
4344

4445
result = redis_client.get_session(sample_session_state.session_id)
4546

@@ -60,15 +61,15 @@ def test_get_session_invalid_session_id_format(self, redis_client, mock_redis):
6061
"created_at": now,
6162
"updated_at": now,
6263
}
63-
mock_redis.get.return_value = json.dumps(invalid_data)
64+
mock_redis.get.set_uncompressed_return_value(json.dumps(invalid_data))
6465

6566
result = redis_client.get_session("test")
6667

6768
assert result is None # Validation failed
6869

6970
def test_get_session_malformed_json(self, redis_client, mock_redis, caplog):
7071
"""Test get_session handles malformed JSON gracefully."""
71-
mock_redis.get.return_value = "{ invalid json"
72+
mock_redis.get.set_uncompressed_return_value("{ invalid json")
7273

7374
with caplog.at_level(logging.ERROR):
7475
result = redis_client.get_session("test-session")
@@ -83,7 +84,7 @@ def test_get_session_missing_required_field(self, redis_client, mock_redis, capl
8384
"project_path": "path",
8485
# Missing created_at, updated_at, etc.
8586
}
86-
mock_redis.get.return_value = json.dumps(incomplete_data)
87+
mock_redis.get.set_uncompressed_return_value(json.dumps(incomplete_data))
8788

8889
with caplog.at_level(logging.ERROR):
8990
result = redis_client.get_session("test")
@@ -103,7 +104,7 @@ def test_get_session_invalid_timestamp(self, redis_client, mock_redis, caplog):
103104
"created_at": "invalid-timestamp", # Invalid format
104105
"updated_at": now,
105106
}
106-
mock_redis.get.return_value = json.dumps(invalid_data)
107+
mock_redis.get.set_uncompressed_return_value(json.dumps(invalid_data))
107108

108109
with caplog.at_level(logging.ERROR):
109110
result = redis_client.get_session("test")
@@ -123,7 +124,7 @@ def test_get_session_path_traversal_attempt(self, redis_client, mock_redis, capl
123124
"created_at": now,
124125
"updated_at": now,
125126
}
126-
mock_redis.get.return_value = json.dumps(traversal_data)
127+
mock_redis.get.set_uncompressed_return_value(json.dumps(traversal_data))
127128

128129
with caplog.at_level(logging.ERROR):
129130
result = redis_client.get_session("test")
@@ -143,7 +144,7 @@ def test_get_session_path_traversal_in_files(self, redis_client, mock_redis, cap
143144
"created_at": now,
144145
"updated_at": now,
145146
}
146-
mock_redis.get.return_value = json.dumps(traversal_data)
147+
mock_redis.get.set_uncompressed_return_value(json.dumps(traversal_data))
147148

148149
with caplog.at_level(logging.ERROR):
149150
result = redis_client.get_session("test")
@@ -165,7 +166,7 @@ def test_get_session_extra_fields_rejected(self, redis_client, mock_redis, caplo
165166
"malicious_field": "should not be here", # Extra field
166167
"another_injection": {"nested": "data"},
167168
}
168-
mock_redis.get.return_value = json.dumps(data_with_extras)
169+
mock_redis.get.set_uncompressed_return_value(json.dumps(data_with_extras))
169170

170171
with caplog.at_level(logging.ERROR):
171172
result = redis_client.get_session("test")
@@ -505,7 +506,7 @@ def test_session_id_redis_flush_injection(self, redis_client, mock_redis, caplog
505506
"created_at": now,
506507
"updated_at": now,
507508
}
508-
mock_redis.get.return_value = json.dumps(redis_injection)
509+
mock_redis.get.set_uncompressed_return_value(json.dumps(redis_injection))
509510

510511
with caplog.at_level(logging.ERROR):
511512
result = redis_client.get_session("whatever")
@@ -525,7 +526,7 @@ def test_session_id_shell_injection(self, redis_client, mock_redis, caplog):
525526
"created_at": now,
526527
"updated_at": now,
527528
}
528-
mock_redis.get.return_value = json.dumps(shell_injection)
529+
mock_redis.get.set_uncompressed_return_value(json.dumps(shell_injection))
529530

530531
with caplog.at_level(logging.ERROR):
531532
result = redis_client.get_session("whatever")
@@ -544,7 +545,7 @@ def test_session_id_backtick_injection(self, redis_client, mock_redis, caplog):
544545
"created_at": now,
545546
"updated_at": now,
546547
}
547-
mock_redis.get.return_value = json.dumps(backtick_injection)
548+
mock_redis.get.set_uncompressed_return_value(json.dumps(backtick_injection))
548549

549550
with caplog.at_level(logging.ERROR):
550551
result = redis_client.get_session("whatever")
@@ -567,7 +568,7 @@ def test_redis_error_logging(self, redis_client, mock_redis, caplog):
567568

568569
def test_session_state_maintains_type_safety(self, redis_client, sample_session_state, mock_redis):
569570
"""Test SessionState object has correct types after deserialization."""
570-
mock_redis.get.return_value = json.dumps(sample_session_state.to_dict())
571+
mock_redis.get.set_uncompressed_return_value(json.dumps(sample_session_state.to_dict()))
571572

572573
session = redis_client.get_session(sample_session_state.session_id)
573574

0 commit comments

Comments
 (0)