1111from qdrant_client import models as qdrant_models
1212from qdrant_client .http import exceptions as qdrant_exceptions
1313from qdrant_client .http import models as http_models
14- from sentence_transformers import SentenceTransformer
1514from sklearn .metrics .pairwise import cosine_similarity
15+ import torch
16+ from transformers import AutoModel , AutoTokenizer
1617
1718from src .app .models .collections import Collection
1819from src .app .models .search import (
@@ -139,10 +140,14 @@ def _get_model(self, curr_model: str) -> dict:
139140 try :
140141 time_start = time .time ()
141142 # TODO: path should be an env variable
142- model = SentenceTransformer (f"../models/embedding/{ curr_model } /" )
143+ model_path = f"../models/embedding/{ curr_model } /"
144+ model = AutoModel .from_pretrained (model_path )
145+ tokenizer = AutoTokenizer .from_pretrained (model_path )
146+ model .eval ()
143147 self .model [curr_model ] = {
144- "max_seq_length" : model . get_max_seq_length () ,
148+ "max_seq_length" : tokenizer . model_max_length ,
145149 "instance" : model ,
150+ "tokenizer" : tokenizer ,
146151 }
147152 time_end = time .time ()
148153
@@ -179,6 +184,15 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]:
179184
180185 return inputs
181186
187+ @log_time_and_error_sync
188+ def _compute_embeddings (self , model , tokenizer , inputs : list [str ]) -> np .ndarray :
189+ with torch .no_grad ():
190+ tokenized_inputs = tokenizer (inputs , padding = True , truncation = True , return_tensors = 'pt' )
191+ model_output = model (** tokenized_inputs )
192+ embeddings = model_output [0 ][:, 0 ]
193+ embeddings = torch .nn .functional .normalize (embeddings , dim = 1 ).numpy ()
194+ return embeddings
195+
182196 @log_time_and_error_sync
183197 async def _embed_query (self , search_input : str , curr_model : str ) -> np .ndarray :
184198 logger .debug ("Creating embeddings model=%s" , curr_model )
@@ -188,11 +202,11 @@ async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray:
188202
189203 seq_len = self .model [curr_model ]["max_seq_length" ]
190204 model = self .model [curr_model ]["instance" ]
205+ tokenizer = self .model [curr_model ]["tokenizer" ]
191206 inputs = self ._split_input_seq_len (seq_len , search_input )
192207
193208 try :
194- embeddings = await run_in_threadpool (model .encode , inputs )
195- # embeddings = model.encode(sentences=inputs)
209+ embeddings = await run_in_threadpool (self ._compute_embeddings , model , tokenizer , inputs )
196210 embeddings = np .mean (embeddings , axis = 0 )
197211 except Exception as ex :
198212 logger .error ("api_error=EMBED_ERROR model=%s" , curr_model )
@@ -210,8 +224,11 @@ async def simple_search_handler(self, qp: EnhancedSearchQuery):
210224 model = await run_in_threadpool (
211225 self ._get_model , curr_model = "granite-embedding-107m-multilingual"
212226 )
227+
213228 model_instance = model ["instance" ]
214- embedding = await run_in_threadpool (model_instance .encode , qp .query )
229+ tokenizer = model ["tokenizer" ]
230+ embedding_input = qp .query if isinstance (qp .query , list ) else [qp .query ]
231+ embedding = await run_in_threadpool (self ._compute_embeddings , model_instance , tokenizer , embedding_input )
215232 result = await self .search (
216233 collection_info = "collection_welearn_mul_granite-embedding-107m-multilingual" ,
217234 embedding = embedding ,
0 commit comments