11"""
22ChromaDB vector store implementation.
33"""
4- from typing import List , Dict , Any , Optional , Type , Union
4+ from typing import List , Dict , Any , Optional , Type , Union , Literal
55from uuid import uuid4
66
77from pydantic import BaseModel , Field , ConfigDict
88
99from .base import VectorStore , D
1010from .chroma_filter_translator import ChromaFilterTranslator
11- from ..node import Node , DocumentNode , Chunk , SymNode
11+ from ..node import Node , DocumentNode , Chunk , SymNode , ObjectNode
1212from ..filters import MetadataFilter
1313
1414
@@ -24,6 +24,8 @@ class ChromaConfig(BaseModel):
2424 collection_name : str = "documents"
2525 persist_directory : Optional [str ] = None
2626 distance : str = "cosine" # Can be "cosine", "l2", or "ip" (inner product)
27+ enable_hybrid : bool = False # Enable hybrid search (not fully supported by ChromaDB)
28+ fusion_method : Literal ["rrf" , "dbsf" ] = "rrf" # Fusion method for hybrid search (for API compatibility)
2729
2830
2931class ChromaVectorStore (VectorStore [D ]):
@@ -38,6 +40,8 @@ class ChromaVectorStore(VectorStore[D]):
3840 collection_name : str = Field (description = "Name of the collection" )
3941 document_class : Optional [Type [D ]] = Field (default = None , description = "Document class type" )
4042 distance : str = Field (default = "cosine" , description = "Distance metric (cosine, l2, or ip)" )
43+ enable_hybrid : bool = Field (default = False , description = "Enable hybrid search (note: limited support in ChromaDB)" )
44+ fusion_method : Literal ["rrf" , "dbsf" , "mmr" ] = Field (default = "rrf" , description = "Fusion method for hybrid search (for API compatibility)" )
4145 _collection : Any = None # ChromaDB collection instance
4246
4347 model_config = ConfigDict (
@@ -52,6 +56,8 @@ def __init__(
5256 embeddings : Optional [Any ] = None ,
5357 document_class : Optional [Type [D ]] = None ,
5458 distance : str = "cosine" ,
59+ enable_hybrid : bool = False ,
60+ fusion_method : Literal ["rrf" , "dbsf" ] = "rrf" ,
5561 ** kwargs
5662 ):
5763 """
@@ -63,6 +69,8 @@ def __init__(
6369 embeddings: Embeddings model for generating document embeddings
6470 document_class: The document model class (defaults to Node if not provided)
6571 distance: Distance metric to use ("cosine", "l2", or "ip")
72+ enable_hybrid: Enable hybrid search (note: limited support in ChromaDB)
73+ fusion_method: Fusion method for hybrid search ("rrf" or "dbsf", for API compatibility)
6674 """
6775 if not CHROMADB_AVAILABLE :
6876 raise ImportError (
@@ -75,6 +83,8 @@ def __init__(
7583 collection_name = collection_name ,
7684 document_class = document_class or Node , # type: ignore
7785 distance = distance ,
86+ enable_hybrid = enable_hybrid ,
87+ fusion_method = fusion_method ,
7888 ** kwargs
7989 )
8090 self ._embeddings = embeddings
@@ -120,10 +130,94 @@ def _get_doc_class(self, class_name: Optional[str]) -> Type[D]:
120130 return DocumentNode # type: ignore
121131 elif class_name == 'Node' :
122132 return Node # type: ignore
133+ elif class_name == 'ObjectNode' :
134+ return ObjectNode # type: ignore
123135 else :
124136 # Fall back to the default document class
125137 return self .document_class
126138
139+ async def find (self , key : str , value : str , limit : int = 10 ):
140+ """
141+ Find documents by a specific key-value pair.
142+
143+ Args:
144+ key: The key to search by
145+ value: The value to search for
146+ limit: Maximum number of results to return
147+
148+ Returns:
149+ List of documents that match the search criteria
150+ """
151+ # Build where filter for the key-value pair
152+ where_filter = {key : {"$eq" : value }}
153+
154+ # Get documents matching the filter
155+ results = self ._collection .get (
156+ where = where_filter ,
157+ limit = limit ,
158+ include = ["embeddings" , "metadatas" , "documents" ]
159+ )
160+
161+ # Parse results
162+ output = []
163+ if results and results ['ids' ] and len (results ['ids' ]) > 0 :
164+ for i in range (len (results ['ids' ])):
165+ doc_id = results ['ids' ][i ]
166+ metadata = results ['metadatas' ][i ] if results .get ('metadatas' ) is not None and len (results ['metadatas' ]) > 0 else {}
167+ embedding = results ['embeddings' ][i ] if results .get ('embeddings' ) is not None and len (results ['embeddings' ]) > 0 else None
168+ text = results ['documents' ][i ] if results .get ('documents' ) is not None and len (results ['documents' ]) > 0 else ""
169+
170+ # Reconstruct document
171+ doc_dict = {'text' : text }
172+ user_metadata = {} # Collect flattened metadata fields
173+
174+ # Parse metadata back to proper types
175+ import json
176+ for meta_key , meta_value in metadata .items ():
177+ if meta_key .startswith ('_' ):
178+ continue # Skip internal fields for now
179+
180+ # Check if this is a flattened metadata field
181+ if meta_key .startswith ('metadata.' ):
182+ # Extract the actual metadata key
183+ metadata_key = meta_key [len ('metadata.' ):]
184+ # Try to parse JSON strings back to objects
185+ if isinstance (meta_value , str ):
186+ try :
187+ user_metadata [metadata_key ] = json .loads (meta_value )
188+ except (json .JSONDecodeError , TypeError ):
189+ user_metadata [metadata_key ] = meta_value
190+ else :
191+ user_metadata [metadata_key ] = meta_value
192+ else :
193+ # Regular field
194+ if isinstance (meta_value , str ):
195+ try :
196+ doc_dict [meta_key ] = json .loads (meta_value )
197+ except (json .JSONDecodeError , TypeError ):
198+ doc_dict [meta_key ] = meta_value
199+ else :
200+ doc_dict [meta_key ] = meta_value
201+
202+ # Add reconstructed user metadata
203+ if user_metadata :
204+ doc_dict ['metadata' ] = user_metadata
205+
206+ # Add back essential fields
207+ doc_dict ['id' ] = metadata .get ('id' , doc_id )
208+ if embedding is not None :
209+ doc_dict ['embedding' ] = embedding
210+
211+ # Get document class
212+ doc_class_name = metadata .get ('_doc_class' )
213+ doc_class = self ._get_doc_class (doc_class_name )
214+
215+ # Create document instance
216+ doc = doc_class (** doc_dict )
217+ output .append (doc )
218+
219+ return output
220+
127221 async def insert_nodes (self , documents : List [D ], index_id : Optional [str ] = None , show_progress : bool = False ) -> List [str ]:
128222 """
129223 Add documents to the Chroma collection.
@@ -504,5 +598,7 @@ def from_config(cls, config: Union[Dict[str, Any], ChromaConfig], embeddings: Op
504598 collection_name = config .collection_name ,
505599 embeddings = embeddings ,
506600 document_class = Node , # Defaults to Node
507- distance = config .distance
601+ distance = config .distance ,
602+ enable_hybrid = config .enable_hybrid ,
603+ fusion_method = config .fusion_method
508604 )
0 commit comments