-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembed_rerank_model.py
More file actions
97 lines (81 loc) · 3.82 KB
/
embed_rerank_model.py
File metadata and controls
97 lines (81 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from openai import OpenAI
import yaml
import logging
import numpy as np
class BaseModel:
def __init__(self, model_type, model_name, config_path='./configs/models.yaml', logger=None):
self.config = self.load_config(config_path)
self.model_config = self.config[model_type][model_name]
self.model = self.model_config['model']
self.logger = logger if logger is not None else self._default_logger(model_type)
self.local_api_key = self.model_config['local_api_key']
self.local_base_url = self.model_config['local_base_url']
self.logger.info(f'[{model_type}] Initializing Model: {model_name}')
self.client = OpenAI(api_key=self.local_api_key, base_url=self.local_base_url)
def load_config(self, path):
with open(path, 'r', encoding='utf-8') as file:
return yaml.safe_load(file)
def _default_logger(self, model_type):
logger = logging.getLogger(f"{model_type}Logger")
logger.addHandler(logging.NullHandler())
return logger
class EmbeddingModel(BaseModel):
def __init__(self, embedding_model, config_path='./configs/models.yaml', logger=None):
super().__init__('embedding_models', embedding_model, config_path, logger)
def embed_query(self, query=None):
response = self.client.embeddings.create(
input=query,
model=self.model
)
return response.data[0].embedding
def embed_documents(self, documents=None):
response = self.client.embeddings.create(
input=documents,
model=self.model
)
embs = [response.data[i].embedding for i in range(len(response.data))]
return embs
class RerankingModel(BaseModel):
def __init__(self, reranking_model, config_path='./llm_tools/configs/models.yaml', logger=None):
super().__init__('reranking_models', reranking_model, config_path, logger)
def rerank_query(self, input=None, query=None):
response = self.client.embeddings.create(
model=self.model,
input=input,
extra_body={"query": query}
)
return response.data[0].embedding
def rerank_documents(self, documents=None, query=None):
response = self.client.embeddings.create(
model=self.model,
input=documents,
extra_body={"query": query}
)
scores = [response.data[i].embedding for i in range(len(response.data))]
return scores
# 使用範例
if __name__ == '__main__':
#### EmbeddingModel ####
embmodel = EmbeddingModel(embedding_model="m3e-base", config_path='./configs/models.yaml')
# Single query embedding
input = "The food was delicious and the waiter was friendly."
query_embedding = np.array(embmodel.embed_query(input))
print(len(query_embedding))
# Document embedding
documents = ["The food was delicious and the waiter was friendly.",
"The service was slow and the food was not very good."]
document_embeddings = np.array(embmodel.embed_documents(documents))
print(document_embeddings.shape)
#### RerankingModel ####
rerankmodel = RerankingModel(reranking_model="bge-reranker-large", config_path='./configs/models.yaml')
# Single query reranking
input = "The food was delicious and the waiter was friendly."
query = "The food was delicious and the waiter was friendly."
query_reranking = rerankmodel.rerank_query(input, query)
print(query_reranking)
# Document reranking
documents = ["The food was delicious and the waiter was friendly.",
"The service was slow and the food was not very good."]
query = "The food was delicious and the waiter was friendly."
document_reranking = np.array(rerankmodel.rerank_documents(documents, query))
print(document_reranking)