Skip to content

Commit fd4762e

Browse files
committed
Fixed tests
1 parent 493044d commit fd4762e

3 files changed

Lines changed: 106 additions & 219 deletions

File tree

packages/fetchcraft-core/src/fetchcraft/vector_store/chroma_store.py

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""
22
ChromaDB vector store implementation.
33
"""
4-
from typing import List, Dict, Any, Optional, Type, Union
4+
from typing import List, Dict, Any, Optional, Type, Union, Literal
55
from uuid import uuid4
66

77
from pydantic import BaseModel, Field, ConfigDict
88

99
from .base import VectorStore, D
1010
from .chroma_filter_translator import ChromaFilterTranslator
11-
from ..node import Node, DocumentNode, Chunk, SymNode
11+
from ..node import Node, DocumentNode, Chunk, SymNode, ObjectNode
1212
from ..filters import MetadataFilter
1313

1414

@@ -24,6 +24,8 @@ class ChromaConfig(BaseModel):
2424
collection_name: str = "documents"
2525
persist_directory: Optional[str] = None
2626
distance: str = "cosine" # Can be "cosine", "l2", or "ip" (inner product)
27+
enable_hybrid: bool = False # Enable hybrid search (not fully supported by ChromaDB)
28+
fusion_method: Literal["rrf", "dbsf"] = "rrf" # Fusion method for hybrid search (for API compatibility)
2729

2830

2931
class ChromaVectorStore(VectorStore[D]):
@@ -38,6 +40,8 @@ class ChromaVectorStore(VectorStore[D]):
3840
collection_name: str = Field(description="Name of the collection")
3941
document_class: Optional[Type[D]] = Field(default=None, description="Document class type")
4042
distance: str = Field(default="cosine", description="Distance metric (cosine, l2, or ip)")
43+
enable_hybrid: bool = Field(default=False, description="Enable hybrid search (note: limited support in ChromaDB)")
44+
fusion_method: Literal["rrf", "dbsf", "mmr"] = Field(default="rrf", description="Fusion method for hybrid search (for API compatibility)")
4145
_collection: Any = None # ChromaDB collection instance
4246

4347
model_config = ConfigDict(
@@ -52,6 +56,8 @@ def __init__(
5256
embeddings: Optional[Any] = None,
5357
document_class: Optional[Type[D]] = None,
5458
distance: str = "cosine",
59+
enable_hybrid: bool = False,
60+
fusion_method: Literal["rrf", "dbsf"] = "rrf",
5561
**kwargs
5662
):
5763
"""
@@ -63,6 +69,8 @@ def __init__(
6369
embeddings: Embeddings model for generating document embeddings
6470
document_class: The document model class (defaults to Node if not provided)
6571
distance: Distance metric to use ("cosine", "l2", or "ip")
72+
enable_hybrid: Enable hybrid search (note: limited support in ChromaDB)
73+
fusion_method: Fusion method for hybrid search ("rrf" or "dbsf", for API compatibility)
6674
"""
6775
if not CHROMADB_AVAILABLE:
6876
raise ImportError(
@@ -75,6 +83,8 @@ def __init__(
7583
collection_name=collection_name,
7684
document_class=document_class or Node, # type: ignore
7785
distance=distance,
86+
enable_hybrid=enable_hybrid,
87+
fusion_method=fusion_method,
7888
**kwargs
7989
)
8090
self._embeddings = embeddings
@@ -120,10 +130,94 @@ def _get_doc_class(self, class_name: Optional[str]) -> Type[D]:
120130
return DocumentNode # type: ignore
121131
elif class_name == 'Node':
122132
return Node # type: ignore
133+
elif class_name == 'ObjectNode':
134+
return ObjectNode # type: ignore
123135
else:
124136
# Fall back to the default document class
125137
return self.document_class
126138

139+
async def find(self, key: str, value: str, limit: int = 10):
140+
"""
141+
Find documents by a specific key-value pair.
142+
143+
Args:
144+
key: The key to search by
145+
value: The value to search for
146+
limit: Maximum number of results to return
147+
148+
Returns:
149+
List of documents that match the search criteria
150+
"""
151+
# Build where filter for the key-value pair
152+
where_filter = {key: {"$eq": value}}
153+
154+
# Get documents matching the filter
155+
results = self._collection.get(
156+
where=where_filter,
157+
limit=limit,
158+
include=["embeddings", "metadatas", "documents"]
159+
)
160+
161+
# Parse results
162+
output = []
163+
if results and results['ids'] and len(results['ids']) > 0:
164+
for i in range(len(results['ids'])):
165+
doc_id = results['ids'][i]
166+
metadata = results['metadatas'][i] if results.get('metadatas') is not None and len(results['metadatas']) > 0 else {}
167+
embedding = results['embeddings'][i] if results.get('embeddings') is not None and len(results['embeddings']) > 0 else None
168+
text = results['documents'][i] if results.get('documents') is not None and len(results['documents']) > 0 else ""
169+
170+
# Reconstruct document
171+
doc_dict = {'text': text}
172+
user_metadata = {} # Collect flattened metadata fields
173+
174+
# Parse metadata back to proper types
175+
import json
176+
for meta_key, meta_value in metadata.items():
177+
if meta_key.startswith('_'):
178+
continue # Skip internal fields for now
179+
180+
# Check if this is a flattened metadata field
181+
if meta_key.startswith('metadata.'):
182+
# Extract the actual metadata key
183+
metadata_key = meta_key[len('metadata.'):]
184+
# Try to parse JSON strings back to objects
185+
if isinstance(meta_value, str):
186+
try:
187+
user_metadata[metadata_key] = json.loads(meta_value)
188+
except (json.JSONDecodeError, TypeError):
189+
user_metadata[metadata_key] = meta_value
190+
else:
191+
user_metadata[metadata_key] = meta_value
192+
else:
193+
# Regular field
194+
if isinstance(meta_value, str):
195+
try:
196+
doc_dict[meta_key] = json.loads(meta_value)
197+
except (json.JSONDecodeError, TypeError):
198+
doc_dict[meta_key] = meta_value
199+
else:
200+
doc_dict[meta_key] = meta_value
201+
202+
# Add reconstructed user metadata
203+
if user_metadata:
204+
doc_dict['metadata'] = user_metadata
205+
206+
# Add back essential fields
207+
doc_dict['id'] = metadata.get('id', doc_id)
208+
if embedding is not None:
209+
doc_dict['embedding'] = embedding
210+
211+
# Get document class
212+
doc_class_name = metadata.get('_doc_class')
213+
doc_class = self._get_doc_class(doc_class_name)
214+
215+
# Create document instance
216+
doc = doc_class(**doc_dict)
217+
output.append(doc)
218+
219+
return output
220+
127221
async def insert_nodes(self, documents: List[D], index_id: Optional[str] = None, show_progress: bool = False) -> List[str]:
128222
"""
129223
Add documents to the Chroma collection.
@@ -504,5 +598,7 @@ def from_config(cls, config: Union[Dict[str, Any], ChromaConfig], embeddings: Op
504598
collection_name=config.collection_name,
505599
embeddings=embeddings,
506600
document_class=Node, # Defaults to Node
507-
distance=config.distance
601+
distance=config.distance,
602+
enable_hybrid=config.enable_hybrid,
603+
fusion_method=config.fusion_method
508604
)

packages/fetchcraft-core/tests/test_hash_update.py

Lines changed: 0 additions & 209 deletions
This file was deleted.

0 commit comments

Comments
 (0)