Skip to content

Commit 1003540

Browse files
committed
fine tune generator code updated
1 parent 8c989b3 commit 1003540

1 file changed

Lines changed: 99 additions & 78 deletions

File tree

src/cli/02_fine_tune_generator.py

Lines changed: 99 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,135 @@
1-
import typer
2-
import yaml
3-
from pathlib import Path
1+
# generator_finetune.py
2+
import math
43
import logging
4+
from pathlib import Path
55

6+
import typer
7+
import yaml
68
import torch
7-
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments, IntervalStrategy
9+
from transformers import (
10+
DataCollatorForLanguageModeling,
11+
IntervalStrategy,
12+
Trainer,
13+
TrainingArguments,
14+
)
815

9-
from utils.wandb_setup import setup_wandb
10-
from utils.metrics import compute_metrics
11-
from models import build_generator
12-
from data import GeneratorDataModule
16+
from src.data import GeneratorDataModule
17+
from src.models import build_generator
18+
from src.utils.wandb_setup import setup_wandb
1319

1420

15-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
21+
logging.basicConfig(level=logging.INFO,
22+
format="%(asctime)s - %(levelname)s - %(message)s")
1623
logger = logging.getLogger(__name__)
1724

1825
app = typer.Typer()
1926

2027

28+
def perplexity_metrics(eval_pred):
29+
"""
30+
For causal‑LM fine‑tuning we usually care about perplexity rather than
31+
accuracy/F1. `Trainer.evaluate` returns (loss, logits, labels) so we grab
32+
the loss and exponentiate it.
33+
"""
34+
# Depending on HF version eval_pred can be EvalPrediction or a tuple
35+
if isinstance(eval_pred, tuple):
36+
loss = eval_pred[0]
37+
else:
38+
loss = eval_pred.loss
39+
return {"perplexity": math.exp(loss)}
40+
41+
2142
@app.command()
22-
def main(config_path: Path = type.Argument(..., help="Path to YAML config")):
43+
def main(
44+
config_path: Path = typer.Argument(..., help="Path to YAML config"),
45+
):
46+
# ------------------------------------------------------------------ CONFIG
2347
cfg = yaml.safe_load(config_path.read_text())
2448

25-
# --- SETUP W&B ---
49+
# ------------------------------ W&B (optional – falls back to “none”)
2650
run_name, report_to = setup_wandb(cfg)
2751

