Skip to content

Commit 6753311

Browse files
authored
Merge branch 'OpenDCAI:main' into main
2 parents b8e1538 + 190595e commit 6753311

55 files changed

Lines changed: 994 additions & 405 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dataflow/cli_funcs/cli_pdf.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ def check_required_files():
199199
# 检查所有需要的内置脚本
200200
required_scripts = [
201201
"path_to_jsonl_script.py",
202-
"merge_filter_qa_pairs.py",
203202
"llama_factory_trainer.py"
204203
]
205204

@@ -321,31 +320,82 @@ def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: s
321320
print("-" * 60)
322321

323322
try:
324-
# Step 1: PDF Detection - 使用内置脚本
323+
# Step 1: PDF Detection
325324
script1_path = get_dataflow_script_path("path_to_jsonl_script.py")
326325
args1 = ["./", "--output", str(cache_path_obj / ".cache" / "gpu" / "pdf_list.jsonl")]
327326
if not run_script_with_args(script1_path, "Step 1: PDF Detection", args1, cwd=str(current_dir)):
328327
return False
329328

330-
# Step 2: Data Processing - 使用用户目录下的脚本
329+
# Step 2: Data Processing
331330
script2 = current_dir / "pdf_to_qa_pipeline.py"
332331
args2 = ["--cache", cache_path]
333332
if not run_script_with_args(script2, "Step 2: Data Processing", args2, cwd=str(current_dir)):
334333
return False
335334

336-
# Step 3: Data Conversion - 使用内置脚本
337-
script3_path = get_dataflow_script_path("merge_filter_qa_pairs.py")
338-
args3 = ["--cache", cache_path]
339-
if not run_script_with_args(script3_path, "Step 3: Data Conversion", args3, cwd=str(current_dir)):
335+
# Step 2.5: Create dataset_info.json (dynamically)
336+
print(f"\n{Fore.BLUE}Step 2.5: Creating dataset_info.json{Style.RESET_ALL}")
337+
338+
# 读取训练配置,获取数据集名称
339+
try:
340+
with open(config_path_obj, 'r', encoding='utf-8') as f:
341+
train_config = yaml.safe_load(f)
342+
343+
# 获取数据集名称
344+
dataset_name = train_config.get('dataset')
345+
if isinstance(dataset_name, list):
346+
dataset_name = dataset_name[0] # 如果是列表,取第一个
347+
348+
if not dataset_name:
349+
print("Warning: No dataset name found in train_config.yaml, using default 'kb_qa'")
350+
dataset_name = 'kb_qa'
351+
352+
print(f"Dataset name from config: {dataset_name}")
353+
354+
except Exception as e:
355+
print(f"Warning: Could not read train_config.yaml: {e}")
356+
print("Using default dataset name: kb_qa")
357+
dataset_name = 'kb_qa'
358+
359+
# 创建 dataset_info.json
360+
dataset_info_path = cache_path_obj / ".cache" / "data" / "dataset_info.json"
361+
dataset_info_path.parent.mkdir(parents=True, exist_ok=True)
362+
363+
dataset_info = {
364+
dataset_name: { # ← 使用从配置读取的名称
365+
"file_name": "qa.json",
366+
"formatting": "alpaca",
367+
"columns": {
368+
"prompt": "instruction",
369+
"query": "input",
370+
"response": "output"
371+
}
372+
}
373+
}
374+
375+
with open(dataset_info_path, 'w', encoding='utf-8') as f:
376+
json.dump(dataset_info, f, indent=2, ensure_ascii=False)
377+
378+
print(f"Created: {dataset_info_path}")
379+
print(f"Dataset registered as: {dataset_name}")
380+
print(f"{Fore.GREEN}✅ Step 2.5: Creating dataset_info.json completed{Style.RESET_ALL}")
381+
382+
# Step 3: Data Conversion - skip
383+
print(f"\n{Fore.BLUE}Step 3: Data Conversion{Style.RESET_ALL}")
384+
qa_json_path = cache_path_obj / ".cache" / "data" / "qa.json"
385+
if qa_json_path.exists():
386+
print(f"✅ qa.json already in correct format, skipping conversion")
387+
print(f"{Fore.GREEN}✅ Step 3: Data Conversion completed{Style.RESET_ALL}")
388+
else:
389+
print(f"❌ qa.json not found at {qa_json_path}")
340390
return False
341391

