-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathtrain.py
More file actions
540 lines (461 loc) · 21.4 KB
/
train.py
File metadata and controls
540 lines (461 loc) · 21.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 MaPO authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
import shutil
from pathlib import Path
import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from tqdm.auto import tqdm
from transformers import AutoTokenizer
import diffusers
from args import parse_args
from diffusers import (
AutoencoderKL,
DDPMScheduler,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils.import_utils import is_xformers_available
from utils import (
collate_fn,
compute_loss,
compute_time_ids,
encode_prompt,
get_dataset_preprocessor,
get_wandb_url,
import_model_class_from_model_name_or_path,
log_validation,
)
logger = get_logger(__name__)
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
).repo_id
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
cache_dir=args.cache_dir,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
cache_dir=args.cache_dir,
)
# import correct text encoder classes
text_encoder_cls_one = import_model_class_from_model_name_or_path(
args.pretrained_model_name_or_path, args.revision
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler", cache_dir=args.cache_dir
)
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
variant=args.variant,
cache_dir=args.cache_dir,
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder_2",
revision=args.revision,
variant=args.variant,
cache_dir=args.cache_dir,
)
vae_path = (
args.pretrained_model_name_or_path
if args.pretrained_vae_model_name_or_path is None
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
cache_dir=args.cache_dir,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
variant=args.variant,
cache_dir=args.cache_dir,
)
# Set the grad requirements.
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
unet.requires_grad_(True)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encoders to device and cast to weight_dtype
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
# The VAE is always in float32 to avoid NaN losses.
vae.to(accelerator.device, dtype=torch.float32)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
model.save_pretrained(os.path.join(output_dir, "unet"))
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
unet_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
unet_.register_to_config(**load_model.config)
unet_.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# Optimizer creation
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Dataset and DataLoaders creation:
train_dataset = load_dataset(
args.dataset_name,
cache_dir=args.cache_dir,
split=args.dataset_split_name,
)
# Preprocessing the datasets.
preprocess_train_fn = get_dataset_preprocessor(args, tokenizer_one, tokenizer_two)
with accelerator.main_process_first():
if args.max_train_samples is not None:
train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
if args.remove_ties:
train_dataset = train_dataset.filter(lambda example: example["label_0"] != example["label_1"])
else:
pass
train_dataset = train_dataset.with_transform(preprocess_train_fn)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers(args.tracker_name, config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
unet.train()
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
latents = []
for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
latents.append(
vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
)
latents = torch.cat(latents, dim=0)
latents = latents * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None:
latents = latents.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
# Sample a random timestep for each image
bsz = latents.shape[0] // 2
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
).repeat(2)
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
# time ids
add_time_ids = torch.cat(
[
compute_time_ids(args, accelerator, weight_dtype, s, c)
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
]
).repeat(2, 1)
# Get the text embedding for conditioning
prompt_embeds, pooled_prompt_embeds = encode_prompt(
[text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]]
)
prompt_embeds = prompt_embeds.repeat(2, 1, 1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
# Predict the noise residual
model_pred = unet(
noisy_model_input,
timesteps,
prompt_embeds,
added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Full ORPO loss
loss, model_losses_w, model_losses_l, ratio_losses = compute_loss(
args=args, noise_scheduler=noise_scheduler, model_pred=model_pred, target=target
)
# Backprop.
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if args.run_validation and global_step % args.validation_steps == 0:
log_validation(
args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
)
logs = {
"epoch": args.num_train_epochs * progress_bar.n / args.max_train_steps,
"total loss": loss.detach().item(),
"Win Score": ((args.snr_value * model_losses_w) / (torch.exp(args.snr_value * model_losses_w) - 1))
.mean()
.detach()
.item(),
"Lose Score": ((args.snr_value * model_losses_l) / (torch.exp(args.snr_value * model_losses_l) - 1))
.mean()
.detach()
.item(),
"lr": lr_scheduler.get_last_lr()[0],
"OR loss": -ratio_losses.mean().detach().item(),
"model_losses_w": model_losses_w.mean().detach().item(),
"model_losses_l": model_losses_l.mean().detach().item(),
}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet.save_pretrained(args.output_dir)
# Final validation?
if args.run_validation:
log_validation(
args,
unet=None,
vae=vae,
accelerator=accelerator,
weight_dtype=weight_dtype,
epoch=epoch,
is_final_validation=True,
)
if args.push_to_hub:
wandb_info = get_wandb_url()
with open(os.path.join(args.output_dir, "README.md"), "w") as f:
f.write(wandb_info)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)