1919import torch
2020
2121import PIL
22- from packaging import version
2322from transformers import CLIPFeatureExtractor , CLIPTokenizer
2423
2524from ...configuration_utils import FrozenDict
@@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
9190 safety_checker : OnnxRuntimeModel
9291 feature_extractor : CLIPFeatureExtractor
9392
93+ _optional_components = ["safety_checker" , "feature_extractor" ]
94+
9495 def __init__ (
9596 self ,
9697 vae_encoder : OnnxRuntimeModel ,
@@ -149,27 +150,6 @@ def __init__(
149150 " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
150151 )
151152
152- is_unet_version_less_0_9_0 = hasattr (unet .config , "_diffusers_version" ) and version .parse (
153- version .parse (unet .config ._diffusers_version ).base_version
154- ) < version .parse ("0.9.0.dev0" )
155- is_unet_sample_size_less_64 = hasattr (unet .config , "sample_size" ) and unet .config .sample_size < 64
156- if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64 :
157- deprecation_message = (
158- "The configuration file of the unet has set the default `sample_size` to smaller than"
159- " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
160- " following: \n - CompVis/stable-diffusion-v1-4 \n - CompVis/stable-diffusion-v1-3 \n -"
161- " CompVis/stable-diffusion-v1-2 \n - CompVis/stable-diffusion-v1-1 \n - runwayml/stable-diffusion-v1-5"
162- " \n - runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
163- " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
164- " in the config might lead to incorrect results in future versions. If you have downloaded this"
165- " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
166- " the `unet/config.json` file"
167- )
168- deprecate ("sample_size<64" , "1.0.0" , deprecation_message , standard_warn = False )
169- new_config = dict (unet .config )
170- new_config ["sample_size" ] = 64
171- unet ._internal_dict = FrozenDict (new_config )
172-
173153 self .register_modules (
174154 vae_encoder = vae_encoder ,
175155 vae_decoder = vae_decoder ,
@@ -180,7 +160,6 @@ def __init__(
180160 safety_checker = safety_checker ,
181161 feature_extractor = feature_extractor ,
182162 )
183- self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
184163 self .register_to_config (requires_safety_checker = requires_safety_checker )
185164
186165 # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
@@ -267,8 +246,8 @@ def __call__(
267246 prompt : Union [str , List [str ]],
268247 image : PIL .Image .Image ,
269248 mask_image : PIL .Image .Image ,
270- height : Optional [int ] = None ,
271- width : Optional [int ] = None ,
249+ height : Optional [int ] = 512 ,
250+ width : Optional [int ] = 512 ,
272251 num_inference_steps : int = 50 ,
273252 guidance_scale : float = 7.5 ,
274253 negative_prompt : Optional [Union [str , List [str ]]] = None ,
@@ -296,9 +275,9 @@ def __call__(
296275 repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
297276 to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
298277 instead of 3, so the expected shape would be `(B, H, W, 1)`.
299- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor ):
278+ height (`int`, *optional*, defaults to 512 ):
300279 The height in pixels of the generated image.
301- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor ):
280+ width (`int`, *optional*, defaults to 512 ):
302281 The width in pixels of the generated image.
303282 num_inference_steps (`int`, *optional*, defaults to 50):
304283 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -343,9 +322,6 @@ def __call__(
343322 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
344323 (nsfw) content, according to the `safety_checker`.
345324 """
346- # 0. Default height and width to unet
347- height = height or self .unet .config .sample_size * self .vae_scale_factor
348- width = width or self .unet .config .sample_size * self .vae_scale_factor
349325
350326 if isinstance (prompt , str ):
351327 batch_size = 1
@@ -381,12 +357,7 @@ def __call__(
381357 )
382358
383359 num_channels_latents = NUM_LATENT_CHANNELS
384- latents_shape = (
385- batch_size * num_images_per_prompt ,
386- num_channels_latents ,
387- height // self .vae_scale_factor ,
388- width // self .vae_scale_factor ,
389- )
360+ latents_shape = (batch_size * num_images_per_prompt , num_channels_latents , height // 8 , width // 8 )
390361 latents_dtype = text_embeddings .dtype
391362 if latents is None :
392363 latents = generator .randn (* latents_shape ).astype (latents_dtype )
0 commit comments