-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_train.py
More file actions
147 lines (122 loc) · 5.36 KB
/
run_train.py
File metadata and controls
147 lines (122 loc) · 5.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import argparse
import pytorch_lightning as pl
from braceexpand import braceexpand
from torch.utils.data import DataLoader
from datasets.webdataset import MultiWebDataset
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict
from torch.utils.data import ConcatDataset
from cldm.hack import disable_verbosity, enable_sliced_attention
from omegaconf import OmegaConf
import torch
from datasets.base import BaseDataset
class BaseLogic(BaseDataset):
def __init__(self, area_ratio, obj_thr):
self.area_ratio = area_ratio
self.obj_thr = obj_thr
print("Number of GPUs available: ", torch.cuda.device_count())
print("Current device: ", torch.cuda.current_device())
print("Device name: ", torch.cuda.get_device_name(0))
def get_args_parser():
parser = argparse.ArgumentParser('PICS Training Script', add_help=False)
parser.add_argument('--resume_path', required=None, type=str)
parser.add_argument('--root_dir', required=True, type=str)
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--limit_train_batches', default=1, type=float)
parser.add_argument('--logger_freq', default=1000, type=int)
parser.add_argument('--learning_rate', default=1e-5, type=float)
parser.add_argument('--is_joint', action='store_true', help="Joint/Seprate training")
parser.add_argument("--dataset_name", type=str, default='lvis', help="Dataset name")
return parser
def main(args):
save_memory = False
disable_verbosity()
if save_memory:
enable_sliced_attention()
sd_locked = False
only_mid_control = False
accumulate_grad_batches = 1
obj_thr = {'obj_thr': 2}
model = create_model('./configs/pics.yaml').cpu()
if args.resume_path and os.path.exists(args.resume_path):
print(f"Loading checkpoint from: {args.resume_path}")
checkpoint = load_state_dict(args.resume_path, location='cpu')
model.load_state_dict(checkpoint, strict=False)
else:
print("No checkpoint found or provided. Training from scratch...")
model.learning_rate = args.learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control
DConf = OmegaConf.load('./configs/datasets.yaml')
if args.is_joint:
# weights = {'LVIS': 30, 'VITONHD': 60, 'Objects365': 1, 'Cityscapes': 180, 'MapillaryVistas': 180,'BDD100K': 180}
weights = {'LVIS': 3, 'VITONHD': 6, 'Objects365': 1, 'Cityscapes': 18, 'MapillaryVistas': 18, 'BDD100K': 18}
else:
if args.dataset_name == 'lvis':
weights = {'LVIS': 1, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
elif args.dataset_name == 'vitonhd':
weights = {'LVIS': 0, 'VITONHD': 1, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
elif args.dataset_name == 'object365':
weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 1, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
elif args.dataset_name == 'cityscapes':
weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 1, 'MapillaryVistas': 0, 'BDD100K': 0}
elif args.dataset_name == 'mapillaryvistas':
weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 1, 'BDD100K': 0}
elif args.dataset_name == 'bdd100k':
weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 1}
else:
raise ValueError(f"Unsupported dataset name: {args.dataset_name}")
all_urls = []
dataset_shards = [
('LVIS', DConf.Train.LVIS.shards),
('VITONHD', DConf.Train.VITONHD.shards),
('Objects365', DConf.Train.Objects365.shards),
('Cityscapes', DConf.Train.Cityscapes.shards),
('MapillaryVistas', DConf.Train.MapillaryVistas.shards),
('BDD100K', DConf.Train.BDD100K.shards)
]
for name, path in dataset_shards:
expanded = list(braceexpand(path))
all_urls.extend(expanded * weights.get(name, 1))
import random
random.shuffle(all_urls)
logic_helper = BaseLogic(
area_ratio=DConf.Defaults.area_ratio,
obj_thr=DConf.Defaults.obj_thr
)
dataset = MultiWebDataset(
urls=all_urls,
construct_collage_fn=logic_helper._construct_collage,
shuffle_size=10000,
seed=42,
decode_mode="pil",
)
dataloader = DataLoader(
dataset,
num_workers=8,
batch_size=args.batch_size,
)
logger = ImageLogger(batch_frequency=args.logger_freq, log_images_kwargs=obj_thr)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=os.path.join(args.root_dir, 'checkpoints'),
filename='pics-{step:06d}',
every_n_train_steps=2000,
save_top_k=-1,
)
trainer = pl.Trainer(
default_root_dir=args.root_dir,
limit_train_batches=args.limit_train_batches,
accelerator="gpu",
devices=1,
precision=16,
callbacks=[logger, checkpoint_callback],
accumulate_grad_batches=accumulate_grad_batches,
max_epochs=50,
val_check_interval=2000,
)
trainer.fit(model, dataloader)
if __name__ == '__main__':
parser = argparse.ArgumentParser('PICS Training', parents=[get_args_parser()])
args = parser.parse_args()
main(args)