From b75642f5a199c595344904785bb307c2854ed464 Mon Sep 17 00:00:00 2001 From: Spark Lab Scout Date: Fri, 6 Mar 2026 20:47:48 +0800 Subject: [PATCH 1/2] fix: list/set sampling optimization bug in _recursive_datetime_check The original code had a bug in the list/set sampling optimization: - When list elements are BaseModel objects, _recursive_datetime_check modifies them in-place (via __dict__), not returning a new object - This caused 'first_checked is first_item' to always be True - The code incorrectly assumed no conversion was needed, skipping all elements after the first one Fix: - Track converted datetime fields in BaseModel processing - Add '_datetime_converted_fields' marker to detect actual changes - Check the marker for BaseModel items in list/set sampling logic - This ensures all elements in a list are properly processed --- src/core/oxm/mongo/document_base.py | 46 ++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/src/core/oxm/mongo/document_base.py b/src/core/oxm/mongo/document_base.py index 0021ecae..33a94bed 100644 --- a/src/core/oxm/mongo/document_base.py +++ b/src/core/oxm/mongo/document_base.py @@ -69,11 +69,18 @@ def _recursive_datetime_check(self, obj, path: str = "", depth: int = 0): # Case 2: Object is BaseModel if isinstance(obj, BaseModel): + # Track if any datetime fields were converted (for list/set sampling optimization) + converted_fields = [] for field_name, value in obj: new_path = f"{path}.{field_name}" if path else field_name new_value = self._recursive_datetime_check(value, new_path, depth + 1) # Directly update value using __dict__ to avoid triggering validators obj.__dict__[field_name] = new_value + # Check if the field value changed (for datetime conversion detection) + if new_value is not value: + converted_fields.append(field_name) + # Attach conversion info for upstream list/set sampling optimization + obj.__dict__["_datetime_converted_fields"] = converted_fields return obj # Case 3: Object is list, tuple, or set (performance optimization) @@ -82,16 +89,29 @@ def _recursive_datetime_check(self, obj, path: str = "", depth: int = 0): if not obj: return obj - # List: only check the first element + # Helper function to check if a BaseModel has converted datetime fields + def has_converted_datetimes(item): + if isinstance(item, BaseModel): + return bool(item.__dict__.get("_datetime_converted_fields", [])) + return item is not self._recursive_datetime_check(item, "", depth + 2) + + # List: only check the first element, but handle BaseModel specially if isinstance(obj, list): first_item = obj[0] first_checked = self._recursive_datetime_check( first_item, f"{path}[0]", depth + 2 ) - # If the first element hasn't changed, assume the whole list doesn't need conversion - if first_checked is first_item: - return obj + # For BaseModel, check if any datetime fields were converted + # For other types, check if reference changed + if isinstance(first_item, BaseModel): + # BaseModel is modified in-place, so we check the conversion marker + if not has_converted_datetimes(first_item): + return obj + else: + # If the first element hasn't changed, assume the whole list doesn't need conversion + if first_checked is first_item: + return obj # Set: check any one element (set is unordered, take the first one) elif isinstance(obj, set): @@ -100,9 +120,14 @@ def _recursive_datetime_check(self, obj, path: str = "", depth: int = 0): sample_item, f"{path}[sample]", depth + 2 ) - # If the sampled element hasn't changed, assume the whole set doesn't need conversion - if sample_checked is sample_item: - return obj + # For BaseModel, check if any datetime fields were converted + if isinstance(sample_item, BaseModel): + if not has_converted_datetimes(sample_item): + return obj + else: + # If the sampled element hasn't changed, assume the whole set doesn't need conversion + if sample_checked is sample_item: + return obj # Tuple: only check the first 3 elements elif isinstance(obj, tuple): @@ -115,7 +140,12 @@ def _recursive_datetime_check(self, obj, path: str = "", depth: int = 0): checked = self._recursive_datetime_check( item, f"{path}[{idx}]", depth + 2 ) - if checked is not item: + # For BaseModel, check conversion marker; for others, check reference + if isinstance(item, BaseModel): + if has_converted_datetimes(item): + need_transform = True + break + elif checked is not item: need_transform = True break From 7cea8c999ddc900cc6ea82bb31bb42ac642cda92 Mon Sep 17 00:00:00 2001 From: Spark Lab Scout Date: Sat, 7 Mar 2026 08:15:25 +0800 Subject: [PATCH 2/2] fix: search all requested memory_types instead of only first one - Fix get_keyword_search_results to iterate over all supported memory types - Fix get_vector_search_results to iterate over all supported memory types - Add MILVUS_REPO_MAP for consistent memory type mapping - Skip unsupported types (e.g., profile) with warning logs instead of failing - Handle individual search failures gracefully to continue with other types Fixes issue: Search API only uses memory_types[0], silently ignoring all other types --- src/agentic_layer/memory_manager.py | 232 ++++++++++++++++------------ 1 file changed, 134 insertions(+), 98 deletions(-) diff --git a/src/agentic_layer/memory_manager.py b/src/agentic_layer/memory_manager.py index df42b05a..0ab7b864 100644 --- a/src/agentic_layer/memory_manager.py +++ b/src/agentic_layer/memory_manager.py @@ -94,6 +94,13 @@ MemoryType.EPISODIC_MEMORY: EpisodicMemoryEsRepository, } +# Milvus repository mapping - same types as ES but for vector search +MILVUS_REPO_MAP = { + MemoryType.FORESIGHT: ForesightMilvusRepository, + MemoryType.EVENT_LOG: EventLogMilvusRepository, + MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository, +} + @dataclass class EventLogCandidate: @@ -337,8 +344,10 @@ async def get_keyword_search_results( retrieve_mem_request: 'RetrieveMemRequest', retrieve_method: str = RetrieveMethod.KEYWORD.value, ) -> List[Dict[str, Any]]: - """Keyword search with stage-level metrics""" + """Keyword search with stage-level metrics - searches all supported memory types""" stage_start = time.perf_counter() + + # Get all memory types for metrics recording (use first one as representative) memory_type = ( retrieve_mem_request.memory_types[0].value if retrieve_mem_request.memory_types @@ -375,32 +384,45 @@ async def get_keyword_search_results( if end_time is not None: date_range["lte"] = end_time - mem_type = memory_types[0] - - repo_class = ES_REPO_MAP.get(mem_type) - if not repo_class: - logger.warning(f"Unsupported memory_type: {mem_type}") + # Filter to only supported memory types (exclude profile, etc. that aren't in ES) + supported_types = [mt for mt in memory_types if mt in ES_REPO_MAP] + + if not supported_types: + logger.warning(f"No supported memory_types for keyword search. Requested: {[mt.value for mt in memory_types]}") return [] - es_repo = get_bean_by_type(repo_class) - logger.debug(f"Using {repo_class.__name__} for {mem_type}") - - results = await es_repo.multi_search( - query=query_words, - user_id=user_id, - group_id=group_id, - size=top_k, - from_=0, - date_range=date_range, - ) + # Search each supported memory type and collect all results + all_results = [] + for mem_type in supported_types: + repo_class = ES_REPO_MAP.get(mem_type) + if not repo_class: + logger.info(f"Skipping unsupported memory_type in keyword search: {mem_type}") + continue + + es_repo = get_bean_by_type(repo_class) + logger.debug(f"Using {repo_class.__name__} for {mem_type}") + + try: + results = await es_repo.multi_search( + query=query_words, + user_id=user_id, + group_id=group_id, + size=top_k, + from_=0, + date_range=date_range, + ) - # Mark memory_type, search_source, and unified score - if results: - for r in results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.KEYWORD.value - r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' - r['score'] = r.get('_score', 0.0) # Unified score field + # Mark memory_type, search_source, and unified score + if results: + for r in results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.KEYWORD.value + r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' + r['score'] = r.get('_score', 0.0) # Unified score field + all_results.extend(results) + except Exception as e: + logger.warning(f"Keyword search failed for {mem_type}: {e}") + continue # Record stage metrics record_retrieve_stage( @@ -410,7 +432,7 @@ async def get_keyword_search_results( duration_seconds=time.perf_counter() - stage_start, ) - return results or [] + return all_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, @@ -472,7 +494,7 @@ async def get_vector_search_results( retrieve_mem_request: 'RetrieveMemRequest', retrieve_method: str = RetrieveMethod.VECTOR.value, ) -> List[Dict[str, Any]]: - """Vector search with stage-level metrics (embedding + milvus_search)""" + """Vector search with stage-level metrics - searches all supported memory types""" memory_type = ( retrieve_mem_request.memory_types[0].value if retrieve_mem_request.memory_types @@ -497,7 +519,7 @@ async def get_vector_search_results( top_k = retrieve_mem_request.top_k start_time = retrieve_mem_request.start_time end_time = retrieve_mem_request.end_time - mem_type = retrieve_mem_request.memory_types[0] + memory_types = retrieve_mem_request.memory_types logger.debug( f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}" @@ -521,74 +543,93 @@ async def get_vector_search_results( f"Query text vectorization completed, vector dimension: {len(query_vector_list)}" ) - # Select Milvus repository based on memory type - match mem_type: - case MemoryType.FORESIGHT: - milvus_repo = get_bean_by_type(ForesightMilvusRepository) - case MemoryType.EVENT_LOG: - milvus_repo = get_bean_by_type(EventLogMilvusRepository) - case MemoryType.EPISODIC_MEMORY: - milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) - case _: - raise ValueError(f"Unsupported memory type: {mem_type}") + # Filter to only supported memory types (exclude profile, etc. that aren't in Milvus) + supported_types = [mt for mt in memory_types if mt in MILVUS_REPO_MAP] + + if not supported_types: + logger.warning(f"No supported memory_types for vector search. Requested: {[mt.value for mt in memory_types]}") + return [] - # Handle time range filter conditions - start_time_dt = None - end_time_dt = None - current_time_dt = None + # Search each supported memory type and collect all results + all_results = [] + for mem_type in supported_types: + # Select Milvus repository based on memory type + milvus_repo_class = MILVUS_REPO_MAP.get(mem_type) + if not milvus_repo_class: + logger.info(f"Skipping unsupported memory_type in vector search: {mem_type}") + continue + + milvus_repo = get_bean_by_type(milvus_repo_class) + + # Handle time range filter conditions + start_time_dt = None + end_time_dt = None + current_time_dt = None + + if start_time is not None: + start_time_dt = ( + from_iso_format(start_time) + if isinstance(start_time, str) + else start_time + ) - if start_time is not None: - start_time_dt = ( - from_iso_format(start_time) - if isinstance(start_time, str) - else start_time - ) + if end_time is not None: + if isinstance(end_time, str): + end_time_dt = from_iso_format(end_time) + # If date only format, set to end of day + if len(end_time) == 10: + end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59) + else: + end_time_dt = end_time + + # Handle foresight time range (only valid for foresight) + if mem_type == MemoryType.FORESIGHT: + if retrieve_mem_request.start_time: + start_time_dt = from_iso_format(retrieve_mem_request.start_time) + if retrieve_mem_request.end_time: + end_time_dt = from_iso_format(retrieve_mem_request.end_time) + if retrieve_mem_request.current_time: + current_time_dt = from_iso_format(retrieve_mem_request.current_time) + + # Call Milvus vector search (pass different parameters based on memory type) + milvus_start = time.perf_counter() + try: + if mem_type == MemoryType.FORESIGHT: + # Foresight: supports time range and validity filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=start_time_dt, + end_time=end_time_dt, + current_time=current_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) + else: + # Episodic memory and event log: use timestamp filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=start_time_dt, + end_time=end_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) + + # Mark memory_type and search_source + if search_results: + for r in search_results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.VECTOR.value + all_results.extend(search_results) + except Exception as e: + logger.warning(f"Vector search failed for {mem_type}: {e}") + continue - if end_time is not None: - if isinstance(end_time, str): - end_time_dt = from_iso_format(end_time) - # If date only format, set to end of day - if len(end_time) == 10: - end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59) - else: - end_time_dt = end_time - - # Handle foresight time range (only valid for foresight) - if mem_type == MemoryType.FORESIGHT: - if retrieve_mem_request.start_time: - start_time_dt = from_iso_format(retrieve_mem_request.start_time) - if retrieve_mem_request.end_time: - end_time_dt = from_iso_format(retrieve_mem_request.end_time) - if retrieve_mem_request.current_time: - current_time_dt = from_iso_format(retrieve_mem_request.current_time) - - # Call Milvus vector search (pass different parameters based on memory type) - milvus_start = time.perf_counter() - if mem_type == MemoryType.FORESIGHT: - # Foresight: supports time range and validity filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - current_time=current_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, - ) - else: - # Episodic memory and event log: use timestamp filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, - ) record_retrieve_stage( retrieve_method=retrieve_method, stage='milvus_search', @@ -596,12 +637,7 @@ async def get_vector_search_results( duration_seconds=time.perf_counter() - milvus_start, ) - for r in search_results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.VECTOR.value - # Milvus already uses 'score', no need to rename - - return search_results + return all_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method,