342-
# Step 4: Training - 使用内置脚本
392+
# Step 4: Training
343393
script4_path = get_dataflow_script_path("llama_factory_trainer.py")
344394
args4 = ["--config", str(config_path_obj), "--cache", cache_path]
345395
if not run_script_with_args(script4_path, "Step 4: Training", args4, cwd=str(current_dir)):
346396
return False
347397

348-
# 显示训练完成信息,从配置文件中读取实际的输出目录
398+
# Show completion info
349399
try:
350400
with open(config_path_obj, 'r', encoding='utf-8') as f:
351401
config = yaml.safe_load(f)
@@ -367,8 +417,6 @@ def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: s
367417

368418
def cli_pdf2model_chat(model_path=None, cache_path="./", base_model=None):
369419
"""Start LlamaFactory chat interface"""
370-
print("Starting chat interface...")
371-
372420
current_dir = Path(os.getcwd())
373421

374422
# 处理cache路径

dataflow/core/prompt.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import TypeVar, Protocol, Union, get_type_hints,cast
22
from functools import wraps
3+
import inspect
34
# from dataflow.core import OperatorABC
45

56
class PromptABC():
@@ -34,22 +35,33 @@ def decorator(cls:T) -> T:
3435
# self.ALLOWED_PROMPTS = list(allowed_prompts)
3536

3637
orig_init = cls.__init__
38+
sig = inspect.signature(orig_init) # 在装饰时就解析一次签名,避免每次实例化重复解析
39+
if "prompt_template" not in sig.parameters:
40+
# 若类的 __init__ 根本没有该形参,就仅维持注解/属性设置,不做运行时检查
41+
# (你也可以选择在这里直接 raise 来强制类必须声明该参数)
42+
pass
3743

3844
@wraps(orig_init)
3945
def new_init(self, *args, **kwargs):
40-
pt = kwargs.get("prompt_template", None)
41-
# if pt is None and len(args) > 1:
42-
# pt = args[1]
46+
# 用签名绑定实参:自动把位置/关键字/默认值对齐到参数名
47+
try:
48+
bound = sig.bind_partial(self, *args, **kwargs)
49+
bound.apply_defaults()
50+
except TypeError:
51+
# 参数不完整或不匹配时,交给原始 __init__ 去报错更合适
52+
return orig_init(self, *args, **kwargs)
53+
54+
pt = bound.arguments.get("prompt_template", None)
4355

4456
if pt is not None and not isinstance(pt, cls.ALLOWED_PROMPTS):
4557
if not isinstance(pt, DIYPromptABC):
46-
# 每个类的完整 import 路径,换行分隔
4758
allowed_names = "\n".join(
4859
f" - {c.__module__}.{c.__qualname__}"
4960
for c in cls.ALLOWED_PROMPTS
5061
)
5162
raise TypeError(
52-
f"[{cls.__name__}] Invalid prompt_template type: {type(pt).__module__}.{type(pt).__qualname__}\n"
63+
f"[{cls.__name__}] Invalid prompt_template type: "
64+
f"{type(pt).__module__}.{type(pt).__qualname__}\n"
5365
f"Expected one of:\n{allowed_names}\n"
5466
f"or a custom subclass of `dataflow.core.prompt.DIYPromptABC.`"
5567
)
@@ -58,10 +70,38 @@ def new_init(self, *args, **kwargs):
5870

5971
cls.__init__ = new_init
6072

61-
# 更新类型注解(运行时可见,get_type_hints 可解析)
73+
# 保持你原本的注解暴露逻辑
6274
cls.__annotations__ = dict(getattr(cls, "__annotations__", {}))
6375
cls.__annotations__["prompt_template"] = _make_diyprompt_union(allowed_prompts)
6476

