Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 53 additions & 29 deletions src/praisonai-agents/praisonaiagents/session/hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class HierarchicalSessionStore(DefaultSessionStore):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._extended_cache: Dict[str, ExtendedSessionData] = {}
self._cache_mtimes: Dict[str, float] = {} # Track file modification times

def _load_session_from_disk(self, session_id: str, filepath: str) -> ExtendedSessionData:
"""Load extended session JSON from disk (caller must hold FileLock)."""
Expand Down Expand Up @@ -173,6 +174,43 @@ def _modify_session_locked(
if isinstance(cached, ExtendedSessionData):
self._extended_cache[session_id] = cached
return result

def _is_cache_valid(self, session_id: str) -> bool:
"""Check if cached session is still valid based on file mtime."""
if session_id not in self._extended_cache:
return False

filepath = self._get_session_path(session_id)
if not os.path.exists(filepath):
return False

try:
current_mtime = os.path.getmtime(filepath)
cached_mtime = self._cache_mtimes.get(session_id, 0)
return current_mtime <= cached_mtime
except (OSError, IOError):
return False

def _read_session_fresh(self, session_id: str) -> ExtendedSessionData:
"""Reload from disk and keep _cache and _extended_cache in sync."""
session = super()._read_session_fresh(session_id)
if not isinstance(session, ExtendedSessionData):
session = ExtendedSessionData.from_session_data(session)
with self._lock:
self._cache[session_id] = session

# Update cache with fresh file mtime
filepath = self._get_session_path(session_id)
try:
mtime = os.path.getmtime(filepath) if os.path.exists(filepath) else time.time()
except (OSError, IOError):
mtime = time.time()

with self._lock:
self._extended_cache[session_id] = session
self._cache_mtimes[session_id] = mtime

return session
Comment on lines +194 to +213
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Mtime captured outside FileLock — same race it aims to prevent

super()._read_session_fresh() acquires and releases FileLock internally. The mtime snapshot is taken after that lock is gone. In the window between the lock release and the getmtime call, a concurrent writer can complete a full locked write, advancing the file's mtime to T1. The cache then stores T0 data paired with the T1 mtime. On the next _is_cache_valid check, current_mtime == cached_mtime (T1) returns True, so the stale T0 data is served — the exact bug this PR intends to close.

The mtime must be sampled inside the same FileLock that protects the read. Because super()._read_session_fresh owns the lock and returns after releasing it, the safe fix is to bypass super() and reproduce its logic directly, capturing os.path.getmtime while the lock is still held.


def add_message(
self,
Expand Down Expand Up @@ -204,35 +242,14 @@ def _apply(session: SessionData) -> None:
)

def _load_extended_session(self, session_id: str, force_reload: bool = False) -> ExtendedSessionData:
"""Load extended session from disk."""
filepath = self._get_session_path(session_id)

# Check cache first (unless force reload)
if not force_reload:
with self._lock:
if session_id in self._extended_cache:
return self._extended_cache[session_id]

# Load from disk
if not os.path.exists(filepath):
session = ExtendedSessionData(session_id=session_id)
with self._lock:
self._extended_cache[session_id] = session
return session

with FileLock(filepath, self.lock_timeout):
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
session = ExtendedSessionData.from_dict(data)
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to load session {session_id}: {e}")
session = ExtendedSessionData(session_id=session_id)
"""Load extended session with smart caching based on file modification time."""
# Force reload bypasses cache validation
if force_reload or not self._is_cache_valid(session_id):
return self._read_session_fresh(session_id)

# Cache is valid, return cached version
with self._lock:
self._extended_cache[session_id] = session

return session
return self._extended_cache[session_id]

def _save_extended_session(self, session: ExtendedSessionData) -> bool:
"""Save extended session to disk."""
Expand Down Expand Up @@ -260,8 +277,15 @@ def _save_extended_session(self, session: ExtendedSessionData) -> bool:

os.replace(temp_path, filepath)

# Update cache with current file mtime after successful write
try:
mtime = os.path.getmtime(filepath)
except (OSError, IOError):
mtime = time.time()

with self._lock:
self._extended_cache[session.session_id] = session
self._cache_mtimes[session.session_id] = mtime

return True
except (IOError, OSError) as e:
Expand Down Expand Up @@ -602,8 +626,8 @@ async def auto_title(self, session_id: str) -> bool:
return False

def get_extended_session(self, session_id: str) -> ExtendedSessionData:
"""Get extended session data."""
return self._load_extended_session(session_id)
"""Get extended session data with smart caching."""
return self._load_extended_session(session_id, force_reload=False)

def export_session(self, session_id: str) -> Dict[str, Any]:
"""
Expand Down
88 changes: 88 additions & 0 deletions src/praisonai-agents/tests/unit/session/test_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,94 @@ def test_export_import_with_custom_id(self):
new_id = self.store.import_session(exported, new_session_id=custom_id)

assert new_id == custom_id

def test_set_title_does_not_drop_messages_after_external_write(self):
"""
Regression test for the stale cache bug.

Reproduces the scenario where:
1. Process A loads session (warms cache)
2. Process B writes new messages
3. Process A calls set_title() → should NOT drop Process B's messages
"""
import json

# Create session with initial messages
session_id = self.store.create_session(title="Test Session")
self.store.add_message(session_id, "user", "Message 1")
self.store.add_message(session_id, "assistant", "Response 1")

# Process A: Load session (warms cache)
session_a = self.store.get_extended_session(session_id)
assert len(session_a.messages) == 2

# Process B: Simulate external write by directly modifying file
# This mimics another process/store instance writing to the same session
filepath = self.store._get_session_path(session_id)
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)

# Add messages from "Process B"
data["messages"].extend([
{"role": "user", "content": "Message 2", "timestamp": time.time(), "metadata": {}},
{"role": "assistant", "content": "Response 2", "timestamp": time.time(), "metadata": {}}
])
data["updated_at"] = time.time()

# Write the updated data (simulating external process write)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)

# Brief sleep to ensure file mtime is different
time.sleep(0.1)

# Process A: Call set_title() - this should detect the external write
# and reload fresh data instead of using stale cache
result = self.store.set_title(session_id, "Updated Title")
assert result is True

# Verify no messages were lost - should have all 4 messages
final_session = self.store.get_extended_session(session_id)
assert len(final_session.messages) == 4, f"Expected 4 messages, got {len(final_session.messages)}"
assert final_session.title == "Updated Title"
assert final_session.messages[0].content == "Message 1"
assert final_session.messages[1].content == "Response 1"
assert final_session.messages[2].content == "Message 2"
assert final_session.messages[3].content == "Response 2"

def test_cache_performance_with_unchanged_files(self):
"""
Test that performance optimization works - reads from cache when file hasn't changed.
"""
session_id = self.store.create_session(title="Cache Test")
self.store.add_message(session_id, "user", "Test message")

# First read - loads from disk and caches
session1 = self.store.get_extended_session(session_id)
assert len(session1.messages) == 1

# Second read should use cache (file hasn't changed)
# We can't easily test this directly, but we can verify the cache is valid
assert self.store._is_cache_valid(session_id) is True

session2 = self.store.get_extended_session(session_id)
assert len(session2.messages) == 1
assert session2 is session1 # Should be same cached object

def test_force_reload_bypasses_cache(self):
"""Test that force_reload=True always loads from disk."""
session_id = self.store.create_session(title="Force Reload Test")
self.store.add_message(session_id, "user", "Message 1")

# Load and cache
session1 = self.store._load_extended_session(session_id, force_reload=False)

# Force reload should bypass cache
session2 = self.store._load_extended_session(session_id, force_reload=True)

# Both should have same data but force_reload ensures fresh read
assert len(session1.messages) == len(session2.messages)
assert session1.session_id == session2.session_id


class TestGlobalHierarchicalStore:
Expand Down
Loading