@@ -199,7 +199,6 @@ def check_required_files():
199199 # 检查所有需要的内置脚本
200200 required_scripts = [
201201 "path_to_jsonl_script.py" ,
202- "merge_filter_qa_pairs.py" ,
203202 "llama_factory_trainer.py"
204203 ]
205204
@@ -321,31 +320,82 @@ def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: s
321320 print ("-" * 60 )
322321
323322 try :
324- # Step 1: PDF Detection - 使用内置脚本
323+ # Step 1: PDF Detection
325324 script1_path = get_dataflow_script_path ("path_to_jsonl_script.py" )
326325 args1 = ["./" , "--output" , str (cache_path_obj / ".cache" / "gpu" / "pdf_list.jsonl" )]
327326 if not run_script_with_args (script1_path , "Step 1: PDF Detection" , args1 , cwd = str (current_dir )):
328327 return False
329328
330- # Step 2: Data Processing - 使用用户目录下的脚本
329+ # Step 2: Data Processing
331330 script2 = current_dir / "pdf_to_qa_pipeline.py"
332331 args2 = ["--cache" , cache_path ]
333332 if not run_script_with_args (script2 , "Step 2: Data Processing" , args2 , cwd = str (current_dir )):
334333 return False
335334
336- # Step 3: Data Conversion - 使用内置脚本
337- script3_path = get_dataflow_script_path ("merge_filter_qa_pairs.py" )
338- args3 = ["--cache" , cache_path ]
339- if not run_script_with_args (script3_path , "Step 3: Data Conversion" , args3 , cwd = str (current_dir )):
335+ # Step 2.5: Create dataset_info.json (dynamically)
336+ print (f"\n { Fore .BLUE } Step 2.5: Creating dataset_info.json{ Style .RESET_ALL } " )
337+
338+ # 读取训练配置,获取数据集名称
339+ try :
340+ with open (config_path_obj , 'r' , encoding = 'utf-8' ) as f :
341+ train_config = yaml .safe_load (f )
342+
343+ # 获取数据集名称
344+ dataset_name = train_config .get ('dataset' )
345+ if isinstance (dataset_name , list ):
346+ dataset_name = dataset_name [0 ] # 如果是列表,取第一个
347+
348+ if not dataset_name :
349+ print ("Warning: No dataset name found in train_config.yaml, using default 'kb_qa'" )
350+ dataset_name = 'kb_qa'
351+
352+ print (f"Dataset name from config: { dataset_name } " )
353+
354+ except Exception as e :
355+ print (f"Warning: Could not read train_config.yaml: { e } " )
356+ print ("Using default dataset name: kb_qa" )
357+ dataset_name = 'kb_qa'
358+
359+ # 创建 dataset_info.json
360+ dataset_info_path = cache_path_obj / ".cache" / "data" / "dataset_info.json"
361+ dataset_info_path .parent .mkdir (parents = True , exist_ok = True )
362+
363+ dataset_info = {
364+ dataset_name : { # ← 使用从配置读取的名称
365+ "file_name" : "qa.json" ,
366+ "formatting" : "alpaca" ,
367+ "columns" : {
368+ "prompt" : "instruction" ,
369+ "query" : "input" ,
370+ "response" : "output"
371+ }
372+ }
373+ }
374+
375+ with open (dataset_info_path , 'w' , encoding = 'utf-8' ) as f :
376+ json .dump (dataset_info , f , indent = 2 , ensure_ascii = False )
377+
378+ print (f"Created: { dataset_info_path } " )
379+ print (f"Dataset registered as: { dataset_name } " )
380+ print (f"{ Fore .GREEN } ✅ Step 2.5: Creating dataset_info.json completed{ Style .RESET_ALL } " )
381+
382+ # Step 3: Data Conversion - skip
383+ print (f"\n { Fore .BLUE } Step 3: Data Conversion{ Style .RESET_ALL } " )
384+ qa_json_path = cache_path_obj / ".cache" / "data" / "qa.json"
385+ if qa_json_path .exists ():
386+ print (f"✅ qa.json already in correct format, skipping conversion" )
387+ print (f"{ Fore .GREEN } ✅ Step 3: Data Conversion completed{ Style .RESET_ALL } " )
388+ else :
389+ print (f"❌ qa.json not found at { qa_json_path } " )
340390 return False
341391
342- # Step 4: Training - 使用内置脚本
392+ # Step 4: Training
343393 script4_path = get_dataflow_script_path ("llama_factory_trainer.py" )
344394 args4 = ["--config" , str (config_path_obj ), "--cache" , cache_path ]
345395 if not run_script_with_args (script4_path , "Step 4: Training" , args4 , cwd = str (current_dir )):
346396 return False
347397
348- # 显示训练完成信息,从配置文件中读取实际的输出目录
398+ # Show completion info
349399 try :
350400 with open (config_path_obj , 'r' , encoding = 'utf-8' ) as f :
351401 config = yaml .safe_load (f )
@@ -367,8 +417,6 @@ def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: s
367417
368418def cli_pdf2model_chat (model_path = None , cache_path = "./" , base_model = None ):
369419 """Start LlamaFactory chat interface"""
370- print ("Starting chat interface..." )
371-
372420 current_dir = Path (os .getcwd ())
373421
374422 # 处理cache路径
0 commit comments