-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
124 lines (106 loc) · 3.31 KB
/
model.py
File metadata and controls
124 lines (106 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import urllib.request
import zipfile
import json
# from os.path import abspath, joins
import torch
import pandas as pd
from utils.data import BSARDataset
from utils.eval import BiEncoderEvaluator
from models.trainable_dense_models import BiEncoder
from sentence_transformers import util
# define class StoredModel
class StoredModel:
# constructor
def __init__(
self,
model_path,
full_text_path,
model,
document_ids,
d_embeddings,
device,
documents,
batch_size=2,
):
self.model_path = model_path
self.full_text_path = full_text_path
self.model = model
self.document_ids = document_ids
self.d_embeddings = d_embeddings
self.device = device
self.documents = documents
self.batch_size = batch_size
def infer(self, prompt, k=10):
q_embeddings = self.model.q_encoder.encode(
texts=[prompt], device=self.device, batch_size=self.batch_size
)
all_results = util.semantic_search(
query_embeddings=q_embeddings,
corpus_embeddings=self.d_embeddings,
top_k=k,
score_function=util.dot_score,
)
all_results = [
[result["corpus_id"] for result in results] for results in all_results
]
results = []
for result in all_results[0]:
print(result)
results.append(
{
"id": result,
"data": self.documents[result],
}
)
return results
def fetch_model(url):
# Download the model
model_filename = url.split("/")[-1]
urllib.request.urlretrieve(url, model_filename)
# Unzip the model to the specified folder
with zipfile.ZipFile(model_filename, "r") as zip_ref:
zip_ref.extractall("./assets/models")
# Remove the downloaded zip file
os.remove(model_filename)
def cap_string(text, cap=200):
if len(text) > cap:
return text[:cap] + "..."
else:
return text
def load_model(model_path, full_text_path, batch_size=2):
model = BiEncoder.load(model_path)
documents_df = pd.read_csv(full_text_path)
documents_dict = {}
# for each row in documents_df
for idx, row in documents_df.iterrows():
# content, title, cat_1, cat_2, book_no, page_no
doc_dict = {
"id": row["id"],
"content": cap_string(row["content"]),
"title": row["title"],
"cat1": row["cat_1"],
"cat2": row["cat_2"],
"book_no": row["book_no"],
"page_no": row["page_no"],
}
# add the dictionary to the list
documents_dict[row["id"]] = doc_dict
id_doc_pair = documents_df.set_index("id")["content"].to_dict()
document_ids = list(id_doc_pair.keys())
documents = [id_doc_pair[doc_id] for doc_id in document_ids]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_size = 2
d_embeddings = model.d_encoder.encode(
texts=documents, device=device, batch_size=batch_size
)
return StoredModel(
model_path,
full_text_path,
model,
document_ids,
d_embeddings,
device,
documents_dict,
batch_size,
)