Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 134 additions & 98 deletions src/agentic_layer/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@
MemoryType.EPISODIC_MEMORY: EpisodicMemoryEsRepository,
}

# Milvus repository mapping - same types as ES but for vector search
MILVUS_REPO_MAP = {
MemoryType.FORESIGHT: ForesightMilvusRepository,
MemoryType.EVENT_LOG: EventLogMilvusRepository,
MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository,
}


@dataclass
class EventLogCandidate:
Expand Down Expand Up @@ -337,8 +344,10 @@ async def get_keyword_search_results(
retrieve_mem_request: 'RetrieveMemRequest',
retrieve_method: str = RetrieveMethod.KEYWORD.value,
) -> List[Dict[str, Any]]:
"""Keyword search with stage-level metrics"""
"""Keyword search with stage-level metrics - searches all supported memory types"""
stage_start = time.perf_counter()

# Get all memory types for metrics recording (use first one as representative)
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
Expand Down Expand Up @@ -375,32 +384,45 @@ async def get_keyword_search_results(
if end_time is not None:
date_range["lte"] = end_time

mem_type = memory_types[0]

repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.warning(f"Unsupported memory_type: {mem_type}")
# Filter to only supported memory types (exclude profile, etc. that aren't in ES)
supported_types = [mt for mt in memory_types if mt in ES_REPO_MAP]

if not supported_types:
logger.warning(f"No supported memory_types for keyword search. Requested: {[mt.value for mt in memory_types]}")
return []

es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")

results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)
# Search each supported memory type and collect all results
all_results = []
for mem_type in supported_types:
repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.info(f"Skipping unsupported memory_type in keyword search: {mem_type}")
continue

es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")

try:
results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)

# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
all_results.extend(results)
except Exception as e:
logger.warning(f"Keyword search failed for {mem_type}: {e}")
continue

# Record stage metrics
record_retrieve_stage(
Expand All @@ -410,7 +432,7 @@ async def get_keyword_search_results(
duration_seconds=time.perf_counter() - stage_start,
)

return results or []
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
Expand Down Expand Up @@ -472,7 +494,7 @@ async def get_vector_search_results(
retrieve_mem_request: 'RetrieveMemRequest',
retrieve_method: str = RetrieveMethod.VECTOR.value,
) -> List[Dict[str, Any]]:
"""Vector search with stage-level metrics (embedding + milvus_search)"""
"""Vector search with stage-level metrics - searches all supported memory types"""
memory_type = (
retrieve_mem_request.memory_types[0].value
if retrieve_mem_request.memory_types
Expand All @@ -497,7 +519,7 @@ async def get_vector_search_results(
top_k = retrieve_mem_request.top_k
start_time = retrieve_mem_request.start_time
end_time = retrieve_mem_request.end_time
mem_type = retrieve_mem_request.memory_types[0]
memory_types = retrieve_mem_request.memory_types

logger.debug(
f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}"
Expand All @@ -521,87 +543,101 @@ async def get_vector_search_results(
f"Query text vectorization completed, vector dimension: {len(query_vector_list)}"
)

# Select Milvus repository based on memory type
match mem_type:
case MemoryType.FORESIGHT:
milvus_repo = get_bean_by_type(ForesightMilvusRepository)
case MemoryType.EVENT_LOG:
milvus_repo = get_bean_by_type(EventLogMilvusRepository)
case MemoryType.EPISODIC_MEMORY:
milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository)
case _:
raise ValueError(f"Unsupported memory type: {mem_type}")
# Filter to only supported memory types (exclude profile, etc. that aren't in Milvus)
supported_types = [mt for mt in memory_types if mt in MILVUS_REPO_MAP]

if not supported_types:
logger.warning(f"No supported memory_types for vector search. Requested: {[mt.value for mt in memory_types]}")
return []

# Handle time range filter conditions
start_time_dt = None
end_time_dt = None
current_time_dt = None
# Search each supported memory type and collect all results
all_results = []
for mem_type in supported_types:
# Select Milvus repository based on memory type
milvus_repo_class = MILVUS_REPO_MAP.get(mem_type)
if not milvus_repo_class:
logger.info(f"Skipping unsupported memory_type in vector search: {mem_type}")
continue

milvus_repo = get_bean_by_type(milvus_repo_class)

# Handle time range filter conditions
start_time_dt = None
end_time_dt = None
current_time_dt = None

if start_time is not None:
start_time_dt = (
from_iso_format(start_time)
if isinstance(start_time, str)
else start_time
)

if start_time is not None:
start_time_dt = (
from_iso_format(start_time)
if isinstance(start_time, str)
else start_time
)
if end_time is not None:
if isinstance(end_time, str):
end_time_dt = from_iso_format(end_time)
# If date only format, set to end of day
if len(end_time) == 10:
end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59)
else:
end_time_dt = end_time

# Handle foresight time range (only valid for foresight)
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
start_time_dt = from_iso_format(retrieve_mem_request.start_time)
if retrieve_mem_request.end_time:
end_time_dt = from_iso_format(retrieve_mem_request.end_time)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(retrieve_mem_request.current_time)

# Call Milvus vector search (pass different parameters based on memory type)
milvus_start = time.perf_counter()
try:
if mem_type == MemoryType.FORESIGHT:
# Foresight: supports time range and validity filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
# Episodic memory and event log: use timestamp filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)

# Mark memory_type and search_source
if search_results:
for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
all_results.extend(search_results)
except Exception as e:
logger.warning(f"Vector search failed for {mem_type}: {e}")
continue

if end_time is not None:
if isinstance(end_time, str):
end_time_dt = from_iso_format(end_time)
# If date only format, set to end of day
if len(end_time) == 10:
end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59)
else:
end_time_dt = end_time

