diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index c394e550..d8c7a2d1 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -401,6 +401,11 @@ class Settings(BaseSettings): redisvl_index_prefix: str = "memory_idx" redisvl_indexing_algorithm: str = "HNSW" + # Working Memory Index Settings + # Used for listing sessions via Redis Search instead of sorted sets + working_memory_index_name: str = "working_memory_idx" + working_memory_index_prefix: str = "working_memory:" + # Deduplication Settings (Store-Time) # Distance threshold for semantic similarity when deduplicating at store time # 0.35 works well for catching paraphrased content while avoiding false positives diff --git a/agent_memory_server/main.py b/agent_memory_server/main.py index 945b1a00..706edb6f 100644 --- a/agent_memory_server/main.py +++ b/agent_memory_server/main.py @@ -86,6 +86,11 @@ async def lifespan(app: FastAPI): await check_and_set_migration_status(redis_conn) + # Ensure working memory search index exists for session listing + from agent_memory_server.working_memory_index import ensure_working_memory_index + + await ensure_working_memory_index(redis_conn) + # Initialize Docket for background tasks if enabled if settings.use_docket: logger.info("Attempting to initialize Docket for background tasks.") diff --git a/agent_memory_server/utils/keys.py b/agent_memory_server/utils/keys.py index cf3bed5d..c5d856f8 100644 --- a/agent_memory_server/utils/keys.py +++ b/agent_memory_server/utils/keys.py @@ -38,9 +38,19 @@ def messages_key(session_id: str, namespace: str | None = None) -> str: @staticmethod def sessions_key(namespace: str | None = None) -> str: - """Get the sessions key for a namespace.""" + """Get the sessions key for a namespace. + + DEPRECATED: This method is deprecated. Session listing now uses + Redis Search index on working memory JSON documents instead of + sorted sets. The index automatically handles TTL expiration. + """ return f"sessions:{namespace}" if namespace else "sessions" + @staticmethod + def working_memory_index_name() -> str: + """Return the name of the working memory search index.""" + return settings.working_memory_index_name + @staticmethod def memory_key(id: str) -> str: """Get the memory key for an ID.""" diff --git a/agent_memory_server/working_memory.py b/agent_memory_server/working_memory.py index 39fc0f03..03a3db35 100644 --- a/agent_memory_server/working_memory.py +++ b/agent_memory_server/working_memory.py @@ -232,38 +232,100 @@ async def list_sessions( user_id: str | None = None, ) -> tuple[int, list[str]]: """ - List sessions + List sessions using Redis Search index. + + Uses FT.SEARCH on the working memory index to list sessions. This approach + ensures that expired sessions (via TTL) are automatically excluded since + Redis Search removes deleted keys from the index. Args: redis: Redis client limit: Maximum number of sessions to return offset: Offset for pagination namespace: Optional namespace filter - user_id: Optional user ID filter (not yet implemented - sessions are stored in sorted sets) + user_id: Optional user ID filter Returns: Tuple of (total_count, session_ids) - - Note: - The user_id parameter is accepted for API compatibility but filtering by user_id - is not yet implemented. This would require changing how sessions are stored to - enable efficient user_id-based filtering. """ - # Calculate start and end indices (0-indexed start, inclusive end) - start = offset - end = offset + limit - 1 + # Build filter query parts + filter_parts = [] + + if namespace: + # Escape special characters in TAG values + escaped_namespace = _escape_tag_value(namespace) + filter_parts.append(f"@namespace:{{{escaped_namespace}}}") - # TODO: This should take a user_id - sessions_key = Keys.sessions_key(namespace=namespace) + if user_id: + escaped_user_id = _escape_tag_value(user_id) + filter_parts.append(f"@user_id:{{{escaped_user_id}}}") - async with redis.pipeline() as pipe: - pipe.zcard(sessions_key) - pipe.zrange(sessions_key, start, end) - total, session_ids = await pipe.execute() + # Combine filters or use wildcard for all + filter_str = " ".join(filter_parts) if filter_parts else "*" - return total, [ - s.decode("utf-8") if isinstance(s, bytes) else s for s in session_ids - ] + try: + # Execute FT.SEARCH query + result = await redis.execute_command( + "FT.SEARCH", + settings.working_memory_index_name, + filter_str, + "RETURN", + "1", + "$.session_id", + "SORTBY", + "created_at", + "DESC", + "LIMIT", + str(offset), + str(limit), + ) + + # Parse FT.SEARCH response + # Format: [total_count, key1, [field, value], key2, [field, value], ...] + total = result[0] + session_ids = [] + + # Iterate through results (skip the total count at index 0) + i = 1 + while i < len(result): + # Skip the key name + i += 1 + if i < len(result): + # Get the field-value pairs + fields = result[i] + if fields and len(fields) >= 2: + # fields is [field_name, value] + session_id = fields[1] + if isinstance(session_id, bytes): + session_id = session_id.decode("utf-8") + # Remove JSON quotes if present + if session_id.startswith('"') and session_id.endswith('"'): + session_id = session_id[1:-1] + session_ids.append(session_id) + i += 1 + + return total, session_ids + + except Exception as e: + logger.error(f"Error listing sessions: {e}") + # Return empty results on error + return 0, [] + + +def _escape_tag_value(value: str) -> str: + """ + Escape special characters in Redis Search TAG values. + + TAG field values need certain characters escaped to be parsed correctly. + Redis Search requires backslash escaping for special characters in TAG queries. + """ + # First escape backslashes (must be done first to avoid double-escaping) + result = value.replace("\\", "\\\\") + # Characters that need escaping in TAG queries (excluding backslash, already handled) + special_chars = ["-", "@", ":", "{", "}", "(", ")", "[", "]", "'", '"', "|", " "] + for char in special_chars: + result = result.replace(char, f"\\{char}") + return result async def get_working_memory( @@ -455,12 +517,10 @@ async def set_working_memory( try: # Use Redis native JSON storage + # The working memory search index automatically indexes this document + # for session listing (no need for separate sorted set) await redis_client.json().set(key, "$", data) - # Index session in sorted set for listing - sessions_key = Keys.sessions_key(namespace=working_memory.namespace) - await redis_client.zadd(sessions_key, {working_memory.session_id: time.time()}) - if working_memory.ttl_seconds is not None: # Set TTL separately for JSON keys await redis_client.expire(key, working_memory.ttl_seconds) @@ -501,10 +561,9 @@ async def delete_working_memory( ) try: + # Delete the JSON key - the working memory search index automatically + # removes the document from the index when the key is deleted await redis_client.delete(key) - # Remove session from sorted set index - sessions_key = Keys.sessions_key(namespace=namespace) - await redis_client.zrem(sessions_key, session_id) logger.info(f"Deleted working memory for session {session_id}") except Exception as e: diff --git a/agent_memory_server/working_memory_index.py b/agent_memory_server/working_memory_index.py new file mode 100644 index 00000000..2ef10611 --- /dev/null +++ b/agent_memory_server/working_memory_index.py @@ -0,0 +1,135 @@ +"""Working memory search index for session listing. + +This module provides Redis Search index creation and management for working memory +JSON documents. Using a search index instead of sorted sets ensures that when +working memory expires via TTL, the session is automatically removed from the index. +""" + +import logging + +from redis.asyncio import Redis +from redis.exceptions import ResponseError + +from agent_memory_server.config import settings + + +logger = logging.getLogger(__name__) + +# Index name constant +WORKING_MEMORY_INDEX_NAME = settings.working_memory_index_name +WORKING_MEMORY_INDEX_PREFIX = settings.working_memory_index_prefix + + +async def ensure_working_memory_index(redis_client: Redis) -> bool: + """ + Ensure the working memory search index exists. + + Creates a Redis Search index on JSON documents with prefix 'working_memory:' + if it doesn't already exist. The index enables efficient session listing + with filtering by namespace and user_id. + + Args: + redis_client: Redis client instance + + Returns: + True if index was created, False if it already existed + """ + index_name = WORKING_MEMORY_INDEX_NAME + prefix = WORKING_MEMORY_INDEX_PREFIX + + try: + # Check if index already exists + await redis_client.execute_command("FT.INFO", index_name) + logger.info(f"Working memory index '{index_name}' already exists") + return False + except ResponseError as e: + error_msg = str(e).lower() + # Handle both "unknown index name" and "no such index" error messages + if "unknown index name" not in error_msg and "no such index" not in error_msg: + # Some other error occurred + raise + + # Create the index + # Schema indexes the JSON fields we need for filtering and sorting + try: + await redis_client.execute_command( + "FT.CREATE", + index_name, + "ON", + "JSON", + "PREFIX", + "1", + prefix, + "SCHEMA", + "$.session_id", + "AS", + "session_id", + "TAG", + "SORTABLE", + "$.namespace", + "AS", + "namespace", + "TAG", + "SORTABLE", + "$.user_id", + "AS", + "user_id", + "TAG", + "SORTABLE", + "$.created_at", + "AS", + "created_at", + "NUMERIC", + "SORTABLE", + "$.updated_at", + "AS", + "updated_at", + "NUMERIC", + "SORTABLE", + ) + logger.info( + f"Created working memory index '{index_name}' with prefix '{prefix}'" + ) + return True + except ResponseError as e: + logger.error(f"Failed to create working memory index: {e}") + raise + + +async def drop_working_memory_index(redis_client: Redis) -> bool: + """ + Drop the working memory search index. + + Args: + redis_client: Redis client instance + + Returns: + True if index was dropped, False if it didn't exist + """ + index_name = WORKING_MEMORY_INDEX_NAME + + try: + await redis_client.execute_command("FT.DROPINDEX", index_name) + logger.info(f"Dropped working memory index '{index_name}'") + return True + except ResponseError as e: + error_msg = str(e).lower() + # Handle both "unknown index name" and "no such index" error messages + if "unknown index name" in error_msg or "no such index" in error_msg: + logger.info(f"Working memory index '{index_name}' does not exist") + return False + raise + + +async def rebuild_working_memory_index(redis_client: Redis) -> bool: + """ + Rebuild the working memory search index by dropping and recreating it. + + Args: + redis_client: Redis client instance + + Returns: + True if index was rebuilt successfully + """ + await drop_working_memory_index(redis_client) + return await ensure_working_memory_index(redis_client) diff --git a/tests/conftest.py b/tests/conftest.py index 3f59ede3..f3f8df37 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -98,6 +98,25 @@ async def search_index(async_redis_client): ) +@pytest.fixture(autouse=True) +async def working_memory_index(async_redis_client): + """Ensure working memory search index exists for session listing tests.""" + from agent_memory_server.working_memory_index import ( + WORKING_MEMORY_INDEX_NAME, + ensure_working_memory_index, + ) + + await ensure_working_memory_index(async_redis_client) + + yield + + # Clean up after tests + with contextlib.suppress(Exception): + await async_redis_client.execute_command( + "FT.DROPINDEX", WORKING_MEMORY_INDEX_NAME + ) + + @pytest.fixture() async def session(use_test_redis_connection, async_redis_client, request): """Set up a test session with Redis data for testing""" @@ -135,10 +154,8 @@ async def session(use_test_redis_connection, async_redis_client, request): redis_client=use_test_redis_connection, ) - # Also add session to sessions list for compatibility - sessions_key = Keys.sessions_key(namespace=namespace) - current_time = int(time.time()) - await use_test_redis_connection.zadd(sessions_key, {session_id: current_time}) + # Note: Session is now automatically indexed via Redis Search index + # on the working memory JSON document (no sorted set needed) # Index the messages as long-term memories directly without background tasks from redisvl.utils.vectorize import OpenAITextVectorizer diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index f08d94ed..ea1e911e 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -1,7 +1,6 @@ services: redis: image: "${REDIS_IMAGE:-redis:8.4}" - platform: linux/amd64 ports: - "6379" environment: diff --git a/tests/test_working_memory.py b/tests/test_working_memory.py index 46570a96..b9458cf3 100644 --- a/tests/test_working_memory.py +++ b/tests/test_working_memory.py @@ -707,8 +707,10 @@ async def test_migration_skipped_when_env_variable_set( monkeypatch.setattr(config.settings, "working_memory_migration_complete", False) @pytest.mark.asyncio - async def test_set_working_memory_indexes_session(self, async_redis_client): - """Test that set_working_memory adds session to the sessions sorted set.""" + async def test_set_working_memory_indexes_session_in_search_index( + self, async_redis_client + ): + """Test that set_working_memory adds session to the search index.""" session_id = "test_session_index_123" namespace = "test_namespace_index" @@ -723,18 +725,22 @@ async def test_set_working_memory_indexes_session(self, async_redis_client): # Save working memory await set_working_memory(working_mem, redis_client=async_redis_client) - # Verify session is in sorted set - sessions_key = Keys.sessions_key(namespace=namespace) - score = await async_redis_client.zscore(sessions_key, session_id) + # Verify session is findable via list_sessions (uses search index) + total, sessions = await list_sessions( + redis=async_redis_client, + namespace=namespace, + limit=10, + offset=0, + ) - assert score is not None, "Session should be indexed in sorted set" - assert score > 0, "Score should be a positive timestamp" + assert session_id in sessions, "Session should be indexed in search index" + assert total >= 1, "Should have at least one session" @pytest.mark.asyncio - async def test_delete_working_memory_removes_session_from_index( + async def test_delete_working_memory_removes_session_from_search_index( self, async_redis_client ): - """Test that delete_working_memory removes session from the sessions sorted set.""" + """Test that delete_working_memory removes session from the search index.""" session_id = "test_session_to_delete" namespace = "test_namespace_delete" @@ -748,9 +754,13 @@ async def test_delete_working_memory_removes_session_from_index( await set_working_memory(working_mem, redis_client=async_redis_client) # Verify session is indexed - sessions_key = Keys.sessions_key(namespace=namespace) - score_before = await async_redis_client.zscore(sessions_key, session_id) - assert score_before is not None, "Session should be indexed before delete" + total_before, sessions_before = await list_sessions( + redis=async_redis_client, + namespace=namespace, + limit=10, + offset=0, + ) + assert session_id in sessions_before, "Session should be indexed before delete" # Delete working memory await delete_working_memory( @@ -758,8 +768,15 @@ async def test_delete_working_memory_removes_session_from_index( ) # Verify session is removed from index - score_after = await async_redis_client.zscore(sessions_key, session_id) - assert score_after is None, "Session should be removed from index after delete" + total_after, sessions_after = await list_sessions( + redis=async_redis_client, + namespace=namespace, + limit=10, + offset=0, + ) + assert ( + session_id not in sessions_after + ), "Session should be removed from index after delete" @pytest.mark.asyncio async def test_list_sessions_returns_indexed_sessions(self, async_redis_client): @@ -789,3 +806,46 @@ async def test_list_sessions_returns_indexed_sessions(self, async_redis_client): assert set(listed_sessions) == set( session_ids ), f"Expected {session_ids}, got {listed_sessions}" + + @pytest.mark.asyncio + async def test_list_sessions_filters_by_user_id(self, async_redis_client): + """Test that list_sessions can filter by user_id.""" + namespace = "user_filter_namespace" + user1_sessions = ["user1_session_a", "user1_session_b"] + user2_sessions = ["user2_session_a"] + + # Create sessions for user1 + for session_id in user1_sessions: + working_mem = WorkingMemory( + session_id=session_id, + namespace=namespace, + user_id="user1", + messages=[], + memories=[], + ) + await set_working_memory(working_mem, redis_client=async_redis_client) + + # Create sessions for user2 + for session_id in user2_sessions: + working_mem = WorkingMemory( + session_id=session_id, + namespace=namespace, + user_id="user2", + messages=[], + memories=[], + ) + await set_working_memory(working_mem, redis_client=async_redis_client) + + # List sessions for user1 only + total, listed_sessions = await list_sessions( + redis=async_redis_client, + namespace=namespace, + user_id="user1", + limit=10, + offset=0, + ) + + assert total == 2, f"Expected 2 sessions for user1, got {total}" + assert set(listed_sessions) == set( + user1_sessions + ), f"Expected {user1_sessions}, got {listed_sessions}"