From ef84055b184b6336017ff901515c1238bc51fdbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E6=98=8E=E7=A5=AF?= Date: Wed, 6 Aug 2025 18:26:44 +0800 Subject: [PATCH 1/4] fix: invoke return all documents --- .../Retrieval/Retriever/Retriever_BM25.py | 106 +++++++++++------- 1 file changed, 63 insertions(+), 43 deletions(-) diff --git a/rag_factory/Retrieval/Retriever/Retriever_BM25.py b/rag_factory/Retrieval/Retriever/Retriever_BM25.py index 278b55e..09723f0 100644 --- a/rag_factory/Retrieval/Retriever/Retriever_BM25.py +++ b/rag_factory/Retrieval/Retriever/Retriever_BM25.py @@ -9,12 +9,12 @@ from pydantic import ConfigDict, Field, model_validator logger = logging.getLogger(__name__) - +import numpy as np from rag_factory.Retrieval.RetrieverBase import BaseRetriever, Document def default_preprocessing_func(text: str) -> List[str]: - """默认的文本预处理函数 + """默认的文本预处理函数,仅在英文文本上有效 Args: text: 输入文本 @@ -25,33 +25,51 @@ def default_preprocessing_func(text: str) -> List[str]: return text.split() -def chinese_preprocessing_func(text: str) -> List[str]: - """中文文本预处理函数 - - Args: - text: 输入的中文文本 - - Returns: - 分词后的词语列表 - """ - try: - import jieba - return list(jieba.cut(text)) - except ImportError: - logger.warning("jieba 未安装,使用默认分词方法。请安装: pip install jieba") - return text.split() class BM25Retriever(BaseRetriever): - """BM25 检索器实现 - - 基于 BM25 算法的文档检索器。 - 使用 rank_bm25 库实现高效的 BM25 搜索。 - - 注意:BM25 算法适用于相对静态的文档集合。虽然支持动态添加/删除文档, - 但每次操作都会重建整个索引,在大型文档集合上可能有性能问题。 - 对于频繁更新的场景,建议使用 VectorStoreRetriever。 - + """ + BM25Retriever 是一个基于 BM25 算法的文档检索器,适用于信息检索、问答系统、知识库等场景下的高效文本相关性排序。 + + 该类通过集成 rank_bm25 库,实现了对文档集合的 BM25 检索,支持文档的动态添加、删除、批量构建索引等操作。 + 适合文档集合相对静态、检索速度要求较高的场景。对于频繁增删文档的场景,建议使用向量检索(如 VectorStoreRetriever)。 + + 主要特性: + - 支持从文本列表或 Document 对象列表快速构建 BM25 检索器。 + - 支持自定义分词/预处理函数,适配不同语言和分词需求。 + - 支持动态添加、删除文档(每次操作会重建索引,适合中小规模数据集)。 + - 可获取检索分数、top-k 文档及分数、检索器配置信息等。 + - 兼容异步文档添加/删除,便于大规模数据处理。 + - 通过 Pydantic 校验参数,保证配置安全。 + + 主要参数: + vectorizer (Any): BM25 向量化器实例(通常为 BM25Okapi)。 + docs (List[Document]): 当前检索器持有的文档对象列表。 + k (int): 默认返回的相关文档数量。 + preprocess_func (Callable): 文本分词/预处理函数,默认为空格分词。 + bm25_params (Dict): 传递给 BM25Okapi 的参数(如 k1、b 等)。 + + 核心方法: + - from_texts/from_documents: 从原始文本或 Document 构建检索器。 + - _get_relevant_documents: 检索与查询最相关的前 k 个文档。 + - get_scores: 获取查询对所有文档的 BM25 分数。 + - get_top_k_with_scores: 获取 top-k 文档及其分数。 + - add_documents/delete_documents: 动态增删文档并重建索引。 + - get_bm25_info: 获取检索器配置信息和统计。 + - update_k: 动态调整返回文档数量。 + + 性能注意事项: + - 每次添加/删除文档都会重建 BM25 索引,适合文档量较小或更新不频繁的场景。 + - 文档量较大或频繁更新时,建议使用向量检索方案。 + - 支持异步操作,便于大规模数据处理。 + + 典型用法: + >>> retriever = BM25Retriever.from_texts(["文本1", "文本2"], k=3) + >>> results = retriever._get_relevant_documents("查询语句") + >>> retriever.add_documents([Document(content="新文档")]) + >>> retriever.delete_documents(ids=["doc_id"]) + >>> info = retriever.get_bm25_info() + Attributes: vectorizer: BM25 向量化器实例 docs: 文档列表 @@ -125,7 +143,7 @@ def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]: Returns: 验证后的值 """ - k = values.get("k", 4) + k = values.get("k", 5) if k <= 0: raise ValueError(f"k 必须大于 0,当前值: {k}") @@ -259,46 +277,48 @@ def from_documents( ) def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: - """获取与查询相关的文档 - + """获取与查询相关的前k个文档 + Args: query: 查询字符串 **kwargs: 其他参数,可能包含 'k' 来覆盖默认的返回数量 - + Returns: 相关文档列表 - + Raises: ValueError: 如果向量化器未初始化 """ if self.vectorizer is None: raise ValueError("BM25 向量化器未初始化") - + if not self.docs: logger.warning("文档列表为空,返回空结果") return [] - + # 获取返回文档数量 k = kwargs.get('k', self.k) k = min(k, len(self.docs)) # 确保不超过总文档数 - + try: # 预处理查询 processed_query = self.preprocess_func(query) logger.debug(f"预处理后的查询: {processed_query}") + + # 获取所有文档的分数 + scores = self.vectorizer.get_scores(processed_query) + # 获取分数最高的前k个文档索引 - # 获取相关文档 - relevant_docs = self.vectorizer.get_top_n( - processed_query, self.docs, n=k - ) - - logger.debug(f"找到 {len(relevant_docs)} 个相关文档") - return relevant_docs - + top_indices = np.argsort(scores)[::-1][:k] + # 返回前k个文档 + top_docs = [self.docs[idx] for idx in top_indices] + logger.debug(f"找到 {len(top_docs)} 个相关文档") + return top_docs + except Exception as e: logger.error(f"BM25 搜索时发生错误: {e}") raise - + def get_scores(self, query: str) -> List[float]: """获取查询对所有文档的 BM25 分数 From 80cc581ca3f6644ae03a8ee0b53562f839ebb352 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E6=98=8E=E7=A5=AF?= Date: Wed, 6 Aug 2025 18:27:19 +0800 Subject: [PATCH 2/4] fix: import path --- rag_factory/Retrieval/Retriever/Retriever_VectorStore.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py b/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py index 734c7d8..b43bc0a 100644 --- a/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py +++ b/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py @@ -10,8 +10,8 @@ import logging from pydantic import ConfigDict, Field, model_validator -from Retrieval.RetrieverBase import BaseRetriever, Document -from Store.VectorStore.VectorStoreBase import VectorStore +from ..RetrieverBase import BaseRetriever, Document +from ...Store.VectorStore.VectorStoreBase import VectorStore logger = logging.getLogger(__name__) From 10422453397fa16aa1df938ea05c7c96dc800093 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E6=98=8E=E7=A5=AF?= Date: Wed, 6 Aug 2025 18:28:22 +0800 Subject: [PATCH 3/4] fix: import path --- rag_factory/Store/VectorStore/VectorStore_Faiss.py | 3 ++- rag_factory/Store/VectorStore/registry.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/rag_factory/Store/VectorStore/VectorStore_Faiss.py b/rag_factory/Store/VectorStore/VectorStore_Faiss.py index 0e38e37..ab68f2d 100644 --- a/rag_factory/Store/VectorStore/VectorStore_Faiss.py +++ b/rag_factory/Store/VectorStore/VectorStore_Faiss.py @@ -6,10 +6,11 @@ import numpy as np from typing import Any, Optional, Callable from .VectorStoreBase import VectorStore, Document -from Embed import Embeddings +from ...Embed.Embedding_Base import Embeddings import asyncio from concurrent.futures import ThreadPoolExecutor +# TODO 需要支持GPU,提高速度 def _mmr_select( docs_and_scores: list[tuple[Document, float]], diff --git a/rag_factory/Store/VectorStore/registry.py b/rag_factory/Store/VectorStore/registry.py index 611b690..7aa4999 100644 --- a/rag_factory/Store/VectorStore/registry.py +++ b/rag_factory/Store/VectorStore/registry.py @@ -1,7 +1,7 @@ # VectorStore/registry.py from typing import Dict, Type, Any, Optional from .VectorStoreBase import VectorStore -from Embed.Embedding_Base import Embeddings +from ...Embed.Embedding_Base import Embeddings from .VectorStore_Faiss import FaissVectorStore From 78af53534e74052a09c69a46871bd66ececf504d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E6=98=8E=E7=A5=AF?= Date: Wed, 6 Aug 2025 18:29:32 +0800 Subject: [PATCH 4/4] add: Rerank module --- rag_factory/rerankers/Reranker_Base.py | 27 +++++++++ rag_factory/rerankers/Reranker_Qwen3.py | 75 +++++++++++++++++++++++++ rag_factory/rerankers/__init__.py | 4 ++ 3 files changed, 106 insertions(+) create mode 100644 rag_factory/rerankers/Reranker_Base.py create mode 100644 rag_factory/rerankers/Reranker_Qwen3.py diff --git a/rag_factory/rerankers/Reranker_Base.py b/rag_factory/rerankers/Reranker_Base.py new file mode 100644 index 0000000..518b81e --- /dev/null +++ b/rag_factory/rerankers/Reranker_Base.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from ..Retrieval import Document +import warnings + +class RerankerBase(ABC): + """ + Reranker 基类,所有 Reranker 应该继承此类并实现 rerank 方法。 + 不建议直接实例化本类。 + + 使用方法: + class MyReranker(RerankerBase): + def rerank(self, query: str, documents: list[str], **kwargs) -> list[float]: + # 实现具体的重排序逻辑 + ... + """ + def __init__(self): + if type(self) is RerankerBase: + warnings.warn("RerankerBase 是抽象基类,不应直接实例化。请继承并实现 rerank 方法。", UserWarning) + + @abstractmethod + def rerank(self, query: str, documents: list[Document], **kwargs) -> list[Document]: + """ + Rerank the documents based on the query. + 需要子类实现。 + """ + warnings.warn("调用了未实现的 rerank 方法。请在子类中实现该方法。", UserWarning) + raise NotImplementedError("子类必须实现 rerank 方法。") diff --git a/rag_factory/rerankers/Reranker_Qwen3.py b/rag_factory/rerankers/Reranker_Qwen3.py new file mode 100644 index 0000000..b909c2b --- /dev/null +++ b/rag_factory/rerankers/Reranker_Qwen3.py @@ -0,0 +1,75 @@ +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from .Reranker_Base import RerankerBase +from ..Retrieval.RetrieverBase import Document + +class Qwen3Reranker(RerankerBase): + def __init__(self, model_name_or_path: str, max_length: int = 4096, instruction=None, attn_type='causal', device_id="cuda:0", **kwargs): + super().__init__() + device = torch.device(device_id) + self.max_length = max_length + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, padding_side='left') + self.lm = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16) + self.lm = self.lm.to(device).eval() + self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") + self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") + self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" + self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False) + self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False) + self.instruction = instruction or "Given the user query, retrieval the relevant passages" + self.device = device + + def format_instruction(self, instruction, query, doc): + if instruction is None: + instruction = self.instruction + output = f": {instruction}\n: {query}\n: {doc}" + return output + + def process_inputs(self, pairs): + out = self.tokenizer( + pairs, padding=False, truncation='longest_first', + return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens) + ) + for i, ele in enumerate(out['input_ids']): + out['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens + out = self.tokenizer.pad(out, padding=True, return_tensors="pt", max_length=self.max_length) + for key in out: + out[key] = out[key].to(self.lm.device) + return out + + @torch.no_grad() + def compute_logits(self, inputs, **kwargs): + batch_scores = self.lm(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, self.token_true_id] + false_vector = batch_scores[:, self.token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + def compute_scores(self, pairs, instruction=None, **kwargs): + pairs = [self.format_instruction(instruction, query, doc) for query, doc in pairs] + inputs = self.process_inputs(pairs) + scores = self.compute_logits(inputs) + return scores + + def rerank(self, query: str, documents: list[Document], k: int = None, batch_size: int = 8, **kwargs) -> list[Document]: + # 1. 组装 (query, doc.content) 对 + pairs = [(query, doc.content) for doc in documents] + + # 2. 计算分数 + all_scores = [] + for i in range(0, len(pairs), batch_size): + batch_pairs = pairs[i:i+batch_size] + batch_scores = self.compute_scores(batch_pairs) + all_scores.extend(batch_scores) + scores = all_scores + + # 3. 按分数排序 + doc_score_pairs = list(zip(documents, scores)) + doc_score_pairs.sort(key=lambda x: x[1], reverse=True) + reranked_docs = [doc for doc, score in doc_score_pairs] + if k is not None: + reranked_docs = reranked_docs[:k] + return reranked_docs \ No newline at end of file diff --git a/rag_factory/rerankers/__init__.py b/rag_factory/rerankers/__init__.py index e69de29..2661297 100644 --- a/rag_factory/rerankers/__init__.py +++ b/rag_factory/rerankers/__init__.py @@ -0,0 +1,4 @@ +from .Reranker_Base import RerankerBase +from .Reranker_Qwen3 import Qwen3Reranker + +__all__ = ["RerankerBase", "Qwen3Reranker"] \ No newline at end of file