-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
158 lines (138 loc) · 6.65 KB
/
train.py
File metadata and controls
158 lines (138 loc) · 6.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy
import os
import copy
import torch
from config import ex
from data import _datamodules
from model.face_tts import FaceTTS
from model.face_tts_w_discriminator import FaceTTSWithDiscriminator
from callbacks.custom_callbacks import (
SaveEpochZeroCallback,
SaveBestCheckpointPath,
StepwiseEvalCallback,
SaveEpoch96Callback,
EarlyStoppingCallback,
CompositeBestMelCallback
)
@ex.automain
def main(_config):
"""
Unified script that trains either FaceTTS or FaceTTSWithDiscriminator
based on _config["use_gan"].
"""
print("[DEBUG] Starting script...")
# Copy the config so we don't mutate it globally
_config = copy.deepcopy(_config)
# Set random seed
pl.seed_everything(_config["seed"])
# --------------------------------------------------------------------------
# Data Module
# --------------------------------------------------------------------------
dm = _datamodules["dataset_" + _config["dataset"]](_config)
print("[DEBUG] Data module initialized")
# --------------------------------------------------------------------------
# Checkpoint Directory and Prevention of Overwriting resume_from
# --------------------------------------------------------------------------
# checkpoint_dir = _config.get("checkpoint_dir", "/mnt/qb/work/butz/bst080/faceGANtts/checkpoints")
# if _config["resume_from"] and os.path.dirname(_config["resume_from"]) == checkpoint_dir:
# raise ValueError(f"ERROR: resume_from ({_config['resume_from']}) is inside checkpoint_dir ({checkpoint_dir}). "
# "This could lead to overwriting! Choose a different checkpoint directory.")
checkpoint_callback_epoch = pl.callbacks.ModelCheckpoint(
# dirpath=checkpoint_dir, # Separate directory for new checkpoints
# filename="epoch={epoch}-step={step}",
save_weights_only=False,
save_top_k=3, # Keeps last 3 best checkpoints
verbose=True,
monitor="val/total_loss",
mode="min",
save_last=True,
auto_insert_metric_name=True,
every_n_epochs=1, # Save at every epoch
#filename="epoch={epoch}-step={step}-{val/total_loss:.4f}"
)
lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
model_summary_callback = pl.callbacks.ModelSummary(max_depth=2)
save_epoch0_callback = SaveEpochZeroCallback()
save_epoch96_callback = SaveEpoch96Callback()
stepwise_eval_callback = StepwiseEvalCallback(_config)
save_best_ckpt_callback = SaveBestCheckpointPath()
early_stopping = EarlyStoppingCallback(
patience = _config["early_stopping_patience"],
min_delta=_config.get("early_stopping_min_delta", 0.0)
)
best_mel_callback = CompositeBestMelCallback(_config, last_n=10)
callbacks = [
checkpoint_callback_epoch,
lr_callback,
model_summary_callback,
stepwise_eval_callback,
save_epoch0_callback,
save_epoch96_callback,
save_best_ckpt_callback,
early_stopping,
best_mel_callback
]
# --------------------------------------------------------------------------
# Model selection based on _config["use_gan"]
# --------------------------------------------------------------------------
use_gan = bool(_config["use_gan"]) # 0 => False, 1 => True
if use_gan:
print("[INFO] use_gan=True -> using FaceTTSWithDiscriminator")
model = FaceTTSWithDiscriminator(_config).to(torch.device("cuda"))
else:
print("[INFO] use_gan=False -> using FaceTTS")
model = FaceTTS(_config).to(torch.device("cuda"))
print("[DEBUG] Model initialized")
# --------------------------------------------------------------------------
# GPU / Trainer settings
# --------------------------------------------------------------------------
num_gpus = _config["num_gpus"] if isinstance(_config["num_gpus"], int) else len(_config["num_gpus"])
grad_steps = _config["batch_size"] // (_config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"])
max_steps = _config["max_steps"] if _config["max_steps"] is not None else None
# --------------------------------------------------------------------------
# Loading checkpoint - Prevent Overwriting
# --------------------------------------------------------------------------
if os.path.exists(_config["resume_from"]):
print(f"[INFO] Loading checkpoint from {_config['resume_from']}")
checkpoint = torch.load(_config["resume_from"], map_location="cuda")
checkpoint.pop('callbacks', None) # Remove callbacks to prevent changes
if use_gan:
generator_state_dict = {k: v for k, v in checkpoint['state_dict'].items() if "discriminator" not in k}
model.load_state_dict(generator_state_dict, strict=False)
print("[INFO] Loaded generator weights (discriminator keys ignored).")
else:
model.load_state_dict(checkpoint['state_dict'], strict=False)
print("[INFO] Loaded entire model state_dict.")
else:
print(f"[WARNING] No checkpoint found at {_config['resume_from']}. Training from scratch.")
#print("[DEBUG] Starting training without checkpoint...")
print(f"[INFO] Using {torch.cuda.device_count()} GPU(s)")
# --------------------------------------------------------------------------
# Trainer setup
# --------------------------------------------------------------------------
trainer = pl.Trainer(
accelerator="gpu",
devices=_config["num_gpus"],
num_nodes=_config["num_nodes"],
strategy=DDPStrategy(gradient_as_bucket_view=True, find_unused_parameters=True),
max_steps=max_steps,
callbacks=callbacks,
accumulate_grad_batches=grad_steps,
log_every_n_steps=50,
enable_model_summary=True,
val_check_interval=_config["val_check_interval"],
)
print("[DEBUG] Trainer initialized")
#trainer.logger.log_hyperparams(_config)
if trainer.logger is not None:
trainer.logger.log_hyperparams(_config)
# --------------------------------------------------------------------------
# Train or Test
# --------------------------------------------------------------------------
if not _config["test_only"]:
print("[DEBUG] Starting training...")
trainer.fit(model, datamodule=dm)
else:
print("[DEBUG] Running test...")
trainer.test(model, datamodule=dm)