diff --git a/src/praisonai-agents/praisonaiagents/session/hierarchy.py b/src/praisonai-agents/praisonaiagents/session/hierarchy.py index 58075d9d4..9938aa984 100644 --- a/src/praisonai-agents/praisonaiagents/session/hierarchy.py +++ b/src/praisonai-agents/praisonaiagents/session/hierarchy.py @@ -144,6 +144,7 @@ class HierarchicalSessionStore(DefaultSessionStore): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._extended_cache: Dict[str, ExtendedSessionData] = {} + self._cache_mtimes: Dict[str, float] = {} # Track file modification times def _load_session_from_disk(self, session_id: str, filepath: str) -> ExtendedSessionData: """Load extended session JSON from disk (caller must hold FileLock).""" @@ -173,6 +174,43 @@ def _modify_session_locked( if isinstance(cached, ExtendedSessionData): self._extended_cache[session_id] = cached return result + + def _is_cache_valid(self, session_id: str) -> bool: + """Check if cached session is still valid based on file mtime.""" + if session_id not in self._extended_cache: + return False + + filepath = self._get_session_path(session_id) + if not os.path.exists(filepath): + return False + + try: + current_mtime = os.path.getmtime(filepath) + cached_mtime = self._cache_mtimes.get(session_id, 0) + return current_mtime <= cached_mtime + except (OSError, IOError): + return False + + def _read_session_fresh(self, session_id: str) -> ExtendedSessionData: + """Reload from disk and keep _cache and _extended_cache in sync.""" + session = super()._read_session_fresh(session_id) + if not isinstance(session, ExtendedSessionData): + session = ExtendedSessionData.from_session_data(session) + with self._lock: + self._cache[session_id] = session + + # Update cache with fresh file mtime + filepath = self._get_session_path(session_id) + try: + mtime = os.path.getmtime(filepath) if os.path.exists(filepath) else time.time() + except (OSError, IOError): + mtime = time.time() + + with self._lock: + self._extended_cache[session_id] = session + self._cache_mtimes[session_id] = mtime + + return session def add_message( self, @@ -204,35 +242,14 @@ def _apply(session: SessionData) -> None: ) def _load_extended_session(self, session_id: str, force_reload: bool = False) -> ExtendedSessionData: - """Load extended session from disk.""" - filepath = self._get_session_path(session_id) - - # Check cache first (unless force reload) - if not force_reload: - with self._lock: - if session_id in self._extended_cache: - return self._extended_cache[session_id] - - # Load from disk - if not os.path.exists(filepath): - session = ExtendedSessionData(session_id=session_id) - with self._lock: - self._extended_cache[session_id] = session - return session - - with FileLock(filepath, self.lock_timeout): - try: - with open(filepath, "r", encoding="utf-8") as f: - data = json.load(f) - session = ExtendedSessionData.from_dict(data) - except (json.JSONDecodeError, IOError) as e: - logger.warning(f"Failed to load session {session_id}: {e}") - session = ExtendedSessionData(session_id=session_id) + """Load extended session with smart caching based on file modification time.""" + # Force reload bypasses cache validation + if force_reload or not self._is_cache_valid(session_id): + return self._read_session_fresh(session_id) + # Cache is valid, return cached version with self._lock: - self._extended_cache[session_id] = session - - return session + return self._extended_cache[session_id] def _save_extended_session(self, session: ExtendedSessionData) -> bool: """Save extended session to disk.""" @@ -260,8 +277,15 @@ def _save_extended_session(self, session: ExtendedSessionData) -> bool: os.replace(temp_path, filepath) + # Update cache with current file mtime after successful write + try: + mtime = os.path.getmtime(filepath) + except (OSError, IOError): + mtime = time.time() + with self._lock: self._extended_cache[session.session_id] = session + self._cache_mtimes[session.session_id] = mtime return True except (IOError, OSError) as e: @@ -602,8 +626,8 @@ async def auto_title(self, session_id: str) -> bool: return False def get_extended_session(self, session_id: str) -> ExtendedSessionData: - """Get extended session data.""" - return self._load_extended_session(session_id) + """Get extended session data with smart caching.""" + return self._load_extended_session(session_id, force_reload=False) def export_session(self, session_id: str) -> Dict[str, Any]: """ diff --git a/src/praisonai-agents/tests/unit/session/test_hierarchy.py b/src/praisonai-agents/tests/unit/session/test_hierarchy.py index 39df04f5c..f10556a60 100644 --- a/src/praisonai-agents/tests/unit/session/test_hierarchy.py +++ b/src/praisonai-agents/tests/unit/session/test_hierarchy.py @@ -369,6 +369,94 @@ def test_export_import_with_custom_id(self): new_id = self.store.import_session(exported, new_session_id=custom_id) assert new_id == custom_id + + def test_set_title_does_not_drop_messages_after_external_write(self): + """ + Regression test for the stale cache bug. + + Reproduces the scenario where: + 1. Process A loads session (warms cache) + 2. Process B writes new messages + 3. Process A calls set_title() → should NOT drop Process B's messages + """ + import json + + # Create session with initial messages + session_id = self.store.create_session(title="Test Session") + self.store.add_message(session_id, "user", "Message 1") + self.store.add_message(session_id, "assistant", "Response 1") + + # Process A: Load session (warms cache) + session_a = self.store.get_extended_session(session_id) + assert len(session_a.messages) == 2 + + # Process B: Simulate external write by directly modifying file + # This mimics another process/store instance writing to the same session + filepath = self.store._get_session_path(session_id) + with open(filepath, "r", encoding="utf-8") as f: + data = json.load(f) + + # Add messages from "Process B" + data["messages"].extend([ + {"role": "user", "content": "Message 2", "timestamp": time.time(), "metadata": {}}, + {"role": "assistant", "content": "Response 2", "timestamp": time.time(), "metadata": {}} + ]) + data["updated_at"] = time.time() + + # Write the updated data (simulating external process write) + with open(filepath, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + # Brief sleep to ensure file mtime is different + time.sleep(0.1) + + # Process A: Call set_title() - this should detect the external write + # and reload fresh data instead of using stale cache + result = self.store.set_title(session_id, "Updated Title") + assert result is True + + # Verify no messages were lost - should have all 4 messages + final_session = self.store.get_extended_session(session_id) + assert len(final_session.messages) == 4, f"Expected 4 messages, got {len(final_session.messages)}" + assert final_session.title == "Updated Title" + assert final_session.messages[0].content == "Message 1" + assert final_session.messages[1].content == "Response 1" + assert final_session.messages[2].content == "Message 2" + assert final_session.messages[3].content == "Response 2" + + def test_cache_performance_with_unchanged_files(self): + """ + Test that performance optimization works - reads from cache when file hasn't changed. + """ + session_id = self.store.create_session(title="Cache Test") + self.store.add_message(session_id, "user", "Test message") + + # First read - loads from disk and caches + session1 = self.store.get_extended_session(session_id) + assert len(session1.messages) == 1 + + # Second read should use cache (file hasn't changed) + # We can't easily test this directly, but we can verify the cache is valid + assert self.store._is_cache_valid(session_id) is True + + session2 = self.store.get_extended_session(session_id) + assert len(session2.messages) == 1 + assert session2 is session1 # Should be same cached object + + def test_force_reload_bypasses_cache(self): + """Test that force_reload=True always loads from disk.""" + session_id = self.store.create_session(title="Force Reload Test") + self.store.add_message(session_id, "user", "Message 1") + + # Load and cache + session1 = self.store._load_extended_session(session_id, force_reload=False) + + # Force reload should bypass cache + session2 = self.store._load_extended_session(session_id, force_reload=True) + + # Both should have same data but force_reload ensures fresh read + assert len(session1.messages) == len(session2.messages) + assert session1.session_id == session2.session_id class TestGlobalHierarchicalStore: