Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions openviking/retrieve/hierarchical_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,30 @@ async def retrieve(

final_metadata_filter = {"op": "and", "conds": filters_to_merge}

if not await self.storage.collection_exists(collection):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

内部默认都是 colleciton 是 context,不过这里检验下也没问题

logger.warning(f"[RecursiveSearch] Collection {collection} does not exist")
return QueryResult(
query=query,
matched_contexts=[],
searched_directories=[],
)

# Generate query vectors once to avoid duplicate embedding calls
query_vector = None
sparse_query_vector = None
if self.embedder:
result: EmbedResult = self.embedder.embed(query.query)
query_vector = result.dense_vector
sparse_query_vector = result.sparse_vector

# Step 1: Determine starting directories based on context_type
root_uris = self._get_root_uris_for_type(query.context_type)

# Step 2: Global vector search to supplement starting points
global_results = await self._global_vector_search(
query=query.query,
collection=collection,
query_vector=query_vector,
sparse_query_vector=sparse_query_vector,
limit=self.GLOBAL_SEARCH_TOPK,
filter=final_metadata_filter,
)
Expand All @@ -125,8 +142,10 @@ async def retrieve(
# Step 4: Recursive search
candidates = await self._recursive_search(
query=query.query,
starting_points=starting_points,
collection=collection,
query_vector=query_vector,
sparse_query_vector=sparse_query_vector,
starting_points=starting_points,
limit=limit,
mode=mode,
threshold=effective_threshold,
Expand All @@ -145,21 +164,16 @@ async def retrieve(

async def _global_vector_search(
self,
query: str,
collection: str,
query_vector: Optional[List[float]],
sparse_query_vector: Optional[Dict[str, float]],
limit: int,
filter: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""Global vector search to locate initial directories."""
if not self.embedder:
return []
if not await self.storage.collection_exists(collection):
return []
result: EmbedResult = self.embedder.embed(query)
query_vector = result.dense_vector
if not query_vector:
return []
sparse_query_vector = result.sparse_vector or {}
sparse_query_vector = sparse_query_vector or {}

global_filter = {
"op": "and",
Expand Down Expand Up @@ -215,8 +229,10 @@ def _merge_starting_points(
async def _recursive_search(
self,
query: str,
starting_points: List[Tuple[str, float]],
collection: str,
query_vector: Optional[List[float]],
sparse_query_vector: Optional[Dict[str, float]],
starting_points: List[Tuple[str, float]],
limit: int,
mode: str,
threshold: Optional[float] = None,
Expand Down Expand Up @@ -249,18 +265,7 @@ def merge_filter(base_filter: Dict, extra_filter: Optional[Dict]) -> Dict:
return base_filter
return {"op": "and", "conds": [base_filter, extra_filter]}

# Generate query vectors
query_vector = None
sparse_query_vector = None

if self.embedder:
result: EmbedResult = self.embedder.embed(query)
query_vector = result.dense_vector
sparse_query_vector = result.sparse_vector

if not await self.storage.collection_exists(collection):
logger.warning(f"[RecursiveSearch] Collection {collection} does not exist")
return []
sparse_query_vector = sparse_query_vector or None

collected: List[Dict[str, Any]] = [] # Collected results (directories and leaves)
dir_queue: List[tuple] = [] # Priority queue: (-score, uri)
Expand Down