-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
87 lines (67 loc) · 2.6 KB
/
train.py
File metadata and controls
87 lines (67 loc) · 2.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import sys
import argparse
from option.train_options import TrainOptions
from option.enums import ModelNames
from option.config import BaseOptionsConfig, TrainOptionsConfig
from call_methods import make_dataset, make_model
from utils.utils import set_seed
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def run(opt: argparse.Namespace) -> None:
"""
This function runs the training loop for the DDPM model.
Process
-------
1. Set the seed
2. Create the DDPM model
3. Create the dataset
4. Split the dataset into training and testing
5. Create the dataloader
6. Run the training loop
7. Save the model and optimizer
"""
set_seed(opt.seed)
# Create the DDPM model
model = make_model(model_name=opt.model_name, T=opt.Time_steps_FD, opt=opt)
# Create dataset
dataset = make_dataset(dataset_name=opt.dataset_name, opt=opt)
if isinstance(dataset, tuple) and len(dataset) == 2:
train_dataset, test_dataset = dataset
else:
train_dataset = dataset[0]
# Training loop
for epoch in range(opt.n_epochs):
for i, (images, labels) in enumerate(train_dataset.dataloader):
images = images.to(model._device)
labels = labels.to(model._device)
# Train parameters
train_params = {
BaseOptionsConfig.BATCH_SIZE: opt.batch_size,
BaseOptionsConfig.DATASET: images,
}
if opt.model_name in [
ModelNames.CFG_DDPM,
ModelNames.CFG_DDPM_EMA,
ModelNames.CFG_DDPM_PowerLawEMA,
ModelNames.CFG_Plus_DDPM,
ModelNames.CFG_Plus_DDPM_EMA,
ModelNames.CFG_Plus_DDPM_PowerLawEMA,
]:
train_params[TrainOptionsConfig.CFG_SCALE] = opt.cfg_scale
train_params[TrainOptionsConfig.LABEL_USAGE] = opt.label_usage
# Forward pass and training step
eps, eps_predicted = model.train_method(**train_params)
loss = model._compute_loss(eps, eps_predicted)
if i % opt.print_freq == 0:
print(
f"Epoch {epoch} [{i}/{len(train_dataset.dataloader)}] - Loss: {loss.item()}"
)
# Save checkpoint periodically.
if epoch > 0 and epoch % opt.save_freq == 0:
model.save_networks(epoch)
# Final save of the model and optimizer
model.save_networks("final")
print("Training complete. Model and optimizer saved.")
if __name__ == "__main__":
opt = TrainOptions().parse()
run(opt)