65-
# return cast(T, cast(OperatorWithAllowedPrompts, cls))
6677
return cls
67-
return decorator
78+
return decorator
79+
80+
81+
if __name__ == "__main__":
82+
import pytest
83+
84+
class A(PromptABC): pass
85+
class B(PromptABC): pass
86+
class MyDIY(DIYPromptABC): pass
87+
class Other(PromptABC): pass
88+
89+
@prompt_restrict(A, B)
90+
class Op:
91+
def __init__(self, prompt_template=None):
92+
self.prompt_template = prompt_template
93+
94+
# 关键字参数:允许
95+
Op(prompt_template=A())
96+
Op(prompt_template=B())
97+
Op(prompt_template=MyDIY())
98+
Op() # None 允许
99+
100+
# 位置参数:同样被检测
101+
Op(A()) # ✅
102+
Op(MyDIY()) # ✅
103+
with pytest.raises(TypeError):
104+
Op(Other()) # ❌ 非白名单且非 DIY
105+
106+
with pytest.raises(TypeError):
107+
Op(object()) # ❌ 完全无关类型
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[
2+
{
3+
"roll": "pig",
4+
"term": "eat"
5+
},
6+
{
7+
"roll": "tiger",
8+
"term": "chase"
9+
},
10+
{
11+
"roll": "people",
12+
"term": "drink"
13+
},
14+
{
15+
"roll": "bird",
16+
"term": "dance"
17+
}
18+
]

dataflow/operators/chemistry/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import TYPE_CHECKING
22

33
if TYPE_CHECKING:
4-
from generate.extract_smiles_from_text import ExtractSmilesFromText
5-
from eval.eval_smiles_equivalence import EvaluateSmilesEquivalence
4+
from generate.extract_smiles_from_text_generator import ExtractSmilesFromTextGenerator
5+
from eval.smiles_equivalence_dataset_evaluator import SmilesEquivalenceDatasetEvaluator
66
else:
77
import sys
88
from dataflow.utils.registry import LazyLoader, generate_import_structure_from_type_checking

dataflow/operators/chemistry/eval/eval_smiles_equivalence.py renamed to dataflow/operators/chemistry/eval/smiles_equivalence_dataset_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99

