diff --git a/.gitignore b/.gitignore index f094905..ceddcab 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,6 @@ cache cache_local .cache -test_mineru \ No newline at end of file +test_mineru + +requirements-kbc.txt \ No newline at end of file diff --git a/dataflow/example/KBCleaningPipeline/pdf_test.json b/dataflow/example/KBCleaningPipeline/kbc_placeholder.json similarity index 100% rename from dataflow/example/KBCleaningPipeline/pdf_test.json rename to dataflow/example/KBCleaningPipeline/kbc_placeholder.json diff --git a/dataflow/example/KBCleaningPipeline/test.doc b/dataflow/example/KBCleaningPipeline/test.doc new file mode 100644 index 0000000..39ff962 Binary files /dev/null and b/dataflow/example/KBCleaningPipeline/test.doc differ diff --git a/dataflow/example/KBCleaningPipeline/test.pdf b/dataflow/example/KBCleaningPipeline/test.pdf new file mode 100644 index 0000000..3921837 Binary files /dev/null and b/dataflow/example/KBCleaningPipeline/test.pdf differ diff --git a/dataflow/operators/generate/KnowledgeCleaning/KnowledgeCleaner.py b/dataflow/operators/generate/KnowledgeCleaning/KnowledgeCleaner.py index 8a6122e..2092b3a 100644 --- a/dataflow/operators/generate/KnowledgeCleaning/KnowledgeCleaner.py +++ b/dataflow/operators/generate/KnowledgeCleaning/KnowledgeCleaner.py @@ -12,9 +12,9 @@ class KnowledgeCleaner(OperatorABC): ''' KnowledgeCleaner is a class that cleans knowledge for RAG to make them more accurate, reliable and readable. ''' - def __init__(self, llm_serving: LLMServingABC, lang="zh"): + def __init__(self, llm_serving: LLMServingABC, lang="en"): self.logger = get_logger() - self.prompts = KnowledgeCleanerPrompt(lang="zh") + self.prompts = KnowledgeCleanerPrompt(lang=lang) self.llm_serving = llm_serving @staticmethod diff --git a/dataflow/operators/generate/KnowledgeCleaning/KnowledgeExtractor.py b/dataflow/operators/generate/KnowledgeCleaning/KnowledgeExtractor.py index e9bba9a..7034620 100644 --- a/dataflow/operators/generate/KnowledgeCleaning/KnowledgeExtractor.py +++ b/dataflow/operators/generate/KnowledgeCleaning/KnowledgeExtractor.py @@ -3,18 +3,8 @@ from dataflow import get_logger from dataflow.utils.storage import DataFlowStorage from dataflow.core import OperatorABC - +from dataflow.utils.kbcleaning import _parse_pdf_to_md,_parse_doc_to_md,_parse_xml_to_md import os -from pathlib import Path -from mineru.data.data_reader_writer import FileBasedDataWriter -from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze -from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make -from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json -from mineru.utils.enum_class import MakeMode -from magic_doc.docconv import DocConverter -import chonkie -import subprocess -from trafilatura import fetch_url, extract @OPERATOR_REGISTRY.register() class KnowledgeExtractor(OperatorABC): @@ -26,7 +16,7 @@ def __init__(self, **kwargs): self.intermediate_dir=kwargs.get("intermediate_dir", "intermediate") @staticmethod - def get_desc(lang="en"): + def get_desc(self, lang="en"): """ 返回算子功能描述 (根据run()函数的功能实现) """ @@ -56,79 +46,13 @@ def get_desc(lang="en"): "- Supports both local files and URLs\n" "- Generates intermediate files to specified directory(intermediate_dir)" ) - - def _parse_pdf_to_md( - self, - input_pdf_path: str, - output_dir: str, - lang: str = "ch", - parse_method: str = "auto" # 解析方法:auto/txt/ocr - ): - """ - 将PDF转换为Markdown(仅使用Pipeline后端) - """ - # 读取PDF文件 - pdf_bytes = Path(input_pdf_path).read_bytes() - pdf_name = Path(input_pdf_path).stem - - # 解析PDF - infer_results, all_image_lists, all_pdf_docs, _, ocr_enabled_list = pipeline_doc_analyze( - [pdf_bytes], [lang], parse_method=parse_method - ) - - # 准备输出目录 - image_dir = os.path.join(output_dir, f"{pdf_name}_images") - os.makedirs(image_dir, exist_ok=True) - image_writer = FileBasedDataWriter(image_dir) - md_writer = FileBasedDataWriter(output_dir) - - # 生成中间结果和Markdown - middle_json = pipeline_result_to_middle_json( - infer_results[0], all_image_lists[0], all_pdf_docs[0], - image_writer, lang, ocr_enabled_list[0], True - ) - md_content = pipeline_union_make(middle_json["pdf_info"], MakeMode.MM_MD, os.path.basename(image_dir)) - # 保存Markdown - md_writer.write_string(f"{pdf_name}_pdf.md", md_content) - print(f"Markdown saved to: {os.path.join(output_dir, f'{pdf_name}_pdf.md')}") - return os.path.join(output_dir,f"{pdf_name}_pdf.md") - - def _parse_doc_to_md(self, input_file: str, output_file: str): - """ - support conversion of doc/ppt/pptx/pdf files to markdowns - """ - converter = DocConverter(s3_config=None) - markdown_content, time_cost = converter.convert(input_file, conv_timeout=300) - print("time cost: ", time_cost) - with open(output_file, "w",encoding='utf-8') as f: - f.write(markdown_content) - return output_file - - def _parse_xml_to_md(self, raw_file:str=None, url:str=None, output_file:str=None): - if(url): - downloaded=fetch_url(url) - elif(raw_file): - with open(raw_file, "r", encoding='utf-8') as f: - downloaded=f.read() - else: - raise Exception("Please provide at least one of file path and url string.") - try: - result=extract(downloaded, output_format="markdown", with_metadata=True) - self.logger.info(f"Extracted content is written into {output_file}") - with open(output_file,"w", encoding="utf-8") as f: - f.write(result) - except Exception as e: - print("Error during extract this file or link: ", e) - - return output_file - - def run(self, storage:DataFlowStorage ,raw_file=None, url=None,lang="ch"): + def run(self, storage:DataFlowStorage ,raw_file=None, url=None,lang="en"): self.logger.info("starting to extract...") self.logger.info("If you are providing a url or a large file, this may take a while, please wait...") if(url): output_file=os.path.join(os.path.dirname(storage.first_entry_file_name), "raw/crawled.md") - output_file=self._parse_xml_to_md(url=url,output_file=output_file) + output_file=_parse_xml_to_md(url=url,output_file=output_file) self.logger.info(f"Primary extracted result written to: {output_file}") return output_file @@ -137,20 +61,46 @@ def run(self, storage:DataFlowStorage ,raw_file=None, url=None,lang="ch"): raw_file_suffix_no_dot=raw_file_suffix.replace(".","") output_file=os.path.join(self.intermediate_dir,f"{raw_file_name}_{raw_file_suffix_no_dot}.md") if(raw_file_suffix==".pdf"): + try: + from mineru.data.data_reader_writer import FileBasedDataWriter + from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze + from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make + from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json + from mineru.utils.enum_class import MakeMode + except: + raise Exception( + """ +MinerU is not installed in this environment yet. +Please refer to https://github.com/opendatalab/mineru to install. +Or you can just execute 'pip install mineru[pipeline]' and 'mineru-models-download' to fix this error. +please make sure you have gpu on your machine. +""" + ) # optional: 是否从本地加载OCR模型 os.environ['MINERU_MODEL_SOURCE'] = "local" - output_file=self._parse_pdf_to_md( + output_file=_parse_pdf_to_md( raw_file, self.intermediate_dir, lang, "txt" ) elif(raw_file_suffix in [".doc", ".docx", ".pptx", ".ppt"]): + try: + from magic_doc.docconv import DocConverter + except: + raise Exception( + """ +Fairy-doc is not installed in this environment yet. +Please refer to https://github.com/opendatalab/magic-doc to install. +Or you can just execute 'apt-get/yum/brew install libreoffice' and 'pip install fairy-doc[gpu]' to fix this error. +please make sure you have gpu on your machine. +""" + ) if(raw_file_suffix==".docx"): raise Exception("Function Under Maintaining...Please try .doc format file instead.") - output_file=self._parse_doc_to_md(raw_file, output_file) + output_file=_parse_doc_to_md(raw_file, output_file) elif(raw_file_suffix in [".html", ".xml"]): - output_file=self._parse_xml_to_md(raw_file=raw_file,output_file=output_file) + output_file=_parse_xml_to_md(raw_file=raw_file,output_file=output_file) elif(raw_file_suffix in [".txt",".md"]): # for .txt and .md file, no action is taken output_file=raw_file @@ -159,3 +109,4 @@ def run(self, storage:DataFlowStorage ,raw_file=None, url=None,lang="ch"): self.logger.info(f"Primary extracted result written to: {output_file}") return output_file + diff --git a/dataflow/operators/generate/KnowledgeCleaning/MultiHopQAGenerator.py b/dataflow/operators/generate/KnowledgeCleaning/MultiHopQAGenerator.py new file mode 100644 index 0000000..ac156da --- /dev/null +++ b/dataflow/operators/generate/KnowledgeCleaning/MultiHopQAGenerator.py @@ -0,0 +1,523 @@ +from dataflow.prompts.multihopqa import MultiHopQAGeneratorPrompt +import pandas as pd +from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow import get_logger + +from dataflow.utils.storage import DataFlowStorage +from dataflow.core import OperatorABC +from dataflow.core import LLMServingABC +import random +from typing import Any, Dict, List, Optional, Sequence +import json +from tqdm import tqdm +import re + +class MultiHopQAGenerator(OperatorABC): + r"""A processor for generating multi-hop question-answer pairs from user + data. + + This class handles the processing of text data to generate multi-hop + question-answer pairs using either an AI model or rule-based approaches. + It manages the entire pipeline from text preprocessing to dataset curation. + """ + + def __init__(self, + llm_serving: LLMServingABC, + seed: int = 0, + lang = "en", + ): + r"""Initialize the UserDataProcessor. + + Args: + config (Optional[ProcessorConfig], optional): Configuration for + data processing. (default: :obj:`None`) + """ + self.rng = random.Random(seed) + self.llm_serving=llm_serving + self.lang = lang + self.logger=get_logger() + + @staticmethod + def get_desc(self, lang: str = "zh") -> tuple: + """Returns a description of the processor's functionality. + + Args: + lang (str, optional): Language for description ('zh' or 'en'). + Defaults to None (uses instance language). + + Returns: + tuple: Description strings in specified language + """ + + if lang == "zh": + return ( + "MultiHopQAGenerator 是多跳问答对生成处理器", + "支持从文本数据自动生成需要多步推理的问题-答案对", + "包含文本预处理、信息抽取和智能问答生成全流程", + "支持配置语言模型服务及多种生成参数" + ) + else: # Default to English + return ( + "MultiHopQAGenerator processes text to create multi-hop QA pairs", + "Automatically generates questions requiring multi-step reasoning", + "Handles full pipeline: text preprocessing, information extraction", + "and intelligent QA generation with configurable LLM backend" + ) + + def process_text( + self, text: str, source: str = "user_input" + ) -> List[Dict[str, Any]]: + r"""Process a single text to generate multi-hop QA pairs. + + Args: + text (str): The input text to process. + source (str, optional): Source identifier for the text. + (default: :obj:`"user_input"`) + + Returns: + List[Dict[str, Any]]: List of processed examples with QA pairs and + metadata. + """ + # Convert text to standard format + raw_data = [ + { + 'text': text, + 'source': source, + } + ] + + # Construct examples + constructor = ExampleConstructor(lang=self.lang, llm_serving=self.llm_serving) + examples = constructor.construct_examples(raw_data) + + # Manage data + # curator = DataCurator(self.config, self.rng) + # final_dataset = curator.curate_dataset(examples) + + return examples + + def process_batch( + self, texts: List[str], sources: Optional[List[str]] = None + ) -> List[Dict[str, Any]]: + r"""Process multiple texts in batch to generate multi-hop QA pairs. + + Args: + texts (List[str]): List of input texts to process. + sources (Optional[List[str]], optional): List of source + identifiers. (default: :obj:`None`) + + Returns: + List[Dict[str, Any]]: List of processed examples with QA pairs and + metadata. + + Raises: + ValueError: If length of sources doesn't match length of texts. + """ + if sources is None: + sources = ["default_source"] * len(texts) + elif len(sources) != len(texts): + raise ValueError("Length of sources must match length of texts") + + raw_data = [ + { + 'text': text, + 'source': source, + } + for text, source in zip(texts, sources) + ] + + # Construct examples + constructor = ExampleConstructor(lang=self.lang, llm_serving=self.llm_serving) + examples = constructor.construct_examples(raw_data) + + # # Manage data + # curator = DataCurator(self.config, self.rng) + # final_dataset = curator.curate_dataset(examples) + + return examples + + def _validate_dataframe(self, dataframe: pd.DataFrame): + required_keys = [self.input_key] + forbidden_keys = [self.output_key] + + missing = [k for k in required_keys if k not in dataframe.columns] + conflict = [k for k in forbidden_keys if k in dataframe.columns] + + if missing: + raise ValueError(f"Missing required column(s): {missing}") + if conflict: + raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}") + + def run( + self, + input_key:str='', + output_key:str='', + storage: DataFlowStorage=None, + ): + self.input_key, self.output_key = input_key, output_key + dataframe = storage.read("dataframe") + self._validate_dataframe(dataframe) + texts = dataframe[self.input_key].tolist() + qa_pairs=self.process_batch(texts) + dataframe[self.output_key] = qa_pairs + output_file = storage.write(dataframe) + self.logger.info(f"Results saved to {output_file}") + + return [output_key] + + +class ExampleConstructor: + r"""Constructs training examples from raw text data. + + This class handles the construction of training examples by preprocessing + text, extracting information pairs, and generating question-answer pairs. + """ + + def __init__( + self, + lang:str = "en", + llm_serving: LLMServingABC = None, + min_text_length:int = 100, + max_text_length:int = 200000, + ): + r"""Initialize the ExampleConstructor. + + Args: + config (ProcessorConfig): Configuration for example construction. + multi_hop_agent (Optional[MultiHopGeneratorAgent], optional): + Agent for generating multi-hop QA pairs. (default: :obj:`None`) + """ + self.lang=lang + self.llm_sering=llm_serving + self.logger=get_logger() + self.max_length=max_text_length + self.min_length=min_text_length + self.prompt=MultiHopQAGeneratorPrompt(lang = self.lang) + + def construct_examples( + self, raw_data: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + r"""Construct training examples from raw data. + + Args: + raw_data (List[Dict[str, Any]]): List of raw data dictionaries + containing text and metadata. + + Returns: + List[Dict[str, Any]]: List of constructed examples with QA pairs + and metadata. + """ + self.logger.info("Starting to construct examples...") + examples = [] + + for data in tqdm(raw_data, desc="Constructing examples"): + # 1. Text preprocessing + processed_text = self._preprocess_text(data.get('text', '')) + if not processed_text: + example = { + 'text': processed_text, + 'qa_pairs': [], + 'metadata': { + 'source': data.get('source', 'unknown'), + 'timestamp': data.get('timestamp', ''), + 'complexity': 0, + }, + } + examples.append(example) + continue + + # 2. Generate key information pairs + info_pairs = self._extract_info_pairs(processed_text) + + # 3. Construct question-answer pairs + qa_pairs = self._generate_qa_pairs(info_pairs) + + # 4. Add metadata + example = { + 'text': processed_text, + 'qa_pairs': qa_pairs, + 'metadata': { + 'source': data.get('source', 'unknown'), + 'timestamp': data.get('timestamp', ''), + 'complexity': self._calculate_complexity(qa_pairs), + }, + } + + examples.append(example) + + self.logger.info(f"Successfully constructed {len(examples)} examples") + return examples + + def _preprocess_text(self, text: str) -> str: + r"""Preprocess input text for example construction. + + Args: + text (str): Input text to preprocess. + + Returns: + str: Preprocessed text, or empty string if text fails quality + checks. + """ + if not isinstance(text, str): + return '' + + # 1. Basic cleaning + text = text.strip() + + # 2. Length check + if ( + len(text) < self.min_length + or len(text) > self.max_length + ): + self.logger.warning("text fail to pass length check.") + return '' + + # 3. Quality check + if not self._check_text_quality(text): + self.logger.warning("text fail to pass quality check.") + return '' + + return text + + def _calculate_special_char_ratio(self,text): + # 中文字符的Unicode范围(基本汉字+扩展) + chinese_ranges = [ + (0x4E00, 0x9FFF), # 基本汉字 + (0x3400, 0x4DBF), # 扩展A + (0x20000, 0x2A6DF), # 扩展B + (0x2A700, 0x2B73F), # 扩展C + (0x2B740, 0x2B81F), # 扩展D + (0x2B820, 0x2CEAF) # 扩展E + ] + + special_count = 0 + for c in text: + # 检查是否为中文、字母数字或空格 + is_chinese = any(start <= ord(c) <= end for start, end in chinese_ranges) + if not (c.isalnum() or c.isspace() or is_chinese): + special_count += 1 + + return special_count / len(text) if text else 0 + + def _check_text_quality(self, text: str) -> bool: + r"""Check the quality of input text. + + Args: + text (str): Text to check quality for. + + Returns: + bool: True if text passes quality checks, False otherwise. + """ + # 1. Basic quality check + if (self.lang=="en" and text.count('.') < 2): # Must have at least 2 sentences + return False + elif(text.count("。") < 2): + return False + + # 2. Special character ratio check + special_char_ratio = self._calculate_special_char_ratio(text) + if special_char_ratio > 0.3: # No more than 30% special characters + return False + + return True + + def _extract_info_pairs(self, text: str) -> List[Dict[str, Sequence[str]]]: + r"""Extract information pairs and relationships from text. + + Args: + text (str): Input text to extract information from. + + Returns: + List[Dict[str, Sequence[str]]]: List of dictionaries containing + premise, intermediate, conclusion, and related contexts. + """ + # Split into sentences + if(self.lang=="en"): + sentences = [s.strip() for s in text.split('.') if s.strip()] + else: + sentences = [s.strip() for s in text.split('。') if s.strip()] + + info_pairs = [] + + # Extract combinations of multiple related sentences + for i in range(len(sentences) - 2): + if len(sentences[i]) > 10 and len(sentences[i + 1]) > 10: + info_pairs.append( + { + 'premise': sentences[i], + 'intermediate': sentences[i + 1], + 'conclusion': sentences[i + 2] + if i + 2 < len(sentences) + else '', + 'related_contexts': [ + s + for j, s in enumerate(sentences) + if j != i and j != i + 1 and len(s) > 10 + ][:2], + # Limit to 2 additional related contexts + } + ) + + return info_pairs + + def _generate_qa_pairs( + self, info_pairs: List[Dict[str, Sequence[str]]] + ) -> List[Dict[str, str]]: + r"""Generate multi-hop question-answer pairs from information pairs. + + Args: + info_pairs (List[Dict[str, Sequence[str]]]): List of information + pairs extracted from text. + + Returns: + List[Dict[str, str]]: List of generated QA pairs. + """ + user_inputs=[] + for pair in info_pairs: + # 1. Generate multi-hop question-answer pair using AI + # Construct full context + context = ( + f"{pair['premise']}. {pair['intermediate']}." + f" {pair['conclusion']}" + ) + user_inputs.append(self.prompt._multihop_qa_generator_user_prompt(context)) + + sys_prompt=self.prompt.system_text + + responses = self.llm_sering.generate_from_input(user_inputs=user_inputs,system_prompt=sys_prompt) + qa_pairs=self._extract_qa_pairs(responses) + + return qa_pairs + + def _extract_qa_pairs(self, responses: List[str]) -> List[Dict[str, Any]]: + """ + 从原始响应中精确提取符合结构的QA对 + 自动跳过非法JSON和干扰文本 + """ + qa_pairs = [] + for response in responses: + self.logger.info(f"generated qa: {response}") + + # 方法1:尝试直接解析整个响应为JSON + try: + qa_pair = json.loads(response) + if isinstance(qa_pair, dict) and "question" in qa_pair: + qa_pairs.append(qa_pair) + continue + elif isinstance(qa_pair, list): + for item in qa_pair: + if isinstance(item, dict) and "question" in item: + qa_pairs.append(item) + continue + except json.JSONDecodeError: + pass + + # 方法2:使用正则表达式查找所有JSON对象 + try: + # 查找所有以 { 开始的JSON对象 + json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' + + # 更精确的模式,匹配完整的JSON对象 + brace_count = 0 + start_pos = -1 + json_objects = [] + + for i, char in enumerate(response): + if char == '{': + if brace_count == 0: + start_pos = i + brace_count += 1 + elif char == '}': + brace_count -= 1 + if brace_count == 0 and start_pos != -1: + json_str = response[start_pos:i+1] + json_objects.append(json_str) + start_pos = -1 + + # 尝试解析找到的每个JSON字符串 + for json_str in json_objects: + try: + qa_pair = json.loads(json_str) + if (isinstance(qa_pair, dict) and \ + "question" in qa_pair and \ + "reasoning_steps" in qa_pair and \ + "answer" in qa_pair and \ + "supporting_facts" in qa_pair and \ + "type" in qa_pair): + qa_pairs.append(qa_pair) + self.logger.info(f"Successfully extracted QA pair: {qa_pair['question']}") + except json.JSONDecodeError as e: + self.logger.debug(f"Failed to parse JSON object: {json_str[:100]}... Error: {e}") + continue + + # 对qa_pairs中重复的question进行去重 + if qa_pairs: + seen_questions = set() + unique_qa_pairs = [] + + for qa_pair in qa_pairs: + question = qa_pair.get("question", "").strip().lower() + if question and question not in seen_questions: + seen_questions.add(question) + unique_qa_pairs.append(qa_pair) + self.logger.debug(f"Added unique question: {qa_pair['question']}") + else: + self.logger.debug(f"Skipped duplicate question: {qa_pair.get('question', 'N/A')}") + + qa_pairs = unique_qa_pairs + self.logger.info(f"After deduplication: {len(qa_pairs)} unique QA pairs") + + # 如果没有找到有效的JSON对象,记录警告 + if not json_objects: + self.logger.warning("No JSON objects found in model response.") + + except Exception as e: + self.logger.warning(f"Failed to parse QA information from model response. Error: {e}") + + return qa_pairs + + def _calculate_complexity(self, qa_pairs: List[Dict[str, Any]]) -> float: + r"""Calculate the complexity score for a set of QA pairs. + + Args: + qa_pairs (List[Dict[str, Any]]): List of QA pairs to calculate + complexity for. + + Returns: + float: Complexity score between 0.0 and 1.0. + """ + if not qa_pairs: + return 0.0 + + # Calculate complexity based on multiple factors + complexities = [] + for qa in qa_pairs: + # 1. Number of reasoning steps + reasoning_steps_count = len(qa.get('reasoning_steps', [])) + + # 2. Number of supporting facts + supporting_facts_count = len(qa.get('supporting_facts', [])) + + # 3. Question length + question_length = len(qa.get('question', '').split()) + + # 4. Answer length + answer_length = len(qa.get('answer', '').split()) + + # Calculate complexity of a single QA pair + qa_complexity = ( + min(reasoning_steps_count / 3, 1.0) + * 0.4 # Weight for reasoning steps + + min(supporting_facts_count / 3, 1.0) + * 0.3 # Weight for supporting facts + + min(question_length / 20, 1.0) + * 0.15 # Weight for question length + + min(answer_length / 50, 1.0) * 0.15 + # Weight for answer length + ) + + complexities.append(qa_complexity) + + return sum(complexities) / len(complexities) + + diff --git a/dataflow/operators/generate/KnowledgeCleaning/__init__.py b/dataflow/operators/generate/KnowledgeCleaning/__init__.py index cb4db60..28ce60b 100644 --- a/dataflow/operators/generate/KnowledgeCleaning/__init__.py +++ b/dataflow/operators/generate/KnowledgeCleaning/__init__.py @@ -1,9 +1,11 @@ from .CorpusTextSplitter import CorpusTextSplitter from .KnowledgeExtractor import KnowledgeExtractor from .KnowledgeCleaner import KnowledgeCleaner +from .MultiHopQAGenerator import MultiHopQAGenerator __all__ = [ "CorpusTextSplitter", "KnowledgeExtractor", "KnowledgeCleaner", + "MultiHopQAGenerator", ] \ No newline at end of file diff --git a/dataflow/operators/generate/__init__.py b/dataflow/operators/generate/__init__.py index 5e9d97b..77ce601 100644 --- a/dataflow/operators/generate/__init__.py +++ b/dataflow/operators/generate/__init__.py @@ -4,7 +4,7 @@ from .Reasoning import * from .Text2SQL import * -# from .KnowledgeCleaning import * +from .KnowledgeCleaning import * from .AgenticRAG import * @@ -30,6 +30,7 @@ "CorpusTextSplitter": (cur_path + "KnowledgeCleaning/CorpusTextSplitter.py", "CorpusTextSplitter"), "KnowledgeExtractor": (cur_path + "KnowledgeCleaning/KnowledgeExtractor.py", "KnowledgeExtractor"), "KnowledgeCleaner": (cur_path + "KnowledgeCleaning/KnowledgeCleaner.py", "KnowledgeCleaner"), + "MultiHopQAGenerator": (cur_path + "KnowledgeCleaning/MultiHopQAGenerator.py", "MultiHopQAGenerator"), "AutoPromptGenerator": (cur_path + "AgenticRAG/AutoPromptGenerator.py", "AutoPromptGenerator"), "QAScorer": (cur_path + "AgenticRAG/QAScorer.py", "QAScorer"), "QAGenerator": (cur_path + "AgenticRAG/QAGenerator.py", "QAGenerator"), diff --git a/dataflow/prompts/kbcleaning.py b/dataflow/prompts/kbcleaning.py index 43c9caf..7d4c6cb 100644 --- a/dataflow/prompts/kbcleaning.py +++ b/dataflow/prompts/kbcleaning.py @@ -3,15 +3,94 @@ class KnowledgeCleanerPrompt: 知识清洗提示词生成器,支持中英文多语言适配 Specialized in refining raw content with multilingual support. ''' - def __init__(self, lang: str = "zh", strict_mode: bool = True): + def __init__(self, lang: str = "en", strict_mode: bool = True): self.lang = lang self.strict_mode = strict_mode self._init_prompt_header() def _init_prompt_header(self): """根据语言初始化提示词头部模板""" - if self.lang == "zh": + if self.lang == "en": self.prompt_header = f""" +You are a meticulous Knowledge Refinement Engineer. Apply these rules STRICTLY: + +1. Remove redundant tags but retain: + - Semantic tags like , + - Meaningful attributes + +2. Normalize special characters: + - Standardize quotes and dashes + - Convert ellipsis (...) + +3. URL handling: + - Preserve footnote URLs + - Extract display texts + +4. Text structure: + - Maintain paragraph/list breaks + - Keep code indentation + - Limit empty lines (max=2) + +5. Reference processing (NEW): + - Images → "[Image: alt_text]" + - Signatures → "[Signature]" + +6. Code blocks: {"(strict)" if self.strict_mode else ""} + - {"Force closure" if self.strict_mode else "Preserve raw"} + - Mark fragments as /*...*/ + +7. Absolute fidelity: + - NO fact/number modifications + - NO term paraphrasing + - NO table structure changes + +8. Security Processing (NEW): + - PII: Phone/ID/Email must be masked, e.g. + Original: phone 13800138000 → Processed: phone 138****8000 + - Classified: Mark 【Confidential】as 〖SEC∶classified〗 + - Illegal: Replace sensitive content with 〖ILLEGAL∶removed〗 + - Encryption tags: Use 〖〗for encrypted sections + +Example: +Input: +
+

Knowledge Cleaning™

+
+ Cleaning Flowchart +
Fig.1: Core Process
+
+

Contact: +8613800138000

+

Text with "curly quotes" and – dash – here…

+
Table data
+
function test() {{
+
Signature: John e-signature
+

Confidential: Project budget is 【Secret】

+

Diagram:

+ + +Output: + +Knowledge Cleaning™ + +[Image: Cleaning Flowchart (Three Phases) Fig.1: Core Process] + +Contact: +86*****8000 + +Text with "straight quotes" and - dash - here... + +
Table data
+ +function test() {{ /*...*/ }} + +[Signature]Signature: John [Image: e-signature] + +〖SEC∶classified content〗 + +Diagram: [Image: Diagram demo.jpg] + +""" + else: + self.prompt_header =f""" 你是一名严谨的知识清洗工程师。请严格按照以下规则处理原始内容: 1. 移除冗余HTML/XML标签,但保留: @@ -92,100 +171,11 @@ def _init_prompt_header(self): 示意图:[引用图片:示意图demo.jpg] -""" - else: - self.prompt_header = f""" -You are a meticulous Knowledge Refinement Engineer. Apply these rules STRICTLY: - -1. Remove redundant tags but retain: - - Semantic tags like , - - Meaningful attributes - -2. Normalize special characters: - - Standardize quotes and dashes - - Convert ellipsis (...) - -3. URL handling: - - Preserve footnote URLs - - Extract display texts - -4. Text structure: - - Maintain paragraph/list breaks - - Keep code indentation - - Limit empty lines (max=2) - -5. Reference processing (NEW): - - Images → "[Image: alt_text]" - - Signatures → "[Signature]" - -6. Code blocks: {"(strict)" if self.strict_mode else ""} - - {"Force closure" if self.strict_mode else "Preserve raw"} - - Mark fragments as /*...*/ - -7. Absolute fidelity: - - NO fact/number modifications - - NO term paraphrasing - - NO table structure changes - -8. Security Processing (NEW): - - PII: Phone/ID/Email must be masked, e.g. - Original: phone 13800138000 → Processed: phone 138****8000 - - Classified: Mark 【Confidential】as 〖SEC∶classified〗 - - Illegal: Replace sensitive content with 〖ILLEGAL∶removed〗 - - Encryption tags: Use 〖〗for encrypted sections - -Example: -Input: -
-

Knowledge Cleaning™

-
- Cleaning Flowchart -
Fig.1: Core Process
-
-

Contact: +8613800138000

-

Text with "curly quotes" and – dash – here…

-
Table data
-
function test() {{
-
Signature: John e-signature
-

Confidential: Project budget is 【Secret】

-

Diagram:

- - -Output: - -Knowledge Cleaning™ - -[Image: Cleaning Flowchart (Three Phases) Fig.1: Core Process] - -Contact: +86*****8000 - -Text with "straight quotes" and - dash - here... - -
Table data
- -function test() {{ /*...*/ }} - -[Signature]Signature: John [Image: e-signature] - -〖SEC∶classified content〗 - -Diagram: [Image: Diagram demo.jpg] - """ def Classic_COT_Prompt(self, raw_content: str) -> str: """生成知识清洗的思维链提示词(保持原有格式)""" - if self.lang == "zh": - processing_steps = """ -处理步骤: -1. [标签分析] 识别并分类所有标记标签 -2. [引用提取] 分离图片/表格/签名等引用内容 -3. [字符审核] 记录特殊字符变更 -4. [结构检查] 验证文本层级 -5. [最终输出] 生成清洗后文本 -""".strip() - output_requirement = '响应必须只包含清洗后文本,以开头,结尾,无其他内容。' - else: + if self.lang == "en": processing_steps = """ Processing Steps: 1. [Tag Analysis] Classify markup tags @@ -195,6 +185,16 @@ def Classic_COT_Prompt(self, raw_content: str) -> str: 5. [Final Output] Generate cleaned text """.strip() output_requirement = 'Response must contain ONLY cleaned text between and .' + else: + processing_steps = """ +处理步骤: +1. [标签分析] 识别并分类所有标记标签 +2. [引用提取] 分离图片/表格/签名等引用内容 +3. [字符审核] 记录特殊字符变更 +4. [结构检查] 验证文本层级 +5. [最终输出] 生成清洗后文本 +""".strip() + output_requirement = '响应必须只包含清洗后文本,以开头,结尾,无其他内容。' return f""" {self.prompt_header} diff --git a/dataflow/prompts/multihopqa.py b/dataflow/prompts/multihopqa.py new file mode 100644 index 0000000..030c765 --- /dev/null +++ b/dataflow/prompts/multihopqa.py @@ -0,0 +1,173 @@ +import textwrap +from typing import Dict, Literal + +class MultiHopQAGeneratorPrompt: + ''' + 多跳问答生成器(严格JSON格式输出) + 根据语言参数提供完全独立的专业提示模板 + ''' + def __init__(self, lang: str = "en"): + self.lang = lang + self.system_text = self._build_system_prompt() + + def _build_system_prompt(self) -> str: + """构建专业级多跳问答提示""" + if self.lang == "en": + return textwrap.dedent("""\ + You are a professional multi-hop QA specialist with strict protocols: + + █ Core Requirements + 1. Must identify 2-3 interrelated facts in context + 2. Design complex questions requiring cross-fact reasoning + 3. Reasoning chains must: + - Contain 2-3 logical steps (numbered) + - Show clear causal/progressive relationships + - Each step must reference specific facts + 4. Final answer must synthesize all reasoning conclusions + + █ Output Specifications + 1. Only pure JSON in this structure: + { + "question": "Multi-fact reasoning question", + "reasoning_steps": [ + {"step": "First step (must use Fact 1)"}, + {"step": "Second step (must link Fact 2)"} + ], + "answer": "Synthesized final answer", + "supporting_facts": ["Verbatim Fact 1", "Verbatim Fact 2"], + "type": "domain_tag" + } + 2. Supporting facts must: + - Be verbatim from context + - Directly support corresponding steps + - No paraphrasing allowed + + █ Demonstration + Context: + "Photosynthesis converts CO2 to oxygen. This process sustains plant growth. Plants form the base of food chains." + + Valid Output: + { + "question": "How does photosynthesis impact ecosystems?", + "reasoning_steps": [ + {"step": "Photosynthesis produces oxygen"}, + {"step": "Plants using photosynthesis form food chain bases"} + ], + "answer": "It provides oxygen and sustains ecosystem food chains", + "supporting_facts": [ + "Photosynthesis converts CO2 to oxygen", + "Plants form the base of food chains" + ], + "type": "biology" + } + + █ Rejection Criteria + Reject if: + - Fewer than 2 reasoning steps + - Unreferenced supporting facts exist + - Any non-JSON content appears + """) + else: + return textwrap.dedent("""\ + 您是专业的多跳问答生成专家,必须严格遵循以下专业标准: + + █ 核心要求 + 1. 必须识别上下文中的2-3个关联事实 + 2. 设计需要跨事实推理的复杂问题 + 3. 推理链必须满足: + - 至少包含2-3个逻辑步骤 + - 每个步骤明确标注序号 + - 步骤间存在因果或递进关系 + 4. 最终答案必须整合所有推理结论 + + █ 输出规范 + 1. 仅允许输出以下结构的纯JSON: + { + "question": "需要跨事实推理的问题", + "reasoning_steps": [ + {"step": "第一推理步骤(必须引用事实1)"}, + {"step": "第二推理步骤(必须关联事实2)"} + ], + "answer": "整合所有步骤的最终答案", + "supporting_facts": ["原文事实1", "原文事实2"], + "type": "领域标签" + } + 2. 支撑事实必须: + - 从上下文逐字提取 + - 与推理步骤严格对应 + - 不得改写或概括 + + █ 示例 + 上下文: + "量子纠缠现象由爱因斯坦提出质疑。后来贝尔实验证实了其真实性。该现象是量子计算的基础。" + + 合格输出: + { + "question": "为什么量子纠缠现象对量子计算很重要?", + "reasoning_steps": [ + {"step": "贝尔实验证实了量子纠缠的真实性"}, + {"step": "该现象是量子计算的基础"} + ], + "answer": "因为量子纠缠被证实真实且是量子计算的基础", + "supporting_facts": [ + "后来贝尔实验证实了其真实性", + "该现象是量子计算的基础" + ], + "type": "量子物理" + } + + █ 违规处理 + 以下情况将拒绝输出: + - 推理步骤少于2步 + - 存在未引用的支撑事实 + - JSON外出现任何附加文本 + """) + + def _multihop_qa_generator_user_prompt(self, text: str) -> str: + """生成完全专业化的用户提示""" + if self.lang == "en": + user_prompt = textwrap.dedent(f"""\ + Generate professional multi-hop QA from: + + Context: + {text} + + Strict requirements: + 1. Extract exactly 2-3 interrelated facts + 2. Question must demonstrate cross-fact reasoning + 3. Use this exact JSON structure (include all quotes/braces): + {{ + "question": "...", + "reasoning_steps": [ + {{"step": "Must explicitly use Fact 1"}}, + {{"step": "Must explicitly link Fact 2"}} + ], + "answer": "...", + "supporting_facts": ["Verbatim Fact 1", "Verbatim Fact 2"], + "type": "..." + }} + """) + else: + user_prompt = textwrap.dedent(f"""\ + 请基于以下上下文生成专业级多跳问答: + + 上下文: + {text} + + 严格按照以下要求执行: + 1. 必须从上述上下文中提取2-3个关联事实 + 2. 问题需体现跨事实推理的复杂性 + 3. 使用此精确JSON结构(包括所有引号和括号): + {{ + "question": "...", + "reasoning_steps": [ + {{"step": "必须明确引用事实1"}}, + {{"step": "必须明确关联事实2"}} + ], + "answer": "...", + "supporting_facts": ["事实1原文", "事实2原文"], + "type": "..." + }} + """) + + return user_prompt \ No newline at end of file diff --git a/dataflow/utils/kbcleaning.py b/dataflow/utils/kbcleaning.py new file mode 100644 index 0000000..ce639ea --- /dev/null +++ b/dataflow/utils/kbcleaning.py @@ -0,0 +1,73 @@ +import os +from pathlib import Path +from trafilatura import fetch_url, extract +from dataflow.logger import get_logger + +def _parse_pdf_to_md( + input_pdf_path: str, + output_dir: str, + lang: str = "ch", + parse_method: str = "auto" # 解析方法:auto/txt/ocr +): + """ + 将PDF转换为Markdown(仅使用Pipeline后端) + """ + logger=get_logger() + # 读取PDF文件 + pdf_bytes = Path(input_pdf_path).read_bytes() + pdf_name = Path(input_pdf_path).stem + + # 解析PDF + infer_results, all_image_lists, all_pdf_docs, _, ocr_enabled_list = pipeline_doc_analyze( + [pdf_bytes], [lang], parse_method=parse_method + ) + + # 准备输出目录 + image_dir = os.path.join(output_dir, f"{pdf_name}_images") + os.makedirs(image_dir, exist_ok=True) + image_writer = FileBasedDataWriter(image_dir) + md_writer = FileBasedDataWriter(output_dir) + + # 生成中间结果和Markdown + middle_json = pipeline_result_to_middle_json( + infer_results[0], all_image_lists[0], all_pdf_docs[0], + image_writer, lang, ocr_enabled_list[0], True + ) + md_content = pipeline_union_make(middle_json["pdf_info"], MakeMode.MM_MD, os.path.basename(image_dir)) + # 保存Markdown + md_writer.write_string(f"{pdf_name}_pdf.md", md_content) + logger.info(f"Markdown saved to: {os.path.join(output_dir, f'{pdf_name}_pdf.md')}") + + return os.path.join(output_dir,f"{pdf_name}_pdf.md") + +def _parse_doc_to_md(input_file: str, output_file: str): + """ + support conversion of doc/ppt/pptx/pdf files to markdowns + """ + logger=get_logger() + converter = DocConverter(s3_config=None) + markdown_content, time_cost = converter.convert(input_file, conv_timeout=300) + logger.info("time cost: ", time_cost) + with open(output_file, "w",encoding='utf-8') as f: + f.write(markdown_content) + return output_file + +def _parse_xml_to_md(raw_file:str=None, url:str=None, output_file:str=None): + logger=get_logger() + if(url): + downloaded=fetch_url(url) + elif(raw_file): + with open(raw_file, "r", encoding='utf-8') as f: + downloaded=f.read() + else: + raise Exception("Please provide at least one of file path and url string.") + + try: + result=extract(downloaded, output_format="markdown", with_metadata=True) + logger.info(f"Extracted content is written into {output_file}") + with open(output_file,"w", encoding="utf-8") as f: + f.write(result) + except Exception as e: + logger.error("Error during extract this file or link: ", e) + + return output_file \ No newline at end of file diff --git a/dataflow/utils/storage.py b/dataflow/utils/storage.py index 6804d77..9b131f9 100644 --- a/dataflow/utils/storage.py +++ b/dataflow/utils/storage.py @@ -157,7 +157,7 @@ def write(self, data: Any) -> Any: os.makedirs(os.path.dirname(file_path), exist_ok=True) print(f"Writing data to {file_path} with type {self.cache_type}") if self.cache_type == "json": - dataframe.to_json(file_path, orient="records", force_ascii=False) + dataframe.to_json(file_path, orient="records", force_ascii=False, indent=2) elif self.cache_type == "jsonl": dataframe.to_json(file_path, orient="records", lines=True, force_ascii=False) elif self.cache_type == "csv": diff --git a/pyproject.toml b/pyproject.toml index b2bfcf6..ae9820b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,5 +49,7 @@ packages = ["dataflow"] # 显式指定主包 version = {attr = "dataflow.version.__version__"} dependencies = {file = "requirements.txt"} + [project.optional-dependencies] -vllm = ["vllm<0.8"] \ No newline at end of file +vllm = ["vllm<0.8"] + diff --git a/requirements-kbc.txt b/requirements-kbc.txt index e8dadcc..9d5c41d 100644 --- a/requirements-kbc.txt +++ b/requirements-kbc.txt @@ -1,5 +1,3 @@ -mineru[core]==2.0.6 -fairy-doc[gpu]==0.1.43 -libreoffice -chonkie==1.0.10 -trafilatura==2.0.0 \ No newline at end of file +fairy-doc[gpu] +chonkie +trafilatura \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a555a44..2b1ef86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,6 @@ pydantic nltk func_timeout - +# knowledge base cleaning +chonkie +trafilatura diff --git a/test/test_dockbcleaning.py b/test/test_dockbcleaning.py new file mode 100644 index 0000000..78c0139 --- /dev/null +++ b/test/test_dockbcleaning.py @@ -0,0 +1,88 @@ +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent)) + +from dataflow.operators.generate.KnowledgeCleaning import ( + CorpusTextSplitter, + KnowledgeExtractor, + KnowledgeCleaner, + MultiHopQAGenerator, +) +from dataflow.utils.storage import FileStorage +from dataflow.llmserving import LocalModelLLMServing + +class KBCleaningPipeline(): + def __init__(self): + + self.storage = FileStorage( + first_entry_file_name="dataflow/example/KBCleaningPipeline/kbc_placeholder.json", + cache_path="./.cache", + file_name_prefix="doc_cleaning_step", + cache_type="json", + ) + + # api_llm_serving = APILLMServing_request( + # api_url="http://123.129.219.111:3000/v1/chat/completions", + # model_name="gpt-4o", + # max_workers=100 + # ) + + local_llm_serving = LocalModelLLMServing( + model_name_or_path="/data0/models/Qwen2.5-7B-Instruct", + max_tokens=512, + tensor_parallel_size=4, + model_source="local", + gpu_memory_utilization=0.6, + repetition_penalty=1.2 + ) + + self.knowledge_cleaning_step1 = KnowledgeExtractor( + intermediate_dir="dataflow/example/KBCleaningPipeline/raw/" + ) + + self.knowledge_cleaning_step2 = CorpusTextSplitter( + split_method="token", + chunk_size=1024, + tokenizer_name="/data0/hzy/RARE/model_base/Qwen2.5-3B-Instruct", + ) + + self.knowledge_cleaning_step3 = KnowledgeCleaner( + llm_serving=local_llm_serving, + lang="ch" + ) + + self.knowledge_cleaning_step4 = MultiHopQAGenerator( + llm_serving=local_llm_serving, + lang="ch" + ) + + def forward(self, url:str=None, raw_file:str=None): + extracted=self.knowledge_cleaning_step1.run( + storage=self.storage, + raw_file=raw_file, + url=url, + lang="ch" + ) + + self.knowledge_cleaning_step2.run( + storage=self.storage.step(), + input_file=extracted, + output_key="raw_content", + ) + + self.knowledge_cleaning_step3.run( + storage=self.storage.step(), + input_key= "raw_content", + output_key="cleaned", + ) + + self.knowledge_cleaning_step4.run( + storage=self.storage.step(), + input_key="cleaned", + output_key="MultiHop_QA" + ) + +if __name__ == "__main__": + model = KBCleaningPipeline() + model.forward(raw_file="/data0/hzy/DataFlow-Preview/dataflow/example/KBCleaningPipeline/test.doc") + diff --git a/test/test_ragkbcleaning.py b/test/test_pdfkbcleaning.py similarity index 82% rename from test/test_ragkbcleaning.py rename to test/test_pdfkbcleaning.py index 7526d54..a4d1eaf 100644 --- a/test/test_ragkbcleaning.py +++ b/test/test_pdfkbcleaning.py @@ -6,19 +6,19 @@ CorpusTextSplitter, KnowledgeExtractor, KnowledgeCleaner, + MultiHopQAGenerator, ) from dataflow.utils.storage import FileStorage from dataflow.llmserving import LocalModelLLMServing -# 这里或许未来可以有个pipeline基类 class KBCleaningPipeline(): def __init__(self): self.storage = FileStorage( - first_entry_file_name="dataflow/example/KBCleaningPipeline/pdf_test.json", + first_entry_file_name="dataflow/example/KBCleaningPipeline/kbc_placeholder.json", cache_path="./.cache", file_name_prefix="pdf_cleaning_step", - cache_type="jsonl", + cache_type="json", ) # api_llm_serving = APILLMServing_request( @@ -33,6 +33,7 @@ def __init__(self): tensor_parallel_size=4, model_source="local", gpu_memory_utilization=0.6, + repetition_penalty=1.2 ) self.knowledge_cleaning_step1 = KnowledgeExtractor( @@ -47,15 +48,20 @@ def __init__(self): self.knowledge_cleaning_step3 = KnowledgeCleaner( llm_serving=local_llm_serving, - lang="zh" + lang="en" ) - # 未来或许可以维护一个类似nn.sequential的容器,方便添加并实例化多个算子 + + self.knowledge_cleaning_step4 = MultiHopQAGenerator( + llm_serving=local_llm_serving, + lang="en" + ) + def forward(self, url:str=None, raw_file:str=None): extracted=self.knowledge_cleaning_step1.run( storage=self.storage, raw_file=raw_file, url=url, - lang="ch" + lang="en" ) self.knowledge_cleaning_step2.run( @@ -69,8 +75,12 @@ def forward(self, url:str=None, raw_file:str=None): input_key= "raw_content", output_key="cleaned", ) + self.knowledge_cleaning_step4.run( + storage=self.storage.step(), + input_key="cleaned", + output_key="MultiHop_QA" + ) if __name__ == "__main__": model = KBCleaningPipeline() - model.forward(raw_file="/data0/hzy/DataFlow-Preview/test_mineru/muban.pdf") - + model.forward(raw_file="/data0/hzy/DataFlow-Preview/test_mineru/muban.pdf") \ No newline at end of file diff --git a/test/test_urlkbcleaning.py b/test/test_urlkbcleaning.py new file mode 100644 index 0000000..2d2b3db --- /dev/null +++ b/test/test_urlkbcleaning.py @@ -0,0 +1,90 @@ +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent)) + +from dataflow.operators.generate.KnowledgeCleaning import ( + CorpusTextSplitter, + KnowledgeExtractor, + KnowledgeCleaner, + MultiHopQAGenerator, +) +from dataflow.utils.storage import FileStorage +from dataflow.llmserving import LocalModelLLMServing + +# 这里或许未来可以有个pipeline基类 +class KBCleaningPipeline(): + def __init__(self): + + self.storage = FileStorage( + first_entry_file_name="dataflow/example/KBCleaningPipeline/kbc_placeholder.json", + cache_path="./.cache", + file_name_prefix="url_cleaning_step", + cache_type="json", + ) + + # api_llm_serving = APILLMServing_request( + # api_url="http://123.129.219.111:3000/v1/chat/completions", + # model_name="gpt-4o", + # max_workers=100 + # ) + + local_llm_serving = LocalModelLLMServing( + model_name_or_path="/data0/models/Qwen2.5-7B-Instruct", + max_tokens=1024, + tensor_parallel_size=4, + model_source="local", + gpu_memory_utilization=0.6, + repetition_penalty=1.2 + ) + + self.knowledge_cleaning_step1 = KnowledgeExtractor( + intermediate_dir="dataflow/example/KBCleaningPipeline/raw/" + ) + + self.knowledge_cleaning_step2 = CorpusTextSplitter( + split_method="token", + chunk_size=512, + tokenizer_name="/data0/hzy/RARE/model_base/Qwen2.5-3B-Instruct", + ) + + self.knowledge_cleaning_step3 = KnowledgeCleaner( + llm_serving=local_llm_serving, + lang="en" + ) + + self.knowledge_cleaning_step4 = MultiHopQAGenerator( + llm_serving=local_llm_serving, + lang="en" + ) + + # 未来或许可以维护一个类似nn.sequential的容器,方便添加并实例化多个算子 + def forward(self, url:str=None, raw_file:str=None): + extracted=self.knowledge_cleaning_step1.run( + storage=self.storage, + raw_file=raw_file, + url=url, + lang="en" + ) + + self.knowledge_cleaning_step2.run( + storage=self.storage.step(), + input_file=extracted, + output_key="raw_content", + ) + + self.knowledge_cleaning_step3.run( + storage=self.storage.step(), + input_key= "raw_content", + output_key="cleaned", + ) + + self.knowledge_cleaning_step4.run( + storage=self.storage.step(), + input_key="cleaned", + output_key="MultiHop_QA" + ) + +if __name__ == "__main__": + model = KBCleaningPipeline() + model.forward(url="https://trafilatura.readthedocs.io/en/latest/quickstart.html") +