@@ -183,6 +183,25 @@ def enable_sequential_cpu_offload(self):
183183 if cpu_offloaded_model is not None :
184184 cpu_offload (cpu_offloaded_model , device )
185185
186+ @property
187+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
188+ def _execution_device (self ):
189+ r"""
190+ Returns the device on which the pipeline's models will be executed. After calling
191+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
192+ hooks.
193+ """
194+ if self .device != torch .device ("meta" ) or not hasattr (self .unet , "_hf_hook" ):
195+ return self .device
196+ for module in self .unet .modules ():
197+ if (
198+ hasattr (module , "_hf_hook" )
199+ and hasattr (module ._hf_hook , "execution_device" )
200+ and module ._hf_hook .execution_device is not None
201+ ):
202+ return torch .device (module ._hf_hook .execution_device )
203+ return self .device
204+
186205 def enable_xformers_memory_efficient_attention (self ):
187206 r"""
188207 Enable memory efficient attention as implemented in xformers.
@@ -303,6 +322,8 @@ def __call__(
303322 f" { type (callback_steps )} ."
304323 )
305324
325+ device = self ._execution_device
326+
306327 # get prompt text embeddings
307328 text_inputs = self .tokenizer (
308329 prompt ,
@@ -319,7 +340,7 @@ def __call__(
319340 f" { self .tokenizer .model_max_length } tokens: { removed_text } "
320341 )
321342 text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
322- text_embeddings = self .text_encoder (text_input_ids .to (self . device ))[0 ]
343+ text_embeddings = self .text_encoder (text_input_ids .to (device ))[0 ]
323344
324345 # duplicate text embeddings for each generation per prompt, using mps friendly method
325346 bs_embed , seq_len , _ = text_embeddings .shape
@@ -359,7 +380,7 @@ def __call__(
359380 truncation = True ,
360381 return_tensors = "pt" ,
361382 )
362- uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self . device ))[0 ]
383+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (device ))[0 ]
363384
364385 # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
365386 seq_len = uncond_embeddings .shape [1 ]
@@ -379,17 +400,15 @@ def __call__(
379400 latents_shape = (batch_size * num_images_per_prompt , num_channels_latents , height // 8 , width // 8 )
380401 latents_dtype = text_embeddings .dtype
381402 if latents is None :
382- if self . device .type == "mps" :
403+ if device .type == "mps" :
383404 # randn does not exist on mps
384- latents = torch .randn (latents_shape , generator = generator , device = "cpu" , dtype = latents_dtype ).to (
385- self .device
386- )
405+ latents = torch .randn (latents_shape , generator = generator , device = "cpu" , dtype = latents_dtype ).to (device )
387406 else :
388- latents = torch .randn (latents_shape , generator = generator , device = self . device , dtype = latents_dtype )
407+ latents = torch .randn (latents_shape , generator = generator , device = device , dtype = latents_dtype )
389408 else :
390409 if latents .shape != latents_shape :
391410 raise ValueError (f"Unexpected latents shape, got { latents .shape } , expected { latents_shape } " )
392- latents = latents .to (self . device )
411+ latents = latents .to (device )
393412
394413 # prepare mask and masked_image
395414 mask , masked_image = prepare_mask_and_masked_image (image , mask_image )
@@ -398,9 +417,9 @@ def __call__(
398417 # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
399418 # and half precision
400419 mask = torch .nn .functional .interpolate (mask , size = (height // 8 , width // 8 ))
401- mask = mask .to (device = self . device , dtype = text_embeddings .dtype )
420+ mask = mask .to (device = device , dtype = text_embeddings .dtype )
402421
403- masked_image = masked_image .to (device = self . device , dtype = text_embeddings .dtype )
422+ masked_image = masked_image .to (device = device , dtype = text_embeddings .dtype )
404423
405424 # encode the mask image into latents space so we can concatenate it to the latents
406425 masked_image_latents = self .vae .encode (masked_image ).latent_dist .sample (generator = generator )
@@ -416,7 +435,7 @@ def __call__(
416435 )
417436
418437 # aligning device to prevent device errors when concating it with the latent model input
419- masked_image_latents = masked_image_latents .to (device = self . device , dtype = text_embeddings .dtype )
438+ masked_image_latents = masked_image_latents .to (device = device , dtype = text_embeddings .dtype )
420439
421440 num_channels_mask = mask .shape [1 ]
422441 num_channels_masked_image = masked_image_latents .shape [1 ]
@@ -431,7 +450,7 @@ def __call__(
431450 )
432451
433452 # set timesteps and move to the correct device
434- self .scheduler .set_timesteps (num_inference_steps , device = self . device )
453+ self .scheduler .set_timesteps (num_inference_steps , device = device )
435454 timesteps_tensor = self .scheduler .timesteps
436455
437456 # scale the initial noise by the standard deviation required by the scheduler
@@ -484,9 +503,7 @@ def __call__(
484503 image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
485504
486505 if self .safety_checker is not None :
487- safety_checker_input = self .feature_extractor (self .numpy_to_pil (image ), return_tensors = "pt" ).to (
488- self .device
489- )
506+ safety_checker_input = self .feature_extractor (self .numpy_to_pil (image ), return_tensors = "pt" ).to (device )
490507 image , has_nsfw_concept = self .safety_checker (
491508 images = image , clip_input = safety_checker_input .pixel_values .to (text_embeddings .dtype )
492509 )
0 commit comments