diff --git a/rag_factory/Retrieval/Retriever/Retriever_BM25.py b/rag_factory/Retrieval/Retriever/Retriever_BM25.py
index 278b55e..09723f0 100644
--- a/rag_factory/Retrieval/Retriever/Retriever_BM25.py
+++ b/rag_factory/Retrieval/Retriever/Retriever_BM25.py
@@ -9,12 +9,12 @@
from pydantic import ConfigDict, Field, model_validator
logger = logging.getLogger(__name__)
-
+import numpy as np
from rag_factory.Retrieval.RetrieverBase import BaseRetriever, Document
def default_preprocessing_func(text: str) -> List[str]:
- """默认的文本预处理函数
+ """默认的文本预处理函数,仅在英文文本上有效
Args:
text: 输入文本
@@ -25,33 +25,51 @@ def default_preprocessing_func(text: str) -> List[str]:
return text.split()
-def chinese_preprocessing_func(text: str) -> List[str]:
- """中文文本预处理函数
-
- Args:
- text: 输入的中文文本
-
- Returns:
- 分词后的词语列表
- """
- try:
- import jieba
- return list(jieba.cut(text))
- except ImportError:
- logger.warning("jieba 未安装,使用默认分词方法。请安装: pip install jieba")
- return text.split()
class BM25Retriever(BaseRetriever):
- """BM25 检索器实现
-
- 基于 BM25 算法的文档检索器。
- 使用 rank_bm25 库实现高效的 BM25 搜索。
-
- 注意:BM25 算法适用于相对静态的文档集合。虽然支持动态添加/删除文档,
- 但每次操作都会重建整个索引,在大型文档集合上可能有性能问题。
- 对于频繁更新的场景,建议使用 VectorStoreRetriever。
-
+ """
+ BM25Retriever 是一个基于 BM25 算法的文档检索器,适用于信息检索、问答系统、知识库等场景下的高效文本相关性排序。
+
+ 该类通过集成 rank_bm25 库,实现了对文档集合的 BM25 检索,支持文档的动态添加、删除、批量构建索引等操作。
+ 适合文档集合相对静态、检索速度要求较高的场景。对于频繁增删文档的场景,建议使用向量检索(如 VectorStoreRetriever)。
+
+ 主要特性:
+ - 支持从文本列表或 Document 对象列表快速构建 BM25 检索器。
+ - 支持自定义分词/预处理函数,适配不同语言和分词需求。
+ - 支持动态添加、删除文档(每次操作会重建索引,适合中小规模数据集)。
+ - 可获取检索分数、top-k 文档及分数、检索器配置信息等。
+ - 兼容异步文档添加/删除,便于大规模数据处理。
+ - 通过 Pydantic 校验参数,保证配置安全。
+
+ 主要参数:
+ vectorizer (Any): BM25 向量化器实例(通常为 BM25Okapi)。
+ docs (List[Document]): 当前检索器持有的文档对象列表。
+ k (int): 默认返回的相关文档数量。
+ preprocess_func (Callable): 文本分词/预处理函数,默认为空格分词。
+ bm25_params (Dict): 传递给 BM25Okapi 的参数(如 k1、b 等)。
+
+ 核心方法:
+ - from_texts/from_documents: 从原始文本或 Document 构建检索器。
+ - _get_relevant_documents: 检索与查询最相关的前 k 个文档。
+ - get_scores: 获取查询对所有文档的 BM25 分数。
+ - get_top_k_with_scores: 获取 top-k 文档及其分数。
+ - add_documents/delete_documents: 动态增删文档并重建索引。
+ - get_bm25_info: 获取检索器配置信息和统计。
+ - update_k: 动态调整返回文档数量。
+
+ 性能注意事项:
+ - 每次添加/删除文档都会重建 BM25 索引,适合文档量较小或更新不频繁的场景。
+ - 文档量较大或频繁更新时,建议使用向量检索方案。
+ - 支持异步操作,便于大规模数据处理。
+
+ 典型用法:
+ >>> retriever = BM25Retriever.from_texts(["文本1", "文本2"], k=3)
+ >>> results = retriever._get_relevant_documents("查询语句")
+ >>> retriever.add_documents([Document(content="新文档")])
+ >>> retriever.delete_documents(ids=["doc_id"])
+ >>> info = retriever.get_bm25_info()
+
Attributes:
vectorizer: BM25 向量化器实例
docs: 文档列表
@@ -125,7 +143,7 @@ def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Returns:
验证后的值
"""
- k = values.get("k", 4)
+ k = values.get("k", 5)
if k <= 0:
raise ValueError(f"k 必须大于 0,当前值: {k}")
@@ -259,46 +277,48 @@ def from_documents(
)
def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
- """获取与查询相关的文档
-
+ """获取与查询相关的前k个文档
+
Args:
query: 查询字符串
**kwargs: 其他参数,可能包含 'k' 来覆盖默认的返回数量
-
+
Returns:
相关文档列表
-
+
Raises:
ValueError: 如果向量化器未初始化
"""
if self.vectorizer is None:
raise ValueError("BM25 向量化器未初始化")
-
+
if not self.docs:
logger.warning("文档列表为空,返回空结果")
return []
-
+
# 获取返回文档数量
k = kwargs.get('k', self.k)
k = min(k, len(self.docs)) # 确保不超过总文档数
-
+
try:
# 预处理查询
processed_query = self.preprocess_func(query)
logger.debug(f"预处理后的查询: {processed_query}")
+
+ # 获取所有文档的分数
+ scores = self.vectorizer.get_scores(processed_query)
+ # 获取分数最高的前k个文档索引
- # 获取相关文档
- relevant_docs = self.vectorizer.get_top_n(
- processed_query, self.docs, n=k
- )
-
- logger.debug(f"找到 {len(relevant_docs)} 个相关文档")
- return relevant_docs
-
+ top_indices = np.argsort(scores)[::-1][:k]
+ # 返回前k个文档
+ top_docs = [self.docs[idx] for idx in top_indices]
+ logger.debug(f"找到 {len(top_docs)} 个相关文档")
+ return top_docs
+
except Exception as e:
logger.error(f"BM25 搜索时发生错误: {e}")
raise
-
+
def get_scores(self, query: str) -> List[float]:
"""获取查询对所有文档的 BM25 分数
diff --git a/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py b/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py
index 734c7d8..b43bc0a 100644
--- a/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py
+++ b/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py
@@ -10,8 +10,8 @@
import logging
from pydantic import ConfigDict, Field, model_validator
-from Retrieval.RetrieverBase import BaseRetriever, Document
-from Store.VectorStore.VectorStoreBase import VectorStore
+from ..RetrieverBase import BaseRetriever, Document
+from ...Store.VectorStore.VectorStoreBase import VectorStore
logger = logging.getLogger(__name__)
diff --git a/rag_factory/Store/VectorStore/VectorStore_Faiss.py b/rag_factory/Store/VectorStore/VectorStore_Faiss.py
index 0e38e37..ab68f2d 100644
--- a/rag_factory/Store/VectorStore/VectorStore_Faiss.py
+++ b/rag_factory/Store/VectorStore/VectorStore_Faiss.py
@@ -6,10 +6,11 @@
import numpy as np
from typing import Any, Optional, Callable
from .VectorStoreBase import VectorStore, Document
-from Embed import Embeddings
+from ...Embed.Embedding_Base import Embeddings
import asyncio
from concurrent.futures import ThreadPoolExecutor
+# TODO 需要支持GPU,提高速度
def _mmr_select(
docs_and_scores: list[tuple[Document, float]],
diff --git a/rag_factory/Store/VectorStore/registry.py b/rag_factory/Store/VectorStore/registry.py
index 611b690..7aa4999 100644
--- a/rag_factory/Store/VectorStore/registry.py
+++ b/rag_factory/Store/VectorStore/registry.py
@@ -1,7 +1,7 @@
# VectorStore/registry.py
from typing import Dict, Type, Any, Optional
from .VectorStoreBase import VectorStore
-from Embed.Embedding_Base import Embeddings
+from ...Embed.Embedding_Base import Embeddings
from .VectorStore_Faiss import FaissVectorStore
diff --git a/rag_factory/rerankers/Reranker_Base.py b/rag_factory/rerankers/Reranker_Base.py
new file mode 100644
index 0000000..518b81e
--- /dev/null
+++ b/rag_factory/rerankers/Reranker_Base.py
@@ -0,0 +1,27 @@
+from abc import ABC, abstractmethod
+from ..Retrieval import Document
+import warnings
+
+class RerankerBase(ABC):
+ """
+ Reranker 基类,所有 Reranker 应该继承此类并实现 rerank 方法。
+ 不建议直接实例化本类。
+
+ 使用方法:
+ class MyReranker(RerankerBase):
+ def rerank(self, query: str, documents: list[str], **kwargs) -> list[float]:
+ # 实现具体的重排序逻辑
+ ...
+ """
+ def __init__(self):
+ if type(self) is RerankerBase:
+ warnings.warn("RerankerBase 是抽象基类,不应直接实例化。请继承并实现 rerank 方法。", UserWarning)
+
+ @abstractmethod
+ def rerank(self, query: str, documents: list[Document], **kwargs) -> list[Document]:
+ """
+ Rerank the documents based on the query.
+ 需要子类实现。
+ """
+ warnings.warn("调用了未实现的 rerank 方法。请在子类中实现该方法。", UserWarning)
+ raise NotImplementedError("子类必须实现 rerank 方法。")
diff --git a/rag_factory/rerankers/Reranker_Qwen3.py b/rag_factory/rerankers/Reranker_Qwen3.py
new file mode 100644
index 0000000..b909c2b
--- /dev/null
+++ b/rag_factory/rerankers/Reranker_Qwen3.py
@@ -0,0 +1,75 @@
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from .Reranker_Base import RerankerBase
+from ..Retrieval.RetrieverBase import Document
+
+class Qwen3Reranker(RerankerBase):
+ def __init__(self, model_name_or_path: str, max_length: int = 4096, instruction=None, attn_type='causal', device_id="cuda:0", **kwargs):
+ super().__init__()
+ device = torch.device(device_id)
+ self.max_length = max_length
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, padding_side='left')
+ self.lm = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16)
+ self.lm = self.lm.to(device).eval()
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
+ self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
+ self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
+ self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
+ self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
+ self.instruction = instruction or "Given the user query, retrieval the relevant passages"
+ self.device = device
+
+ def format_instruction(self, instruction, query, doc):
+ if instruction is None:
+ instruction = self.instruction
+ output = f": {instruction}\n: {query}\n: {doc}"
+ return output
+
+ def process_inputs(self, pairs):
+ out = self.tokenizer(
+ pairs, padding=False, truncation='longest_first',
+ return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
+ )
+ for i, ele in enumerate(out['input_ids']):
+ out['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
+ out = self.tokenizer.pad(out, padding=True, return_tensors="pt", max_length=self.max_length)
+ for key in out:
+ out[key] = out[key].to(self.lm.device)
+ return out
+
+ @torch.no_grad()
+ def compute_logits(self, inputs, **kwargs):
+ batch_scores = self.lm(**inputs).logits[:, -1, :]
+ true_vector = batch_scores[:, self.token_true_id]
+ false_vector = batch_scores[:, self.token_false_id]
+ batch_scores = torch.stack([false_vector, true_vector], dim=1)
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
+ scores = batch_scores[:, 1].exp().tolist()
+ return scores
+
+ def compute_scores(self, pairs, instruction=None, **kwargs):
+ pairs = [self.format_instruction(instruction, query, doc) for query, doc in pairs]
+ inputs = self.process_inputs(pairs)
+ scores = self.compute_logits(inputs)
+ return scores
+
+ def rerank(self, query: str, documents: list[Document], k: int = None, batch_size: int = 8, **kwargs) -> list[Document]:
+ # 1. 组装 (query, doc.content) 对
+ pairs = [(query, doc.content) for doc in documents]
+
+ # 2. 计算分数
+ all_scores = []
+ for i in range(0, len(pairs), batch_size):
+ batch_pairs = pairs[i:i+batch_size]
+ batch_scores = self.compute_scores(batch_pairs)
+ all_scores.extend(batch_scores)
+ scores = all_scores
+
+ # 3. 按分数排序
+ doc_score_pairs = list(zip(documents, scores))
+ doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
+ reranked_docs = [doc for doc, score in doc_score_pairs]
+ if k is not None:
+ reranked_docs = reranked_docs[:k]
+ return reranked_docs
\ No newline at end of file
diff --git a/rag_factory/rerankers/__init__.py b/rag_factory/rerankers/__init__.py
index e69de29..2661297 100644
--- a/rag_factory/rerankers/__init__.py
+++ b/rag_factory/rerankers/__init__.py
@@ -0,0 +1,4 @@
+from .Reranker_Base import RerankerBase
+from .Reranker_Qwen3 import Qwen3Reranker
+
+__all__ = ["RerankerBase", "Qwen3Reranker"]
\ No newline at end of file