|
| 1 | +import re |
| 2 | +import pandas as pd |
| 3 | +import numpy as np |
| 4 | +from tqdm import tqdm |
| 5 | +from dataflow.core import OperatorABC, LLMServingABC |
| 6 | +from dataflow.utils.storage import DataFlowStorage |
| 7 | +from dataflow.prompts.func_call import CompositionTaskFilterPrompt |
| 8 | +from dataflow.logger import get_logger |
| 9 | +from dataflow.utils.registry import OPERATOR_REGISTRY |
| 10 | +from dataflow.core.prompt import prompt_restrict |
| 11 | + |
| 12 | +@prompt_restrict( |
| 13 | + CompositionTaskFilterPrompt |
| 14 | +) |
| 15 | + |
| 16 | +@OPERATOR_REGISTRY.register() |
| 17 | +class CompositionTaskFilter(OperatorABC): |
| 18 | + def __init__(self, llm_serving: LLMServingABC): |
| 19 | + self.logger = get_logger() |
| 20 | + self.prompt = CompositionTaskFilterPrompt() |
| 21 | + self.llm_serving = llm_serving |
| 22 | + self.logger.info(f"Initializing {self.__class__.__name__}...") |
| 23 | + |
| 24 | + @staticmethod |
| 25 | + def get_desc(lang: str = "zh"): |
| 26 | + if lang == "zh": |
| 27 | + return ( |
| 28 | + "根据组合任务及其子任务,使用LLM服务判断组合任务是否具备可行性与完备性,从而进行可运行任务的筛选。\n" |
| 29 | + "输入参数:\n" |
| 30 | + "- llm_serving:LLM服务对象,需实现LLMServingABC接口\n" |
| 31 | + "- input_composition_task_key:组合任务字段名\n" |
| 32 | + "- input_sub_tasks_keys:子任务字段名列表(如原子任务、并行任务、后继任务等)\n" |
| 33 | + "- output_key:可运行标签的输出字段名,默认'runable_label'\n" |
| 34 | + "输出参数:\n" |
| 35 | + "- 仅包含可运行组合任务的数据DataFrame\n" |
| 36 | + "- 包含输出字段名的列表(可运行标签字段)" |
| 37 | + ) |
| 38 | + elif lang == "en": |
| 39 | + return ( |
| 40 | + "Evaluate the feasibility and completeness of a composition task based on its sub-tasks using an LLM service, and filter out unexecutable tasks.\n" |
| 41 | + "Input Parameters:\n" |
| 42 | + "- llm_serving: LLM serving object implementing LLMServingABC interface\n" |
| 43 | + "- input_composition_task_key: Field name for the composition task\n" |
| 44 | + "- input_sub_tasks_keys: List of field names for sub-tasks (e.g., atomic, parallel, subsequent tasks)\n" |
| 45 | + "- output_key: Field name for the executability label, default 'runable_label'\n" |
| 46 | + "Output Parameters:\n" |
| 47 | + "- DataFrame containing only executable composition tasks\n" |
| 48 | + "- List containing the output field name (executability label)" |
| 49 | + ) |
| 50 | + else: |
| 51 | + return "Filter composition tasks for feasibility and completeness using LLM service." |
| 52 | + |
| 53 | + |
| 54 | + def _reformat_prompt(self, dataframe: pd.DataFrame): |
| 55 | + formatted_prompts = [] |
| 56 | + for task, sub_tasks in tqdm(zip(dataframe[self.input_composition_task_key], dataframe[self.input_sub_tasks_keys].to_dict(orient='records')), desc="Reformatting prompts..."): |
| 57 | + formatted_prompts.append(self.prompt.build_prompt(task=task, sub_tasks=sub_tasks)) |
| 58 | + # formatted_prompts = [self.prompt.filter_composition_task(task=item, sub_tasks=sub_tasks) for item, sub_tasks in tqdm(zip(dataframe[self.input_composition_task_key], dataframe[self.input_sub_tasks_key]), desc=f"Reformatting prompts...")] |
| 59 | + return formatted_prompts |
| 60 | + |
| 61 | + def run(self, storage: DataFlowStorage, input_composition_task_key: str, input_sub_tasks_keys: list[str], output_key: str = "runable_label"): |
| 62 | + self.input_composition_task_key = input_composition_task_key |
| 63 | + self.input_sub_tasks_keys = input_sub_tasks_keys |
| 64 | + self.output_key = output_key |
| 65 | + dataframe = storage.read("dataframe") |
| 66 | + llm_inputs = self._reformat_prompt(dataframe) |
| 67 | + self.logger.debug(f"One of formatted prompts: {llm_inputs[0]}") |
| 68 | + llm_outputs = self.llm_serving.generate_from_input(llm_inputs) |
| 69 | + self.logger.debug(f"One of LLM outputs: {llm_outputs[0]}") |
| 70 | + labels = [] |
| 71 | + for item in llm_outputs: |
| 72 | + match = re.search(r"<ans>(yes|no)</ans>", item.strip(), re.IGNORECASE) |
| 73 | + if match: |
| 74 | + labels.append(1 if match.group(1).lower() == "yes" else 0) |
| 75 | + else: |
| 76 | + labels.append(0) |
| 77 | + dataframe[self.output_key] = labels |
| 78 | + dataframe = dataframe[dataframe[self.output_key] > 0] |
| 79 | + storage.write(dataframe) |
| 80 | + output_file = storage.write(dataframe) |
| 81 | + self.logger.info(f"Results saved to {output_file}") |
| 82 | + return [self.output_key] |
| 83 | + |
0 commit comments