1+ import torch
2+ from transformers import AutoTokenizer , AutoModelForCausalLM
3+ from .Reranker_Base import RerankerBase
4+ from ..Retrieval .RetrieverBase import Document
5+
6+ class Qwen3Reranker (RerankerBase ):
7+ def __init__ (self , model_name_or_path : str , max_length : int = 4096 , instruction = None , attn_type = 'causal' , device_id = "cuda:0" , ** kwargs ):
8+ super ().__init__ ()
9+ device = torch .device (device_id )
10+ self .max_length = max_length
11+ self .tokenizer = AutoTokenizer .from_pretrained (model_name_or_path , trust_remote_code = True , padding_side = 'left' )
12+ self .lm = AutoModelForCausalLM .from_pretrained (model_name_or_path , trust_remote_code = True , torch_dtype = torch .float16 )
13+ self .lm = self .lm .to (device ).eval ()
14+ self .token_false_id = self .tokenizer .convert_tokens_to_ids ("no" )
15+ self .token_true_id = self .tokenizer .convert_tokens_to_ids ("yes" )
16+ self .prefix = "<|im_start|>system\n Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \" yes\" or \" no\" .<|im_end|>\n <|im_start|>user\n "
17+ self .suffix = "<|im_end|>\n <|im_start|>assistant\n <think>\n \n </think>\n \n "
18+ self .prefix_tokens = self .tokenizer .encode (self .prefix , add_special_tokens = False )
19+ self .suffix_tokens = self .tokenizer .encode (self .suffix , add_special_tokens = False )
20+ self .instruction = instruction or "Given the user query, retrieval the relevant passages"
21+ self .device = device
22+
23+ def format_instruction (self , instruction , query , doc ):
24+ if instruction is None :
25+ instruction = self .instruction
26+ output = f"<Instruct>: { instruction } \n <Query>: { query } \n <Document>: { doc } "
27+ return output
28+
29+ def process_inputs (self , pairs ):
30+ out = self .tokenizer (
31+ pairs , padding = False , truncation = 'longest_first' ,
32+ return_attention_mask = False , max_length = self .max_length - len (self .prefix_tokens ) - len (self .suffix_tokens )
33+ )
34+ for i , ele in enumerate (out ['input_ids' ]):
35+ out ['input_ids' ][i ] = self .prefix_tokens + ele + self .suffix_tokens
36+ out = self .tokenizer .pad (out , padding = True , return_tensors = "pt" , max_length = self .max_length )
37+ for key in out :
38+ out [key ] = out [key ].to (self .lm .device )
39+ return out
40+
41+ @torch .no_grad ()
42+ def compute_logits (self , inputs , ** kwargs ):
43+ batch_scores = self .lm (** inputs ).logits [:, - 1 , :]
44+ true_vector = batch_scores [:, self .token_true_id ]
45+ false_vector = batch_scores [:, self .token_false_id ]
46+ batch_scores = torch .stack ([false_vector , true_vector ], dim = 1 )
47+ batch_scores = torch .nn .functional .log_softmax (batch_scores , dim = 1 )
48+ scores = batch_scores [:, 1 ].exp ().tolist ()
49+ return scores
50+
51+ def compute_scores (self , pairs , instruction = None , ** kwargs ):
52+ pairs = [self .format_instruction (instruction , query , doc ) for query , doc in pairs ]
53+ inputs = self .process_inputs (pairs )
54+ scores = self .compute_logits (inputs )
55+ return scores
56+
57+ def rerank (self , query : str , documents : list [Document ], k : int = None , batch_size : int = 8 , ** kwargs ) -> list [Document ]:
58+ # 1. 组装 (query, doc.content) 对
59+ pairs = [(query , doc .content ) for doc in documents ]
60+
61+ # 2. 计算分数
62+ all_scores = []
63+ for i in range (0 , len (pairs ), batch_size ):
64+ batch_pairs = pairs [i :i + batch_size ]
65+ batch_scores = self .compute_scores (batch_pairs )
66+ all_scores .extend (batch_scores )
67+ scores = all_scores
68+
69+ # 3. 按分数排序
70+ doc_score_pairs = list (zip (documents , scores ))
71+ doc_score_pairs .sort (key = lambda x : x [1 ], reverse = True )
72+ reranked_docs = [doc for doc , score in doc_score_pairs ]
73+ if k is not None :
74+ reranked_docs = reranked_docs [:k ]
75+ return reranked_docs
0 commit comments