-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquery_documents.py
More file actions
71 lines (60 loc) · 1.97 KB
/
query_documents.py
File metadata and controls
71 lines (60 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from pylate import indexes, models, retrieve
import time
import json
from datetime import datetime
# 1. Load model
model = models.ColBERT(model_name_or_path="lightonai/Reason-ModernColBERT")
# 2. Load existing index
index = indexes.Voyager(
index_folder="./pylate_index",
index_name="my_docs_index"
)
# 3. Initialize retriever
retriever = retrieve.ColBERT(index=index)
# Load doc_id to filename mapping
with open("doc_id_map.json", "r", encoding="utf-8") as f:
doc_id_map = json.load(f)
log_file = "search_log.txt"
# 4. Run queries
while True:
query = input("\nEnter your search query (type 'exit' to quit): ")
if query.lower() == 'exit':
break
start_time = time.time()
# Encode query
query_embedding = model.encode(
[query],
is_query=True
)
# Retrieve results
results = retriever.retrieve(query_embedding, k=5)
print("Raw results:", results) # Debug print
# Display results
print(f"\nResults for: '{query}'")
print(f"Search took: {time.time() - start_time:.2f} seconds")
print("=" * 50)
# Prepare log entry
log_entry = {
"timestamp": datetime.now().isoformat(),
"query": query,
"results": []
}
for hit in results[0]:
doc_id = hit.get('id', 'N/A')
score = hit.get('score', 'N/A')
filename = doc_id_map.get(doc_id, 'Unknown file')
try:
score_str = f"{float(score):.4f}"
print(f"Filename: {filename} | Document ID: {doc_id} | Score: {score_str}")
except (ValueError, TypeError):
score_str = str(score)
print(f"Filename: {filename} | Document ID: {doc_id} | Score: {score_str}")
print("-" * 50)
log_entry["results"].append({
"filename": filename,
"doc_id": doc_id,
"score": score_str
})
# Write log entry to file
with open(log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry) + "\n")