Skip to content

Commit 053e6b7

Browse files
committed
dreambooth: guard against passing keep_fp32_wrapper arg to older versions of accelerate. part of fix for #1566
1 parent c3922b5 commit 053e6b7

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

examples/dreambooth/train_dreambooth.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import hashlib
3+
import inspect
34
import itertools
45
import math
56
import os
@@ -680,10 +681,18 @@ def main(args):
680681

681682
if global_step % args.save_steps == 0:
682683
if accelerator.is_main_process:
684+
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
685+
# it, the models will be unwrapped, and when they are then used for further training,
686+
# we will crash. pass this, but only to newer versions of accelerate. fixes
687+
# https://github.com/huggingface/diffusers/issues/1566
688+
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
689+
inspect.signature(accelerator.unwrap_model).parameters.keys()
690+
)
691+
extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
683692
pipeline = DiffusionPipeline.from_pretrained(
684693
args.pretrained_model_name_or_path,
685-
unet=accelerator.unwrap_model(unet, True),
686-
text_encoder=accelerator.unwrap_model(text_encoder, True),
694+
unet=accelerator.unwrap_model(unet, **extra_args),
695+
text_encoder=accelerator.unwrap_model(text_encoder, **extra_args),
687696
revision=args.revision,
688697
)
689698
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")

0 commit comments

Comments
 (0)