-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathqdrant_utils.py
More file actions
109 lines (86 loc) · 3.87 KB
/
qdrant_utils.py
File metadata and controls
109 lines (86 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import VectorParams
from langchain_qdrant import QdrantVectorStore
import time
# ========================================================================
QDRANT_URL = 'http://localhost:6333'
VECTOR_SIZE = 1024 # if we plan to embed large data in the future this must be updated
BATCH_SIZE = 100
EMBEDDING_DELAY_SECONDS = 5
# ========================================================================
# TODO: create a single instance of this to be shared with all other rags
embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-large-en",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': False}
)
# create collections if they don't already exist
def create_collections(collections: list[str]):
client = QdrantClient(
url=QDRANT_URL, prefer_grpc=False
)
# for each collection passed in args, if it doesn't exist create it in Qdrant
for collection in collections:
try:
qd_collections = client.get_collections()
if any(qd_collection.name == collection for qd_collection in qd_collections.collections):
print(f'[QDRANT] Collection {collection} already exists')
else:
client.create_collection(
collection_name=collection,
vectors_config=VectorParams(size=VECTOR_SIZE, distance='Cosine'),
)
print(f"[QDRANT] Successfully creating collection: {collection}")
except Exception as e:
print(f"[ERROR] Error creating during creation collection: {collection}\n {e}")
client.close()
# load embeddings into collection with custom metadata
# use batches to prevent file descriptor error
def load_embeddings_custom_metadata(texts: list[str], metadata: list[dict], collection: str):
client = QdrantClient(
url=QDRANT_URL, prefer_grpc=False
)
text_embeddings = embeddings.embed_documents(texts)
print(f'[QDRANT] Vector embeddings created')
# length of texts and metadata should always be the same, but incase
min_size = min(len(texts), len(metadata))
for i in range(0, min_size, BATCH_SIZE):
batch_texts = texts[i:i + BATCH_SIZE]
batch_metadata = metadata[i:i + BATCH_SIZE]
batch_embeddings = text_embeddings[i:i + BATCH_SIZE]
client.upsert(
collection_name=collection,
points=[
models.PointStruct(
id=batch_metadata[n].get('id', i * BATCH_SIZE + n),
vector=embedding,
payload={
'metadata': batch_metadata[n],
'page_content': batch_texts[n]
}
)
for n, embedding in enumerate(batch_embeddings)
]
)
print(f'[QDRANT] Batch {min(i + BATCH_SIZE, min_size)} of {min_size}')
time.sleep(EMBEDDING_DELAY_SECONDS)
client.close()
print("[QDRANT] Embeddings successfully loaded into collection")
# execute a similarity search with the given query
# return the id of the top num_matches matches
def retrieve_relevant_context_ids(query: str, num_matches: int, collection: str) -> list[int]:
client = QdrantClient(
url=QDRANT_URL, prefer_grpc=False
)
vector_store = QdrantVectorStore(client=client, embedding=embeddings, collection_name=collection)
points = vector_store.similarity_search_with_score(query=query, k=num_matches)
content = []
for point in points:
body, _ = point # ignore vectors
if 'id' in body.metadata:
content.append(body.metadata['id'])
client.close()
print(f'[QDRANT] Retrieved {len(content)} points of exploit data')
return content