Skip to content

Commit e39613a

Browse files
feat(search): remove sentence transformers (#116)
* Update poetry requirements * Switch from sentence-transformers to transformers only * neww lock --------- Co-authored-by: Jean-Marc SEVIN <jean-marc.sevin@learningplanetinstitute.org>
1 parent 2ba96e5 commit e39613a

3 files changed

Lines changed: 38 additions & 21 deletions

File tree

poetry.lock

Lines changed: 15 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ python-multipart = "^0.0.18"
2424
qdrant-client = "^1.15.0"
2525
requests = "^2.32.4"
2626
scikit-learn = "^1.5.1"
27-
sentence-transformers = "^3.4.1"
2827
sqlalchemy = "^2.0.35"
2928
transformers="^4.50.0"
3029
torch = {version = "^2.2.2+cpu", source = "pytorch_cpu"}

src/app/services/search.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
from qdrant_client import models as qdrant_models
1212
from qdrant_client.http import exceptions as qdrant_exceptions
1313
from qdrant_client.http import models as http_models
14-
from sentence_transformers import SentenceTransformer
1514
from sklearn.metrics.pairwise import cosine_similarity
15+
import torch
16+
from transformers import AutoModel, AutoTokenizer
1617

1718
from src.app.models.collections import Collection
1819
from src.app.models.search import (
@@ -139,10 +140,14 @@ def _get_model(self, curr_model: str) -> dict:
139140
try:
140141
time_start = time.time()
141142
# TODO: path should be an env variable
142-
model = SentenceTransformer(f"../models/embedding/{curr_model}/")
143+
model_path = f"../models/embedding/{curr_model}/"
144+
model = AutoModel.from_pretrained(model_path)
145+
tokenizer = AutoTokenizer.from_pretrained(model_path)
146+
model.eval()
143147
self.model[curr_model] = {
144-
"max_seq_length": model.get_max_seq_length(),
148+
"max_seq_length": tokenizer.model_max_length,
145149
"instance": model,
150+
"tokenizer": tokenizer,
146151
}
147152
time_end = time.time()
148153

@@ -179,6 +184,15 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]:
179184

180185
return inputs
181186

187+
@log_time_and_error_sync
188+
def _compute_embeddings(self, model, tokenizer, inputs: list[str]) -> np.ndarray:
189+
with torch.no_grad():
190+
tokenized_inputs = tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')
191+
model_output = model(**tokenized_inputs)
192+
embeddings = model_output[0][:, 0]
193+
embeddings = torch.nn.functional.normalize(embeddings, dim=1).numpy()
194+
return embeddings
195+
182196
@log_time_and_error_sync
183197
async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray:
184198
logger.debug("Creating embeddings model=%s", curr_model)
@@ -188,11 +202,11 @@ async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray:
188202

189203
seq_len = self.model[curr_model]["max_seq_length"]
190204
model = self.model[curr_model]["instance"]
205+
tokenizer = self.model[curr_model]["tokenizer"]
191206
inputs = self._split_input_seq_len(seq_len, search_input)
192207

193208
try:
194-
embeddings = await run_in_threadpool(model.encode, inputs)
195-
# embeddings = model.encode(sentences=inputs)
209+
embeddings = await run_in_threadpool(self._compute_embeddings, model, tokenizer, inputs)
196210
embeddings = np.mean(embeddings, axis=0)
197211
except Exception as ex:
198212
logger.error("api_error=EMBED_ERROR model=%s", curr_model)
@@ -210,8 +224,11 @@ async def simple_search_handler(self, qp: EnhancedSearchQuery):
210224
model = await run_in_threadpool(
211225
self._get_model, curr_model="granite-embedding-107m-multilingual"
212226
)
227+
213228
model_instance = model["instance"]
214-
embedding = await run_in_threadpool(model_instance.encode, qp.query)
229+
tokenizer = model["tokenizer"]
230+
embedding_input = qp.query if isinstance(qp.query, list) else [qp.query]
231+
embedding = await run_in_threadpool(self._compute_embeddings, model_instance, tokenizer, embedding_input)
215232
result = await self.search(
216233
collection_info="collection_welearn_mul_granite-embedding-107m-multilingual",
217234
embedding=embedding,

0 commit comments

Comments
 (0)