-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: HierarchicalSessionStore stale extended cache after cross-instance writes #1785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -146,6 +146,7 @@ def __init__(self, *args, **kwargs): | |
| self._extended_cache: Dict[str, ExtendedSessionData] = {} | ||
| self._cache_mtimes: Dict[str, float] = {} # Track file modification times | ||
|
|
||
|
|
||
| def _load_session_from_disk(self, session_id: str, filepath: str) -> ExtendedSessionData: | ||
| """Load extended session JSON from disk (caller must hold FileLock).""" | ||
| if os.path.exists(filepath): | ||
|
|
@@ -330,13 +331,11 @@ def create_session( | |
|
|
||
| # Update parent's children list without clobbering concurrent message writes | ||
| if parent_id: | ||
| def _link_child(parent: SessionData) -> None: | ||
| if sid not in parent.children_ids: | ||
| parent.children_ids.append(sid) | ||
|
|
||
| self._modify_session_locked( | ||
| parent_id, _link_child, error_label="link child session" | ||
| ) | ||
| def _apply(parent_session: SessionData) -> None: | ||
| assert isinstance(parent_session, ExtendedSessionData) | ||
| if sid not in parent_session.children_ids: | ||
| parent_session.children_ids.append(sid) | ||
| self._modify_session_locked(parent_id, _apply, error_label="update parent children") | ||
|
|
||
| self._save_extended_session(session) | ||
| return sid | ||
|
|
@@ -477,28 +476,30 @@ def revert_to_snapshot(self, session_id: str, snapshot_id: str) -> bool: | |
| Returns: | ||
| True if successful | ||
| """ | ||
| # Read-only lookup to find snapshot without triggering unnecessary writes | ||
| session = self._read_session_fresh(session_id) | ||
| snapshot = None | ||
| for s in session.snapshots: | ||
| if s.id == snapshot_id: | ||
| snapshot = s | ||
| break | ||
|
|
||
| if snapshot is None: | ||
| logger.warning(f"Snapshot {snapshot_id} not found") | ||
| return False | ||
|
|
||
| # Now perform the actual revert in a single locked operation | ||
| def _revert(session: SessionData) -> None: | ||
| def _apply(session: SessionData) -> None: | ||
| assert isinstance(session, ExtendedSessionData) | ||
|
|
||
| # Find the snapshot | ||
| snapshot = None | ||
| for s in session.snapshots: | ||
| if s.id == snapshot_id: | ||
| snapshot = s | ||
| break | ||
|
|
||
| if snapshot is None: | ||
| logger.warning(f"Snapshot {snapshot_id} not found") | ||
| raise ValueError(f"Snapshot {snapshot_id} not found") | ||
|
|
||
| # Revert messages | ||
| if snapshot.message_index >= 0: | ||
| session.messages = session.messages[: snapshot.message_index + 1] | ||
| session.messages = session.messages[:snapshot.message_index + 1] | ||
| else: | ||
| session.messages = [] | ||
|
|
||
| return self._modify_session_locked( | ||
| session_id, _revert, error_label="revert to snapshot" | ||
| ) | ||
|
|
||
| try: | ||
| return self._modify_session_locked(session_id, _apply, error_label="revert to snapshot") | ||
| except ValueError: | ||
| return False | ||
|
|
||
| def revert_to_message(self, session_id: str, message_index: int) -> bool: | ||
| """ | ||
|
|
@@ -511,37 +512,33 @@ def revert_to_message(self, session_id: str, message_index: int) -> bool: | |
| Returns: | ||
| True if successful | ||
| """ | ||
| # Validate message index before writing | ||
| session = self._read_session_fresh(session_id) | ||
| if message_index < 0 or message_index >= len(session.messages): | ||
| logger.warning(f"Invalid message index {message_index}") | ||
| def _apply(session: SessionData) -> None: | ||
| assert isinstance(session, ExtendedSessionData) | ||
|
|
||
| if message_index < 0 or message_index >= len(session.messages): | ||
| logger.warning(f"Invalid message index {message_index}") | ||
| raise ValueError(f"Invalid message index {message_index}") | ||
|
|
||
| session.messages = session.messages[:message_index + 1] | ||
|
|
||
| try: | ||
| return self._modify_session_locked(session_id, _apply, error_label="revert to message") | ||
| except ValueError: | ||
| return False | ||
|
|
||
| # Valid index, proceed with locked revert | ||
| def _revert(session: SessionData) -> None: | ||
| session.messages = session.messages[: message_index + 1] | ||
|
|
||
| return self._modify_session_locked( | ||
| session_id, _revert, error_label="revert to message" | ||
| ) | ||
|
|
||
| def share_session(self, session_id: str) -> bool: | ||
| """Mark a session as shared.""" | ||
| def _share(session: SessionData) -> None: | ||
| def _apply(session: SessionData) -> None: | ||
| assert isinstance(session, ExtendedSessionData) | ||
| session.is_shared = True | ||
|
|
||
| return self._modify_session_locked( | ||
| session_id, _share, error_label="share session" | ||
| ) | ||
| return self._modify_session_locked(session_id, _apply, error_label="share session") | ||
|
|
||
| def unshare_session(self, session_id: str) -> bool: | ||
| """Mark a session as not shared.""" | ||
| def _unshare(session: SessionData) -> None: | ||
| def _apply(session: SessionData) -> None: | ||
| assert isinstance(session, ExtendedSessionData) | ||
| session.is_shared = False | ||
|
|
||
| return self._modify_session_locked( | ||
| session_id, _unshare, error_label="unshare session" | ||
| ) | ||
| return self._modify_session_locked(session_id, _apply, error_label="unshare session") | ||
|
|
||
| def is_shared(self, session_id: str) -> bool: | ||
| """Check if a session is shared.""" | ||
|
|
@@ -550,12 +547,10 @@ def is_shared(self, session_id: str) -> bool: | |
|
|
||
| def set_title(self, session_id: str, title: str) -> bool: | ||
| """Set session title.""" | ||
| def _set_title(session: SessionData) -> None: | ||
| def _apply(session: SessionData) -> None: | ||
| assert isinstance(session, ExtendedSessionData) | ||
| session.title = title | ||
|
|
||
| return self._modify_session_locked( | ||
| session_id, _set_title, error_label="set session title" | ||
| ) | ||
| return self._modify_session_locked(session_id, _apply, error_label="set session title") | ||
|
|
||
| async def auto_title(self, session_id: str) -> bool: | ||
| """Generate and set title automatically from first exchange. | ||
|
|
@@ -610,12 +605,19 @@ async def auto_title(self, session_id: str) -> bool: | |
| title = await generate_title_async(user_msg, assistant_msg) | ||
|
|
||
| if title and title.strip(): | ||
| # Reload session to avoid overwriting concurrent updates | ||
| fresh_session = await asyncio.to_thread(self._load_extended_session, session_id) | ||
| # Only set title if it's still empty | ||
| if not fresh_session.title or not fresh_session.title.strip(): | ||
| fresh_session.title = title.strip() | ||
| return await asyncio.to_thread(self._save_extended_session, fresh_session) | ||
| # Use locked read-modify-write to avoid overwriting concurrent updates | ||
| def _apply(fresh_session: SessionData) -> None: | ||
| assert isinstance(fresh_session, ExtendedSessionData) | ||
| # Only set title if it's still empty | ||
| if not fresh_session.title or not fresh_session.title.strip(): | ||
| fresh_session.title = title.strip() | ||
|
|
||
| return await asyncio.to_thread( | ||
| self._modify_session_locked, | ||
| session_id, | ||
| _apply, | ||
| error_label="auto-title session" | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| # Title generation failed - log with context instead of silent failure | ||
|
|
@@ -626,8 +628,18 @@ async def auto_title(self, session_id: str) -> bool: | |
| return False | ||
|
|
||
| def get_extended_session(self, session_id: str) -> ExtendedSessionData: | ||
| """Get extended session data with smart caching.""" | ||
| return self._load_extended_session(session_id, force_reload=False) | ||
| """Get extended session data.""" | ||
| return self._read_session_fresh(session_id) | ||
|
Comment on lines
630
to
+632
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1. Cached write overwrites updates Multiple mutating methods still do read-then-write via _load_extended_session() (which can return stale _extended_cache data) followed by _save_extended_session(), which can overwrite newer messages/fields written by another process. get_extended_session() now refreshes from disk, but write paths like create_session(parent update), set_title/share_session/unshare_session, and auto_title still risk session truncation in multi-worker deployments. Agent Prompt
|
||
|
|
||
| def invalidate_cache(self, session_id: Optional[str] = None) -> None: | ||
| """Invalidate base and extended in-memory caches atomically.""" | ||
| with self._lock: | ||
| if session_id: | ||
| self._cache.pop(session_id, None) | ||
| self._extended_cache.pop(session_id, None) | ||
| else: | ||
| self._cache.clear() | ||
| self._extended_cache.clear() | ||
|
Comment on lines
+634
to
+642
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make cache invalidation atomic across Line 592 clears base cache before Line 593 acquires the lock for Suggested fix def invalidate_cache(self, session_id: Optional[str] = None) -> None:
"""Invalidate base and extended in-memory caches."""
- super().invalidate_cache(session_id)
- with self._lock:
- if session_id:
- self._extended_cache.pop(session_id, None)
- else:
- self._extended_cache.clear()
+ with self._lock:
+ if session_id:
+ self._cache.pop(session_id, None)
+ self._extended_cache.pop(session_id, None)
+ else:
+ self._cache.clear()
+ self._extended_cache.clear()🤖 Prompt for AI Agents |
||
|
|
||
| def export_session(self, session_id: str) -> Dict[str, Any]: | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -117,50 +117,43 @@ def test_add_message_preserves_concurrent_writes(self): | |
| assert len(history) == 2 | ||
| assert history[1]["content"] == "second" | ||
|
|
||
| def test_fork_session_preserves_concurrent_messages(self): | ||
| """Registering a fork must not clobber messages added on the parent.""" | ||
| import threading | ||
| import time | ||
|
|
||
| def test_get_extended_session_sees_writes_from_other_store(self): | ||
| """Extended reads must reload from disk, not stale _extended_cache.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| # Use two separate store instances to simulate concurrent processes | ||
| store1 = HierarchicalSessionStore(session_dir=tmpdir) | ||
| store2 = HierarchicalSessionStore(session_dir=tmpdir) | ||
|
|
||
| # Create session and add initial message | ||
| session_id = store1.create_session(title="Parent") | ||
| store1.add_user_message(session_id, "first") | ||
|
|
||
| # Use threading to create deterministic interleaving | ||
| fork_started = threading.Event() | ||
| fork_completed = threading.Event() | ||
|
|
||
| def concurrent_fork(): | ||
| # Signal that fork has started | ||
| fork_started.set() | ||
| # Small delay to allow message to be added | ||
| time.sleep(0.05) | ||
| fork_id = store1.fork_session(session_id) | ||
| assert fork_id | ||
| fork_completed.set() | ||
| return fork_id | ||
| writer = HierarchicalSessionStore(session_dir=tmpdir) | ||
| reader = HierarchicalSessionStore(session_dir=tmpdir) | ||
|
|
||
| writer.add_user_message("session-1", "first") | ||
| reader._load_extended_session("session-1") | ||
| writer.add_user_message("session-1", "second") | ||
|
|
||
| session = reader.get_extended_session("session-1") | ||
| assert len(session.messages) == 2 | ||
| assert session.messages[1].content == "second" | ||
|
|
||
|
Comment on lines
+120
to
+133
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major | 🏗️ Heavy lift Add an agentic integration/e2e test for this feature path. This new unit test is useful, but guideline-mandated feature coverage also requires a real agent flow ( As per coding guidelines, "Real agentic tests are MANDATORY for every feature: Agent must call agent.start() with a real prompt, call the LLM, and produce actual text response—not just smoke tests of object construction." 🤖 Prompt for AI Agents |
||
| def test_stale_cache_write_preserves_concurrent_updates(self): | ||
| """Metadata writes must not truncate messages written by other processes.""" | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| writer = HierarchicalSessionStore(session_dir=tmpdir) | ||
| reader = HierarchicalSessionStore(session_dir=tmpdir) | ||
|
|
||
| # Start fork operation in background thread | ||
| fork_thread = threading.Thread(target=concurrent_fork) | ||
| fork_thread.start() | ||
| # Create session and warm reader's cache | ||
| writer.create_session("session-1", title="Original") | ||
| reader.get_extended_session("session-1") # Warms cache | ||
|
|
||
| # Wait for fork to start, then add concurrent message | ||
| fork_started.wait() | ||
| store2.add_user_message(session_id, "concurrent_message") | ||
| # Writer adds messages, reader has stale cache | ||
| writer.add_user_message("session-1", "first message") | ||
| writer.add_assistant_message("session-1", "first response") | ||
|
|
||
| # Wait for fork to complete | ||
| fork_thread.join() | ||
| fork_completed.wait() | ||
| # Reader performs metadata-only write with stale cache | ||
| reader.set_title("session-1", "Updated Title") | ||
|
|
||
| # Both messages should be preserved | ||
| history = store1.get_chat_history(session_id) | ||
| assert len(history) == 2 | ||
| assert any(msg["content"] == "concurrent_message" for msg in history) | ||
| # Verify messages are preserved | ||
| session = writer.get_extended_session("session-1") | ||
| assert session.title == "Updated Title" | ||
| assert len(session.messages) == 2 | ||
| assert session.messages[0].content == "first message" | ||
| assert session.messages[1].content == "first response" | ||
|
|
||
| def test_update_session_metadata_preserves_extended_fields(self): | ||
| """Metadata updates must not strip parent_id, snapshots, etc.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto_titlenow returnsTrueeven when no title was setThe refactored implementation returns the result of
_modify_session_locked, which isTruewhenever the locked write succeeds — regardless of whether_applyactually changed anything. If another process sets the title between the early-exit check (line 545) and the locked write,_applysilently no-ops but_modify_session_lockedstill returnsTrue, violating the documented contract ("True if title was generated and set"). The old code fell through toreturn Falsein that concurrent case. Additionally, every invocation now always writes the session back to disk (updatingupdated_at) even when_applymakes no change.