1010
@OPERATOR_REGISTRY.register()
11-
class EvaluateSmilesEquivalence(OperatorABC):
11+
class SmilesEquivalenceDatasetEvaluator(OperatorABC):
1212
"""
1313
对每个块(row)里的 golden_label 与 synth_smiles 进行 SMILES 等价性评估:
1414
- 以 abbreviation 对齐

dataflow/operators/chemistry/generate/extract_smiles_from_text.py renamed to dataflow/operators/chemistry/generate/extract_smiles_from_text_generator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,20 @@
1010
import json
1111
import re
1212

13+
from typing import Union
14+
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
15+
from dataflow.prompts.chemistry import ExtractSmilesFromTextPrompt
1316

17+
18+
@prompt_restrict(
19+
ExtractSmilesFromTextPrompt
20+
)
1421
@OPERATOR_REGISTRY.register()
15-
class ExtractSmilesFromText(OperatorABC):
22+
class ExtractSmilesFromTextGenerator(OperatorABC):
1623
'''
1724
Answer Generator is a class that generates answers for given questions.
1825
'''
19-
def __init__(self, llm_serving: LLMServingABC, prompt_template = None):
26+
def __init__(self, llm_serving: LLMServingABC, prompt_template: Union[ExtractSmilesFromTextPrompt, DIYPromptABC] = ExtractSmilesFromTextPrompt):
2027
self.logger = get_logger()
2128
self.llm_serving = llm_serving
2229
self.prompt_template = prompt_template

dataflow/operators/code/eval/code_quality_sample_evaluator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@
88
from dataflow.utils.storage import DataFlowStorage
99
from dataflow.core import OperatorABC
1010
from dataflow.core import LLMServingABC
11+
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
1112
from dataflow.prompts.code import CodeQualityEvaluatorPrompt, DiyCodePrompt
1213

14+
from typing import Union
15+
@prompt_restrict(
16+
CodeQualityEvaluatorPrompt,
17+
DiyCodePrompt
18+
)
1319
@OPERATOR_REGISTRY.register()
1420
class CodeQualitySampleEvaluator(OperatorABC):
1521
"""
@@ -18,7 +24,7 @@ class CodeQualitySampleEvaluator(OperatorABC):
1824
and textual feedback, acting as an automated code reviewer.
1925
"""
2026

21-
def __init__(self, llm_serving: LLMServingABC, prompt_template=None):
27+
def __init__(self, llm_serving: LLMServingABC, prompt_template: Union[CodeQualityEvaluatorPrompt, DiyCodePrompt, DIYPromptABC] = None):
2228
"""
2329
Initializes the operator with a language model serving endpoint.
2430
"""

dataflow/operators/code/generate/code_code_to_instruction_generator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
from dataflow.core import LLMServingABC
1010
from dataflow.prompts.code import CodeCodeToInstructionGeneratorPrompt, DiyCodePrompt
1111

12+
from typing import Union
13+
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
14+
@prompt_restrict(
15+
CodeCodeToInstructionGeneratorPrompt,
16+
DiyCodePrompt
17+
)
1218
@OPERATOR_REGISTRY.register()
1319
class CodeCodeToInstructionGenerator(OperatorABC):
1420
"""
@@ -17,7 +23,7 @@ class CodeCodeToInstructionGenerator(OperatorABC):
1723
'self-instruct' style data synthesis pipeline for code.
1824
"""
1925

20-
def __init__(self, llm_serving: LLMServingABC, prompt_template=None):
26+
def __init__(self, llm_serving: LLMServingABC, prompt_template: Union[CodeCodeToInstructionGeneratorPrompt, DiyCodePrompt, DIYPromptABC] = None):
2127
"""
2228
Initializes the operator with a language model serving endpoint.
2329
"""

dataflow/operators/code/generate/code_gen_instruction.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77
from dataflow.utils.storage import DataFlowStorage
88
from dataflow.core import OperatorABC
99
from dataflow.core import LLMServingABC
10-
from dataflow.prompts.code import CodeInstructionGenerate, DiyCodePrompt
10+
from dataflow.prompts.code import CodeInstructionGeneratePrompt, DiyCodePrompt
1111

12+
from typing import Union
13+
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
14+
15+
@prompt_restrict(
16+
CodeInstructionGeneratePrompt,
17+
)
1218
@OPERATOR_REGISTRY.register()
1319
class CodeInstructionGenerator(OperatorABC):
1420
"""
@@ -19,7 +25,7 @@ class CodeInstructionGenerator(OperatorABC):
1925
and enhance instruction datasets for programming tasks.
2026
"""
2127

22-
def __init__(self, llm_serving: LLMServingABC, prompt_template=None, num_few_shot: int = 3, num_generate: int = 10):
28+
def __init__(self, llm_serving: LLMServingABC, prompt_template: Union[CodeInstructionGeneratePrompt, DIYPromptABC]=None, num_few_shot: int = 3, num_generate: int = 10):
2329
"""
2430
Initializes the operator with a language model serving endpoint.
2531
@@ -32,7 +38,7 @@ def __init__(self, llm_serving: LLMServingABC, prompt_template=None, num_few_sho
3238
self.num_generate = num_generate
3339
self.llm_serving = llm_serving
3440
self.num_few_shot = num_few_shot
35-
self.prompt_template = CodeInstructionGenerate()
41+
self.prompt_template = CodeInstructionGeneratePrompt()
3642

3743
@staticmethod
3844
def get_desc(lang: str = "en"):

dataflow/operators/code/generate/code_instruction_enhancement.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
from dataflow.core import LLMServingABC
99
from dataflow.prompts.code import CodeInstructionEnhancement, DiyCodePrompt
1010

11+
from typing import Union
12+
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
13+
@prompt_restrict(
14+
CodeInstructionEnhancement,
15+
DiyCodePrompt
16+
)
1117
@OPERATOR_REGISTRY.register()
1218
class CodeEnhancementInstructionGenerator(OperatorABC):
1319
"""
@@ -16,7 +22,7 @@ class CodeEnhancementInstructionGenerator(OperatorABC):
1622
It rewrites original instructions into standardized English instruction + code block format.
1723
"""
1824

19-
def __init__(self, llm_serving: LLMServingABC, prompt_template=None):
25+
def __init__(self, llm_serving: LLMServingABC, prompt_template: Union[CodeInstructionEnhancement, DiyCodePrompt, DIYPromptABC] = None):
2026
"""
2127
Initializes the operator with a language model serving endpoint.
2228
"""

0 commit comments

Comments
 (0)