Skip to content

Commit 9e5d43d

Browse files
authored
Merge pull request #8 from DataArcTech/main
merge from main
2 parents 2653718 + fb91761 commit 9e5d43d

7 files changed

Lines changed: 174 additions & 47 deletions

File tree

rag_factory/Retrieval/Retriever/Retriever_BM25.py

Lines changed: 63 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from pydantic import ConfigDict, Field, model_validator
1010

1111
logger = logging.getLogger(__name__)
12-
12+
import numpy as np
1313
from rag_factory.Retrieval.RetrieverBase import BaseRetriever, Document
1414

1515

1616
def default_preprocessing_func(text: str) -> List[str]:
17-
"""默认的文本预处理函数
17+
"""默认的文本预处理函数,仅在英文文本上有效
1818
1919
Args:
2020
text: 输入文本
@@ -25,33 +25,51 @@ def default_preprocessing_func(text: str) -> List[str]:
2525
return text.split()
2626

2727

28-
def chinese_preprocessing_func(text: str) -> List[str]:
29-
"""中文文本预处理函数
30-
31-
Args:
32-
text: 输入的中文文本
33-
34-
Returns:
35-
分词后的词语列表
36-
"""
37-
try:
38-
import jieba
39-
return list(jieba.cut(text))
40-
except ImportError:
41-
logger.warning("jieba 未安装,使用默认分词方法。请安装: pip install jieba")
42-
return text.split()
4328

4429

4530
class BM25Retriever(BaseRetriever):
46-
"""BM25 检索器实现
47-
48-
基于 BM25 算法的文档检索器。
49-
使用 rank_bm25 库实现高效的 BM25 搜索。
50-
51-
注意:BM25 算法适用于相对静态的文档集合。虽然支持动态添加/删除文档,
52-
但每次操作都会重建整个索引,在大型文档集合上可能有性能问题。
53-
对于频繁更新的场景,建议使用 VectorStoreRetriever。
54-
31+
"""
32+
BM25Retriever 是一个基于 BM25 算法的文档检索器,适用于信息检索、问答系统、知识库等场景下的高效文本相关性排序。
33+
34+
该类通过集成 rank_bm25 库,实现了对文档集合的 BM25 检索,支持文档的动态添加、删除、批量构建索引等操作。
35+
适合文档集合相对静态、检索速度要求较高的场景。对于频繁增删文档的场景,建议使用向量检索(如 VectorStoreRetriever)。
36+
37+
主要特性:
38+
- 支持从文本列表或 Document 对象列表快速构建 BM25 检索器。
39+
- 支持自定义分词/预处理函数,适配不同语言和分词需求。
40+
- 支持动态添加、删除文档(每次操作会重建索引,适合中小规模数据集)。
41+
- 可获取检索分数、top-k 文档及分数、检索器配置信息等。
42+
- 兼容异步文档添加/删除,便于大规模数据处理。
43+
- 通过 Pydantic 校验参数,保证配置安全。
44+
45+
主要参数:
46+
vectorizer (Any): BM25 向量化器实例(通常为 BM25Okapi)。
47+
docs (List[Document]): 当前检索器持有的文档对象列表。
48+
k (int): 默认返回的相关文档数量。
49+
preprocess_func (Callable): 文本分词/预处理函数,默认为空格分词。
50+
bm25_params (Dict): 传递给 BM25Okapi 的参数(如 k1、b 等)。
51+
52+
核心方法:
53+
- from_texts/from_documents: 从原始文本或 Document 构建检索器。
54+
- _get_relevant_documents: 检索与查询最相关的前 k 个文档。
55+
- get_scores: 获取查询对所有文档的 BM25 分数。
56+
- get_top_k_with_scores: 获取 top-k 文档及其分数。
57+
- add_documents/delete_documents: 动态增删文档并重建索引。
58+
- get_bm25_info: 获取检索器配置信息和统计。
59+
- update_k: 动态调整返回文档数量。
60+
61+
性能注意事项:
62+
- 每次添加/删除文档都会重建 BM25 索引,适合文档量较小或更新不频繁的场景。
63+
- 文档量较大或频繁更新时,建议使用向量检索方案。
64+
- 支持异步操作,便于大规模数据处理。
65+
66+
典型用法:
67+
>>> retriever = BM25Retriever.from_texts(["文本1", "文本2"], k=3)
68+
>>> results = retriever._get_relevant_documents("查询语句")
69+
>>> retriever.add_documents([Document(content="新文档")])
70+
>>> retriever.delete_documents(ids=["doc_id"])
71+
>>> info = retriever.get_bm25_info()
72+
5573
Attributes:
5674
vectorizer: BM25 向量化器实例
5775
docs: 文档列表
@@ -125,7 +143,7 @@ def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
125143
Returns:
126144
验证后的值
127145
"""
128-
k = values.get("k", 4)
146+
k = values.get("k", 5)
129147
if k <= 0:
130148
raise ValueError(f"k 必须大于 0,当前值: {k}")
131149

@@ -259,46 +277,48 @@ def from_documents(
259277
)
260278

261279
def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
262-
"""获取与查询相关的文档
263-
280+
"""获取与查询相关的前k个文档
281+
264282
Args:
265283
query: 查询字符串
266284
**kwargs: 其他参数,可能包含 'k' 来覆盖默认的返回数量
267-
285+
268286
Returns:
269287
相关文档列表
270-
288+
271289
Raises:
272290
ValueError: 如果向量化器未初始化
273291
"""
274292
if self.vectorizer is None:
275293
raise ValueError("BM25 向量化器未初始化")
276-
294+
277295
if not self.docs:
278296
logger.warning("文档列表为空,返回空结果")
279297
return []
280-
298+
281299
# 获取返回文档数量
282300
k = kwargs.get('k', self.k)
283301
k = min(k, len(self.docs)) # 确保不超过总文档数
284-
302+
285303
try:
286304
# 预处理查询
287305
processed_query = self.preprocess_func(query)
288306
logger.debug(f"预处理后的查询: {processed_query}")
307+
308+
# 获取所有文档的分数
309+
scores = self.vectorizer.get_scores(processed_query)
310+
# 获取分数最高的前k个文档索引
289311