# Handle foresight time range (only valid for foresight)
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
start_time_dt = from_iso_format(retrieve_mem_request.start_time)
if retrieve_mem_request.end_time:
end_time_dt = from_iso_format(retrieve_mem_request.end_time)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(retrieve_mem_request.current_time)

# Call Milvus vector search (pass different parameters based on memory type)
milvus_start = time.perf_counter()
if mem_type == MemoryType.FORESIGHT:
# Foresight: supports time range and validity filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
# Episodic memory and event log: use timestamp filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
record_retrieve_stage(
retrieve_method=retrieve_method,
stage='milvus_search',
memory_type=memory_type,
duration_seconds=time.perf_counter() - milvus_start,
)

for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
# Milvus already uses 'score', no need to rename

return search_results
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
Expand Down
46 changes: 38 additions & 8 deletions src/core/oxm/mongo/document_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,18 @@ def _recursive_datetime_check(self, obj, path: str = "", depth: int = 0):

# Case 2: Object is BaseModel
if isinstance(obj, BaseModel):
# Track if any datetime fields were converted (for list/set sampling optimization)
converted_fields = []
for field_name, value in obj:
new_path = f"{path}.{field_name}" if path else field_name
new_value = self._recursive_datetime_check(value, new_path, depth + 1)
# Directly update value using __dict__ to avoid triggering validators
obj.__dict__[field_name] = new_value
# Check if the field value changed (for datetime conversion detection)
if new_value is not value:
converted_fields.append(field_name)
# Attach conversion info for upstream list/set sampling optimization
obj.__dict__["_datetime_converted_fields"] = converted_fields
return obj

# Case 3: Object is list, tuple, or set (performance optimization)
Expand All @@ -82,16 +89,29 @@ def _recursive_datetime_check(self, obj, path: str = "", depth: int = 0):
if not obj:
return obj

# List: only check the first element
# Helper function to check if a BaseModel has converted datetime fields
def has_converted_datetimes(item):
if isinstance(item, BaseModel):
return bool(item.__dict__.get("_datetime_converted_fields", []))
return item is not self._recursive_datetime_check(item, "", depth + 2)

# List: only check the first element, but handle BaseModel specially
if isinstance(obj, list):
first_item = obj[0]
first_checked = self._recursive_datetime_check(
first_item, f"{path}[0]", depth + 2
)

# If the first element hasn't changed, assume the whole list doesn't need conversion
if first_checked is first_item:
return obj
# For BaseModel, check if any datetime fields were converted
# For other types, check if reference changed
if isinstance(first_item, BaseModel):
# BaseModel is modified in-place, so we check the conversion marker
if not has_converted_datetimes(first_item):
return obj
else:
# If the first element hasn't changed, assume the whole list doesn't need conversion
if first_checked is first_item:
return obj

# Set: check any one element (set is unordered, take the first one)
elif isinstance(obj, set):
Expand All @@ -100,9 +120,14 @@ def _recursive_datetime_check(self, obj, path: str = "", depth: int = 0):
sample_item, f"{path}[sample]", depth + 2
)

# If the sampled element hasn't changed, assume the whole set doesn't need conversion
if sample_checked is sample_item:
return obj
# For BaseModel, check if any datetime fields were converted
if isinstance(sample_item, BaseModel):
if not has_converted_datetimes(sample_item):
return obj
else:
# If the sampled element hasn't changed, assume the whole set doesn't need conversion
if sample_checked is sample_item:
return obj

# Tuple: only check the first 3 elements
elif isinstance(obj, tuple):
Expand All @@ -115,7 +140,12 @@ def _recursive_datetime_check(self, obj, path: str = "", depth: int = 0):
checked = self._recursive_datetime_check(
item, f"{path}[{idx}]", depth + 2
)
if checked is not item:
# For BaseModel, check conversion marker; for others, check reference
if isinstance(item, BaseModel):
if has_converted_datetimes(item):
need_transform = True
break
elif checked is not item:
need_transform = True
break

Expand Down