1- import typer
2- import yaml
3- from pathlib import Path
1+ # generator_finetune.py
2+ import math
43import logging
4+ from pathlib import Path
55
6+ import typer
7+ import yaml
68import torch
7- from transformers import DataCollatorForLanguageModeling , Trainer , TrainingArguments , IntervalStrategy
9+ from transformers import (
10+ DataCollatorForLanguageModeling ,
11+ IntervalStrategy ,
12+ Trainer ,
13+ TrainingArguments ,
14+ )
815
9- from utils .wandb_setup import setup_wandb
10- from utils .metrics import compute_metrics
11- from models import build_generator
12- from data import GeneratorDataModule
16+ from src .data import GeneratorDataModule
17+ from src .models import build_generator
18+ from src .utils .wandb_setup import setup_wandb
1319
1420
15- logging .basicConfig (level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" )
21+ logging .basicConfig (level = logging .INFO ,
22+ format = "%(asctime)s - %(levelname)s - %(message)s" )
1623logger = logging .getLogger (__name__ )
1724
1825app = typer .Typer ()
1926
2027
28+ def perplexity_metrics (eval_pred ):
29+ """
30+ For causal‑LM fine‑tuning we usually care about perplexity rather than
31+ accuracy/F1. `Trainer.evaluate` returns (loss, logits, labels) so we grab
32+ the loss and exponentiate it.
33+ """
34+ # Depending on HF version eval_pred can be EvalPrediction or a tuple
35+ if isinstance (eval_pred , tuple ):
36+ loss = eval_pred [0 ]
37+ else :
38+ loss = eval_pred .loss
39+ return {"perplexity" : math .exp (loss )}
40+
41+
2142@app .command ()
22- def main (config_path : Path = type .Argument (..., help = "Path to YAML config" )):
43+ def main (
44+ config_path : Path = typer .Argument (..., help = "Path to YAML config" ),
45+ ):
46+ # ------------------------------------------------------------------ CONFIG
2347 cfg = yaml .safe_load (config_path .read_text ())
2448
25- # --- SETUP W&B ---
49+ # ------------------------------ W&B (optional – falls back to “none”)
2650 run_name , report_to = setup_wandb (cfg )
2751
28- # --- BUILD MODEL ---
29- model , tokenizer = build_generator (cfg ['model' ])
30-
31- # --- SETUP DATA ---
32- data_module = GeneratorDataModule (cfg ['data' ], tokenizer )
33- data_module .setup ()
34-
35- train_dataset = data_module .get_train_dataset ()
36- eval_dataset = data_module .get_eval_dataset ()
37-
38- # --- SETUP TRAINER ---
39- data_collator = DataCollatorForLanguageModeling (
40- tokenizer = tokenizer ,
41- mlm = False ,
52+ # ------------------------------------------------------- MODEL & TOKENISER
53+ model , tokenizer = build_generator (cfg ["model" ])
54+
55+ # ----------------------------------------------------------------- DATA
56+ dm = GeneratorDataModule (cfg ["data" ], tokenizer )
57+ dm .setup ()
58+ train_ds = dm .get_train_dataset ()
59+ eval_ds = dm .get_eval_dataset ()
60+
61+ # ---------------------------------- DATALOADER COLLATOR (causal‑LM, no MLM)
62+ data_collator = DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm = False )
63+
64+ # ------------------------------------------------ TRAINING ARGUMENTS
65+ training_args = TrainingArguments (
66+ output_dir = cfg ["training" ]["output_dir" ],
67+ overwrite_output_dir = cfg ["training" ].get ("overwrite_output_dir" , True ),
68+ do_train = True ,
69+ do_eval = eval_ds is not None ,
70+ per_device_train_batch_size = cfg ["training" ].get ("per_device_train_batch_size" , 8 ),
71+ per_device_eval_batch_size = cfg ["training" ].get ("per_device_eval_batch_size" , 16 ),
72+ gradient_accumulation_steps = cfg ["training" ].get ("gradient_accumulation_steps" , 1 ),
73+ num_train_epochs = cfg ["training" ].get ("num_train_epochs" , 3 ),
74+ learning_rate = cfg ["training" ].get ("learning_rate" , 5e-5 ),
75+ warmup_ratio = cfg ["training" ].get ("warmup_ratio" , 0.1 ),
76+ fp16 = cfg ["training" ].get ("fp16" , torch .cuda .is_available ()),
77+ logging_dir = cfg ["training" ].get ("logging_dir" , f"{ cfg ['training' ]['output_dir' ]} /logs" ),
78+ logging_steps = cfg ["training" ].get ("logging_steps" , 100 ),
79+ eval_strategy = (
80+ IntervalStrategy .STEPS if eval_ds is not None else IntervalStrategy .NO
81+ ),
82+ eval_steps = cfg ["training" ].get ("eval_steps" , 500 ),
83+ save_strategy = IntervalStrategy .STEPS ,
84+ save_steps = cfg ["training" ].get ("save_steps" , 500 ),
85+ save_total_limit = cfg ["training" ].get ("save_total_limit" , 2 ),
86+ load_best_model_at_end = cfg ["training" ].get (
87+ "load_best_model_at_end" , eval_ds is not None
88+ ),
89+ metric_for_best_model = cfg ["training" ].get (
90+ "metric_for_best_model" , "eval_loss" if eval_ds else None
91+ ),
92+ greater_is_better = False , # lower perplexity is better
93+ report_to = [report_to ] if report_to != "none" else [],
94+ run_name = run_name ,
95+ remove_unused_columns = False ,
96+ ddp_find_unused_parameters = cfg ["training" ].get ("ddp_find_unused_parameters" , False ),
4297 )
98+ logger .info ("Training args created (fp16=%s)." , training_args .fp16 )
4399
44- training_args_dict = {
45- "output_dir" : cfg ['training' ]['output_dir' ],
46- "overwrite_output_dir" : cfg ['training' ].get ("overwrite_output_dir" , True ),
47- "do_train" : True ,
48- "do_eval" : eval_dataset is not None ,
49- "per_device_train_batch_size" : cfg ['training' ].get ("per_device_train_batch_size" , 8 ),
50- "per_device_eval_batch_size" : cfg ['training' ].get ("per_device_eval_batch_size" , 16 ),
51- "gradient_accumulation_steps" : cfg ['training' ].get ("gradient_accumulation_steps" , 1 ),
52- "num_train_epochs" : cfg ['training' ].get ("num_train_epochs" , 3 ),
53- "learning_rate" : cfg ['training' ].get ("learning_rate" , 5e-5 ),
54- "warmup_ratio" : cfg ['training' ].get ("warmup_ratio" , 0.1 ),
55- "fp16" : cfg ['training' ].get ("fp16" , torch .cuda .is_available ()),
56- "logging_dir" : cfg ['training' ].get ("logging_dir" , f"{ cfg ['training' ]['output_dir' ]} /logs" ),
57- "logging_steps" : cfg ['training' ].get ("logging_steps" , 100 ),
58- "eval_strategy" : IntervalStrategy .STEPS if eval_dataset is not None else IntervalStrategy .NO ,
59- "eval_steps" : cfg ['training' ].get ("eval_steps" , 500 ),
60- "save_strategy" : IntervalStrategy .STEPS ,
61- "save_steps" : cfg ['training' ].get ("save_steps" , 500 ),
62- "save_total_limit" : cfg ['training' ].get ("save_total_limit" , 2 ),
63- "load_best_model_at_end" : cfg ['training' ].get ("load_best_model_at_end" , eval_dataset is not None ),
64- "metric_for_best_model" : cfg ['training' ].get ("metric_for_best_model" , "eval_loss" if eval_dataset else None ),
65- "greater_is_better" : cfg ['training' ].get ("greater_is_better" , False ),
66- "report_to" : [report_to ] if report_to != "none" else [],
67- "run_name" : run_name ,
68- "remove_unused_columns" : False ,
69- "ddp_find_unused_parameters" : cfg ['training' ].get ("ddp_find_unused_parameters" , False ),
70- }
71-
72- training_args = TrainingArguments (** training_args_dict )
73- logger .info (f"Training arguments: { training_args } . FP16 Enabled: { training_args .fp16 } " )
74-
100+ # --------------------------------------------------------------- TRAINER
75101 trainer = Trainer (
76102 model = model ,
77103 args = training_args ,
78- train_dataset = train_dataset ,
79- eval_dataset = eval_dataset ,
104+ train_dataset = train_ds ,
105+ eval_dataset = eval_ds ,
80106 tokenizer = tokenizer ,
81107 data_collator = data_collator ,
82- compute_metrics = compute_metrics if eval_dataset is not None else None ,
108+ compute_metrics = perplexity_metrics if eval_ds is not None else None ,
83109 )
84110
85- # --- TRAIN ---
86- logger .info ("Training model... " )
111+ # ------------------------------------------------- TRAINING LOOP
112+ logger .info ("🚀 Starting training … " )
87113 train_result = trainer .train ()
88- logger .info (f"Training results: { train_result } " )
114+ logger .info ("✅ Training finished " )
89115
90- # Save final model & metrics
91- logger .info (f"Saving best model to { training_args .output_dir } " )
92- trainer .save_model () # Saves the best model due to load_best_model_at_end=True
116+ # --------------------------- SAVE FINALISED CHECKPOINT & METRICS
117+ trainer .save_model () # saves best if load_best_model_at_end=True
93118 trainer .save_state ()
94119
95- # Log final metrics
96- metrics = train_result .metrics
97- trainer .log_metrics ("train" , metrics )
98- trainer .save_metrics ("train" , metrics )
120+ trainer .log_metrics ("train" , train_result .metrics )
121+ trainer .save_metrics ("train" , train_result .metrics )
99122
100- # Evaluate on test set if available
101- test_dataset = data_module .get_sanity_dataset ()
102- if test_dataset and cfg [' training' ].get ("do_test_eval" , True ):
103- logger .info ("Evaluating on test set... " )
104- test_metrics = trainer .evaluate (eval_dataset = test_dataset , metric_key_prefix = "test" )
123+ # ------------------------------------------ OPTIONAL TEST EVALUATION
124+ test_ds = dm .get_sanity_dataset ()
125+ if test_ds and cfg [" training" ].get ("do_test_eval" , True ):
126+ logger .info ("🧪 Running test evaluation … " )
127+ test_metrics = trainer .evaluate (test_ds , metric_key_prefix = "test" )
105128 trainer .log_metrics ("test" , test_metrics )
106129 trainer .save_metrics ("test" , test_metrics )
107- logger .info (f"Test set evaluation complete: { test_metrics } " )
108-
109130
110- logger .info ("Script finished successfully." )
131+ logger .info ("🎉 Script completed successfully." )
111132
112133
113134if __name__ == "__main__" :
114- app ()
135+ app ()
0 commit comments