diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index a0a54b2f..689f0b2c 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -108,13 +108,30 @@ async def retrieve( final_metadata_filter = {"op": "and", "conds": filters_to_merge} + if not await self.storage.collection_exists(collection): + logger.warning(f"[RecursiveSearch] Collection {collection} does not exist") + return QueryResult( + query=query, + matched_contexts=[], + searched_directories=[], + ) + + # Generate query vectors once to avoid duplicate embedding calls + query_vector = None + sparse_query_vector = None + if self.embedder: + result: EmbedResult = self.embedder.embed(query.query) + query_vector = result.dense_vector + sparse_query_vector = result.sparse_vector + # Step 1: Determine starting directories based on context_type root_uris = self._get_root_uris_for_type(query.context_type) # Step 2: Global vector search to supplement starting points global_results = await self._global_vector_search( - query=query.query, collection=collection, + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, limit=self.GLOBAL_SEARCH_TOPK, filter=final_metadata_filter, ) @@ -125,8 +142,10 @@ async def retrieve( # Step 4: Recursive search candidates = await self._recursive_search( query=query.query, - starting_points=starting_points, collection=collection, + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + starting_points=starting_points, limit=limit, mode=mode, threshold=effective_threshold, @@ -145,21 +164,16 @@ async def retrieve( async def _global_vector_search( self, - query: str, collection: str, + query_vector: Optional[List[float]], + sparse_query_vector: Optional[Dict[str, float]], limit: int, filter: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """Global vector search to locate initial directories.""" - if not self.embedder: - return [] - if not await self.storage.collection_exists(collection): - return [] - result: EmbedResult = self.embedder.embed(query) - query_vector = result.dense_vector if not query_vector: return [] - sparse_query_vector = result.sparse_vector or {} + sparse_query_vector = sparse_query_vector or {} global_filter = { "op": "and", @@ -215,8 +229,10 @@ def _merge_starting_points( async def _recursive_search( self, query: str, - starting_points: List[Tuple[str, float]], collection: str, + query_vector: Optional[List[float]], + sparse_query_vector: Optional[Dict[str, float]], + starting_points: List[Tuple[str, float]], limit: int, mode: str, threshold: Optional[float] = None, @@ -249,18 +265,7 @@ def merge_filter(base_filter: Dict, extra_filter: Optional[Dict]) -> Dict: return base_filter return {"op": "and", "conds": [base_filter, extra_filter]} - # Generate query vectors - query_vector = None - sparse_query_vector = None - - if self.embedder: - result: EmbedResult = self.embedder.embed(query) - query_vector = result.dense_vector - sparse_query_vector = result.sparse_vector - - if not await self.storage.collection_exists(collection): - logger.warning(f"[RecursiveSearch] Collection {collection} does not exist") - return [] + sparse_query_vector = sparse_query_vector or None collected: List[Dict[str, Any]] = [] # Collected results (directories and leaves) dir_queue: List[tuple] = [] # Priority queue: (-score, uri)