Skip to content

Commit 8f7da08

Browse files
authored
Merge branch 'OpenDCAI:main' into main
2 parents d1f992a + e1b06f2 commit 8f7da08

6 files changed

Lines changed: 403 additions & 144 deletions

File tree

dataflow/operators/conversations/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
from typing import TYPE_CHECKING
22

33
if TYPE_CHECKING:
4-
from .generate.func_call_operators import (
5-
ScenarioExtractor,
6-
ScenarioExpander,
4+
from .generate.func_call_generators import (
5+
ScenarioExtractGenerator,
6+
ScenarioExpandGenerator,
77
AtomTaskGenerator,
88
SequentialTaskGenerator,
99
ParaSeqTaskGenerator,
1010
FunctionGenerator,
1111
MultiTurnConversationGenerator,
1212
)
1313
from .generate.consistent_chat_generator import ConsistentChatGenerator
14+
15+
from .eval.func_call_conversation_sample_evaluator import FuncCallConversationSampleEvaluator
16+
17+
from .filter.composition_task_filter import CompositionTaskFilter
1418

1519
else:
1620
import sys
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import re
2+
import pandas as pd
3+
import numpy as np
4+
from tqdm import tqdm
5+
from dataflow.core import OperatorABC, LLMServingABC
6+
from dataflow.core.prompt import prompt_restrict
7+
from dataflow.utils.storage import DataFlowStorage
8+
from dataflow.prompts.func_call import ConversationEvalPrompt
9+
from dataflow.logger import get_logger
10+
from dataflow.utils.registry import OPERATOR_REGISTRY
11+
12+
@prompt_restrict(
13+
ConversationEvalPrompt
14+
)
15+
16+
@OPERATOR_REGISTRY.register()
17+
class FuncCallConversationSampleEvaluator(OperatorABC):
18+
19+
def __init__(self, llm_serving: LLMServingABC):
20+
self.llm_serving = llm_serving
21+
self.prompt = ConversationEvalPrompt()
22+
self.logger = get_logger()
23+
24+
@staticmethod
25+
def get_desc(lang: str = "zh"):
26+
if lang == "zh":
27+
return (
28+
"对对话样本进行打分评估:使用 LLM 服务根据预设评分提示词对每条对话进行评分,并将结果写回数据流。\n"
29+
"输入参数:\n"
30+
"- llm_serving:LLM 服务对象,需实现 LLMServingABC 接口\n"
31+
"- input_conversation_key:DataFrame 中对话内容字段名,默认 'conversations'\n"
32+
"- output_score_key:评分结果输出字段名,默认 'score'\n"
33+
"处理流程:\n"
34+
"- 读取存储中的 DataFrame\n"
35+
"- 将每条对话重组为评分提示词并调用 LLM 生成评分(JSON)\n"
36+
"- 解析 JSON,提取 'score' 字段写入 DataFrame;解析失败则回退为 0\n"
37+
"输出参数:\n"
38+
"- 包含评分结果列的 DataFrame\n"
39+
"- 包含输出字段名的列表(仅 'score' 或自定义的输出列名)"
40+
)
41+
elif lang == "en":
42+
return (
43+
"Evaluate conversation samples with an LLM-based scorer: the operator formats each "
44+
"conversation into a scoring prompt, calls the LLM, parses the JSON response, and writes the score back.\n"
45+
"Input Parameters:\n"
46+
"- llm_serving: LLM serving object implementing LLMServingABC\n"
47+
"- input_conversation_key: column name for conversations in the DataFrame, default 'conversations'\n"
48+
"- output_score_key: column name for the score output, default 'score'\n"
49+
"Process:\n"
50+
"- Read the DataFrame from storage\n"
51+
"- Reformat each conversation into a scoring prompt and call the LLM (expects JSON)\n"
52+
"- Parse the JSON to extract 'score'; fallback to 0 on parse errors\n"
53+
"Output:\n"
54+
"- DataFrame with a score column added\n"
55+
"- A list containing the output field name (e.g., 'score')"
56+
)
57+
else:
58+
return "Evaluate conversation samples with an LLM-based scorer and write the parsed 'score' back to the DataFrame."
59+
60+
def _reformat_prompt(self, dataframe: pd.DataFrame):
61+
formatted_prompts = []
62+
for conversation in tqdm(dataframe[self.input_conversation_key], desc="Reformatting prompts..."):
63+
formatted_prompts.append(self.prompt.build_prompt(conversation=conversation))
64+
return formatted_prompts
65+
66+
def clean_json_block(self, s: str) -> str:
67+
s = s.strip()
68+
if s.startswith("```"):
69+
# 去掉首尾 ```json 或 ``` 包裹
70+
s = s.strip("`")
71+
s = s.replace("json\n", "", 1) # 去掉开头的 json\n
72+
if s.endswith("```"):
73+
s = s[:-3]
74+
return s.strip()
75+
76+
def json_validate(self, llm_outputs):
77+
import json
78+
scores = []
79+
for item in llm_outputs:
80+
score = 0
81+
try:
82+
data = json.loads(self.clean_json_block(item))
83+
score = data['score']
84+
except Exception as e:
85+
self.logger.debug(f"json loading error in item {item}")
86+
scores.append(score)
87+
return scores
88+
89+
def run(self, storage: DataFlowStorage, input_conversation_key: str = "conversations", output_score_key = "score"):
90+
self.input_conversation_key = input_conversation_key
91+
self.output_score_key = output_score_key
92+
dataframe = storage.read("dataframe")
93+
llm_inputs = self._reformat_prompt(dataframe)
94+
llm_outputs = self.llm_serving.generate_from_input(llm_inputs)
95+
dataframe[self.output_score_key] = self.json_validate(llm_outputs)
96+
storage.write(dataframe)
97+
output_file = storage.write(dataframe)
98+
self.logger.info(f"Results saved to {output_file}")
99+
return [self.output_score_key]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import re
2+
import pandas as pd
3+
import numpy as np
4+
from tqdm import tqdm
5+
from dataflow.core import OperatorABC, LLMServingABC
6+
from dataflow.utils.storage import DataFlowStorage
7+
from dataflow.prompts.func_call import CompositionTaskFilterPrompt
8+
from dataflow.logger import get_logger
9+
from dataflow.utils.registry import OPERATOR_REGISTRY
10+
from dataflow.core.prompt import prompt_restrict
11+
12+
@prompt_restrict(
13+
CompositionTaskFilterPrompt
14+
)
15+
16+
@OPERATOR_REGISTRY.register()
17+
class CompositionTaskFilter(OperatorABC):
18+
def __init__(self, llm_serving: LLMServingABC):
19+
self.logger = get_logger()
20+
self.prompt = CompositionTaskFilterPrompt()
21+
self.llm_serving = llm_serving
22+
self.logger.info(f"Initializing {self.__class__.__name__}...")
23+
24+
@staticmethod
25+
def get_desc(lang: str = "zh"):
26+
if lang == "zh":
27+
return (
28+
"根据组合任务及其子任务,使用LLM服务判断组合任务是否具备可行性与完备性,从而进行可运行任务的筛选。\n"
29+
"输入参数:\n"
30+
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
31+
"- input_composition_task_key:组合任务字段名\n"
32+
"- input_sub_tasks_keys:子任务字段名列表(如原子任务、并行任务、后继任务等)\n"
33+
"- output_key:可运行标签的输出字段名,默认'runable_label'\n"
34+
"输出参数:\n"
35+
"- 仅包含可运行组合任务的数据DataFrame\n"
36+
"- 包含输出字段名的列表(可运行标签字段)"
37+
)
38+
elif lang == "en":
39+
return (
40+
"Evaluate the feasibility and completeness of a composition task based on its sub-tasks using an LLM service, and filter out unexecutable tasks.\n"
41+
"Input Parameters:\n"
42+
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
43+
"- input_composition_task_key: Field name for the composition task\n"
44+
"- input_sub_tasks_keys: List of field names for sub-tasks (e.g., atomic, parallel, subsequent tasks)\n"
45+
"- output_key: Field name for the executability label, default 'runable_label'\n"
46+
"Output Parameters:\n"
47+
"- DataFrame containing only executable composition tasks\n"
48+
"- List containing the output field name (executability label)"
49+
)
50+
else:
51+
return "Filter composition tasks for feasibility and completeness using LLM service."
52+
53+
54+
def _reformat_prompt(self, dataframe: pd.DataFrame):
55+
formatted_prompts = []
56+
for task, sub_tasks in tqdm(zip(dataframe[self.input_composition_task_key], dataframe[self.input_sub_tasks_keys].to_dict(orient='records')), desc="Reformatting prompts..."):
57+
formatted_prompts.append(self.prompt.build_prompt(task=task, sub_tasks=sub_tasks))
58+
# formatted_prompts = [self.prompt.filter_composition_task(task=item, sub_tasks=sub_tasks) for item, sub_tasks in tqdm(zip(dataframe[self.input_composition_task_key], dataframe[self.input_sub_tasks_key]), desc=f"Reformatting prompts...")]
59+
return formatted_prompts
60+
61+
def run(self, storage: DataFlowStorage, input_composition_task_key: str, input_sub_tasks_keys: list[str], output_key: str = "runable_label"):
62+
self.input_composition_task_key = input_composition_task_key
63+
self.input_sub_tasks_keys = input_sub_tasks_keys
64+
self.output_key = output_key
65+
dataframe = storage.read("dataframe")
66+
llm_inputs = self._reformat_prompt(dataframe)
67+
self.logger.debug(f"One of formatted prompts: {llm_inputs[0]}")
68+
llm_outputs = self.llm_serving.generate_from_input(llm_inputs)
69+
self.logger.debug(f"One of LLM outputs: {llm_outputs[0]}")
70+
labels = []
71+
for item in llm_outputs:
72+
match = re.search(r"<ans>(yes|no)</ans>", item.strip(), re.IGNORECASE)
73+
if match:
74+
labels.append(1 if match.group(1).lower() == "yes" else 0)
75+
else:
76+
labels.append(0)
77+
dataframe[self.output_key] = labels
78+
dataframe = dataframe[dataframe[self.output_key] > 0]
79+
storage.write(dataframe)
80+
output_file = storage.write(dataframe)
81+
self.logger.info(f"Results saved to {output_file}")
82+
return [self.output_key]
83+

0 commit comments

Comments
 (0)