290-
# 获取相关文档
291-
relevant_docs = self.vectorizer.get_top_n(
292-
processed_query, self.docs, n=k
293-
)
294-
295-
logger.debug(f"找到 {len(relevant_docs)} 个相关文档")
296-
return relevant_docs
297-
312+
top_indices = np.argsort(scores)[::-1][:k]
313+
# 返回前k个文档
314+
top_docs = [self.docs[idx] for idx in top_indices]
315+
logger.debug(f"找到 {len(top_docs)} 个相关文档")
316+
return top_docs
317+
298318
except Exception as e:
299319
logger.error(f"BM25 搜索时发生错误: {e}")
300320
raise
301-
321+
302322
def get_scores(self, query: str) -> List[float]:
303323
"""获取查询对所有文档的 BM25 分数
304324

rag_factory/Retrieval/Retriever/Retriever_VectorStore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import logging
1111

1212
from pydantic import ConfigDict, Field, model_validator
13-
from Retrieval.RetrieverBase import BaseRetriever, Document
14-
from Store.VectorStore.VectorStoreBase import VectorStore
13+
from ..RetrieverBase import BaseRetriever, Document
14+
from ...Store.VectorStore.VectorStoreBase import VectorStore
1515

1616
logger = logging.getLogger(__name__)
1717

rag_factory/Store/VectorStore/VectorStore_Faiss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import numpy as np
77
from typing import Any, Optional, Callable
88
from .VectorStoreBase import VectorStore, Document
9-
from Embed import Embeddings
9+
from ...Embed.Embedding_Base import Embeddings
1010
import asyncio
1111
from concurrent.futures import ThreadPoolExecutor
1212

13+
# TODO 需要支持GPU,提高速度
1314

1415
def _mmr_select(
1516
docs_and_scores: list[tuple[Document, float]],

rag_factory/Store/VectorStore/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# VectorStore/registry.py
22
from typing import Dict, Type, Any, Optional
33
from .VectorStoreBase import VectorStore
4-
from Embed.Embedding_Base import Embeddings
4+
from ...Embed.Embedding_Base import Embeddings
55
from .VectorStore_Faiss import FaissVectorStore
66

77

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from abc import ABC, abstractmethod
2+
from ..Retrieval import Document
3+
import warnings
4+
5+
class RerankerBase(ABC):
6+
"""
7+
Reranker 基类,所有 Reranker 应该继承此类并实现 rerank 方法。
8+
不建议直接实例化本类。
9+
10+
使用方法:
11+
class MyReranker(RerankerBase):
12+
def rerank(self, query: str, documents: list[str], **kwargs) -> list[float]:
13+
# 实现具体的重排序逻辑
14+
...
15+
"""
16+
def __init__(self):
17+
if type(self) is RerankerBase:
18+
warnings.warn("RerankerBase 是抽象基类,不应直接实例化。请继承并实现 rerank 方法。", UserWarning)
19+
20+
@abstractmethod
21+
def rerank(self, query: str, documents: list[Document], **kwargs) -> list[Document]:
22+
"""
23+
Rerank the documents based on the query.
24+
需要子类实现。
25+
"""
26+
warnings.warn("调用了未实现的 rerank 方法。请在子类中实现该方法。", UserWarning)
27+
raise NotImplementedError("子类必须实现 rerank 方法。")
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
from transformers import AutoTokenizer, AutoModelForCausalLM
3+
from .Reranker_Base import RerankerBase
4+
from ..Retrieval.RetrieverBase import Document
5+
6+
class Qwen3Reranker(RerankerBase):
7+
def __init__(self, model_name_or_path: str, max_length: int = 4096, instruction=None, attn_type='causal', device_id="cuda:0", **kwargs):
8+
super().__init__()
9+
device = torch.device(device_id)
10+
self.max_length = max_length
11+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, padding_side='left')
12+
self.lm = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16)
13+
self.lm = self.lm.to(device).eval()
14+
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
15+
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
16+
self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
17+
self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
18+
self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
19+
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
20+
self.instruction = instruction or "Given the user query, retrieval the relevant passages"
21+
self.device = device
22+
23+
def format_instruction(self, instruction, query, doc):
24+
if instruction is None:
25+
instruction = self.instruction
26+
output = f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
27+
return output
28+
29+
def process_inputs(self, pairs):
30+
out = self.tokenizer(
31+
pairs, padding=False, truncation='longest_first',
32+
return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
33+
)
34+
for i, ele in enumerate(out['input_ids']):
35+
out['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
36+
out = self.tokenizer.pad(out, padding=True, return_tensors="pt", max_length=self.max_length)
37+
for key in out:
38+
out[key] = out[key].to(self.lm.device)
39+
return out
40+
41+
@torch.no_grad()
42+
def compute_logits(self, inputs, **kwargs):
43+
batch_scores = self.lm(**inputs).logits[:, -1, :]
44+
true_vector = batch_scores[:, self.token_true_id]
45+
false_vector = batch_scores[:, self.token_false_id]
46+
batch_scores = torch.stack([false_vector, true_vector], dim=1)
47+
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
48+
scores = batch_scores[:, 1].exp().tolist()
49+
return scores
50+
51+
def compute_scores(self, pairs, instruction=None, **kwargs):
52+
pairs = [self.format_instruction(instruction, query, doc) for query, doc in pairs]
53+
inputs = self.process_inputs(pairs)
54+
scores = self.compute_logits(inputs)
55+
return scores
56+
57+
def rerank(self, query: str, documents: list[Document], k: int = None, batch_size: int = 8, **kwargs) -> list[Document]:
58+
# 1. 组装 (query, doc.content) 对
59+
pairs = [(query, doc.content) for doc in documents]
60+
61+
# 2. 计算分数
62+
all_scores = []
63+
for i in range(0, len(pairs), batch_size):
64+
batch_pairs = pairs[i:i+batch_size]
65+
batch_scores = self.compute_scores(batch_pairs)
66+
all_scores.extend(batch_scores)
67+
scores = all_scores
68+
69+
# 3. 按分数排序
70+
doc_score_pairs = list(zip(documents, scores))
71+
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
72+
reranked_docs = [doc for doc, score in doc_score_pairs]
73+
if k is not None:
74+
reranked_docs = reranked_docs[:k]
75+
return reranked_docs

rag_factory/rerankers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .Reranker_Base import RerankerBase
2+
from .Reranker_Qwen3 import Qwen3Reranker
3+
4+
__all__ = ["RerankerBase", "Qwen3Reranker"]

0 commit comments

Comments
 (0)