28-
# --- BUILD MODEL ---
29-
model, tokenizer = build_generator(cfg['model'])
30-
31-
# --- SETUP DATA ---
32-
data_module = GeneratorDataModule(cfg['data'], tokenizer)
33-
data_module.setup()
34-
35-
train_dataset = data_module.get_train_dataset()
36-
eval_dataset = data_module.get_eval_dataset()
37-
38-
# --- SETUP TRAINER ---
39-
data_collator = DataCollatorForLanguageModeling(
40-
tokenizer=tokenizer,
41-
mlm=False,
52+
# ------------------------------------------------------- MODEL & TOKENISER
53+
model, tokenizer = build_generator(cfg["model"])
54+
55+
# ----------------------------------------------------------------- DATA
56+
dm = GeneratorDataModule(cfg["data"], tokenizer)
57+
dm.setup()
58+
train_ds = dm.get_train_dataset()
59+
eval_ds = dm.get_eval_dataset()
60+
61+
# ---------------------------------- DATALOADER COLLATOR (causal‑LM, no MLM)
62+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
63+
64+
# ------------------------------------------------ TRAINING ARGUMENTS
65+
training_args = TrainingArguments(
66+
output_dir=cfg["training"]["output_dir"],
67+
overwrite_output_dir=cfg["training"].get("overwrite_output_dir", True),
68+
do_train=True,
69+
do_eval=eval_ds is not None,
70+
per_device_train_batch_size=cfg["training"].get("per_device_train_batch_size", 8),
71+
per_device_eval_batch_size=cfg["training"].get("per_device_eval_batch_size", 16),
72+
gradient_accumulation_steps=cfg["training"].get("gradient_accumulation_steps", 1),
73+
num_train_epochs=cfg["training"].get("num_train_epochs", 3),
74+
learning_rate=cfg["training"].get("learning_rate", 5e-5),
75+
warmup_ratio=cfg["training"].get("warmup_ratio", 0.1),
76+
fp16=cfg["training"].get("fp16", torch.cuda.is_available()),
77+
logging_dir=cfg["training"].get("logging_dir", f"{cfg['training']['output_dir']}/logs"),
78+
logging_steps=cfg["training"].get("logging_steps", 100),
79+
eval_strategy=(
80+
IntervalStrategy.STEPS if eval_ds is not None else IntervalStrategy.NO
81+
),
82+
eval_steps=cfg["training"].get("eval_steps", 500),
83+
save_strategy=IntervalStrategy.STEPS,
84+
save_steps=cfg["training"].get("save_steps", 500),
85+
save_total_limit=cfg["training"].get("save_total_limit", 2),
86+
load_best_model_at_end=cfg["training"].get(
87+
"load_best_model_at_end", eval_ds is not None
88+
),
89+
metric_for_best_model=cfg["training"].get(
90+
"metric_for_best_model", "eval_loss" if eval_ds else None
91+
),
92+
greater_is_better=False, # lower perplexity is better
93+
report_to=[report_to] if report_to != "none" else [],
94+
run_name=run_name,
95+
remove_unused_columns=False,
96+
ddp_find_unused_parameters=cfg["training"].get("ddp_find_unused_parameters", False),
4297
)
98+
logger.info("Training args created (fp16=%s).", training_args.fp16)
4399

44-
training_args_dict = {
45-
"output_dir": cfg['training']['output_dir'],
46-
"overwrite_output_dir": cfg['training'].get("overwrite_output_dir", True),
47-
"do_train": True,
48-
"do_eval": eval_dataset is not None,
49-
"per_device_train_batch_size": cfg['training'].get("per_device_train_batch_size", 8),
50-
"per_device_eval_batch_size": cfg['training'].get("per_device_eval_batch_size", 16),
51-
"gradient_accumulation_steps": cfg['training'].get("gradient_accumulation_steps", 1),
52-
"num_train_epochs": cfg['training'].get("num_train_epochs", 3),
53-
"learning_rate": cfg['training'].get("learning_rate", 5e-5),
54-
"warmup_ratio": cfg['training'].get("warmup_ratio", 0.1),
55-
"fp16": cfg['training'].get("fp16", torch.cuda.is_available()),
56-
"logging_dir": cfg['training'].get("logging_dir", f"{cfg['training']['output_dir']}/logs"),
57-
"logging_steps": cfg['training'].get("logging_steps", 100),
58-
"eval_strategy": IntervalStrategy.STEPS if eval_dataset is not None else IntervalStrategy.NO,
59-
"eval_steps": cfg['training'].get("eval_steps", 500),
60-
"save_strategy": IntervalStrategy.STEPS,
61-
"save_steps": cfg['training'].get("save_steps", 500),
62-
"save_total_limit": cfg['training'].get("save_total_limit", 2),
63-
"load_best_model_at_end": cfg['training'].get("load_best_model_at_end", eval_dataset is not None),
64-
"metric_for_best_model": cfg['training'].get("metric_for_best_model", "eval_loss" if eval_dataset else None),
65-
"greater_is_better": cfg['training'].get("greater_is_better", False),
66-
"report_to": [report_to] if report_to != "none" else [],
67-
"run_name": run_name,
68-
"remove_unused_columns": False,
69-
"ddp_find_unused_parameters": cfg['training'].get("ddp_find_unused_parameters", False),
70-
}
71-
72-
training_args = TrainingArguments(**training_args_dict)
73-
logger.info(f"Training arguments: {training_args}. FP16 Enabled: {training_args.fp16}")
74-
100+
# --------------------------------------------------------------- TRAINER
75101
trainer = Trainer(
76102
model=model,
77103
args=training_args,
78-
train_dataset=train_dataset,
79-
eval_dataset=eval_dataset,
104+
train_dataset=train_ds,
105+
eval_dataset=eval_ds,
80106
tokenizer=tokenizer,
81107
data_collator=data_collator,
82-
compute_metrics=compute_metrics if eval_dataset is not None else None,
108+
compute_metrics=perplexity_metrics if eval_ds is not None else None,
83109
)
84110

85-
# --- TRAIN ---
86-
logger.info("Training model...")
111+
# ------------------------------------------------- TRAINING LOOP
112+
logger.info("🚀 Starting training …")
87113
train_result = trainer.train()
88-
logger.info(f"Training results: {train_result}")
114+
logger.info("✅ Training finished")
89115

90-
# Save final model & metrics
91-
logger.info(f"Saving best model to {training_args.output_dir}")
92-
trainer.save_model() # Saves the best model due to load_best_model_at_end=True
116+
# --------------------------- SAVE FINALISED CHECKPOINT & METRICS
117+
trainer.save_model() # saves best if load_best_model_at_end=True
93118
trainer.save_state()
94119

95-
# Log final metrics
96-
metrics = train_result.metrics
97-
trainer.log_metrics("train", metrics)
98-
trainer.save_metrics("train", metrics)
120+
trainer.log_metrics("train", train_result.metrics)
121+
trainer.save_metrics("train", train_result.metrics)
99122

100-
# Evaluate on test set if available
101-
test_dataset = data_module.get_sanity_dataset()
102-
if test_dataset and cfg['training'].get("do_test_eval", True):
103-
logger.info("Evaluating on test set...")
104-
test_metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")
123+
# ------------------------------------------ OPTIONAL TEST EVALUATION
124+
test_ds = dm.get_sanity_dataset()
125+
if test_ds and cfg["training"].get("do_test_eval", True):
126+
logger.info("🧪 Running test evaluation …")
127+
test_metrics = trainer.evaluate(test_ds, metric_key_prefix="test")
105128
trainer.log_metrics("test", test_metrics)
106129
trainer.save_metrics("test", test_metrics)
107-
logger.info(f"Test set evaluation complete: {test_metrics}")
108-
109130

110-
logger.info("Script finished successfully.")
131+
logger.info("🎉 Script completed successfully.")
111132

112133

113134
if __name__ == "__main__":
114-
app()
135+
app()

0 commit comments

Comments
 (0)