2727from invokeai .backend .stable_diffusion .diffusers_pipeline import PipelineIntermediateState
2828from invokeai .backend .stable_diffusion .diffusion .conditioning_data import ZImageConditioningInfo
2929from invokeai .backend .util .devices import TorchDevice
30+ from invokeai .backend .z_image .extensions .regional_prompting_extension import ZImageRegionalPromptingExtension
31+ from invokeai .backend .z_image .text_conditioning import ZImageTextConditioning
32+ from invokeai .backend .z_image .z_image_transformer_patch import patch_transformer_for_regional_prompting
3033
3134
3235@invocation (
3336 "z_image_denoise" ,
3437 title = "Denoise - Z-Image" ,
3538 tags = ["image" , "z-image" ],
3639 category = "image" ,
37- version = "1.1 .0" ,
40+ version = "1.2 .0" ,
3841 classification = Classification .Prototype ,
3942)
4043class ZImageDenoiseInvocation (BaseInvocation ):
41- """Run the denoising process with a Z-Image model."""
44+ """Run the denoising process with a Z-Image model.
45+
46+ Supports regional prompting by connecting multiple conditioning inputs with masks.
47+ """
4248
4349 # If latents is provided, this means we are doing image-to-image.
4450 latents : Optional [LatentsField ] = InputField (
@@ -53,10 +59,10 @@ class ZImageDenoiseInvocation(BaseInvocation):
5359 transformer : TransformerField = InputField (
5460 description = FieldDescriptions .z_image_model , input = Input .Connection , title = "Transformer"
5561 )
56- positive_conditioning : ZImageConditioningField = InputField (
62+ positive_conditioning : ZImageConditioningField | list [ ZImageConditioningField ] = InputField (
5763 description = FieldDescriptions .positive_cond , input = Input .Connection
5864 )
59- negative_conditioning : Optional [ZImageConditioningField ] = InputField (
65+ negative_conditioning : ZImageConditioningField | list [ZImageConditioningField ] | None = InputField (
6066 default = None , description = FieldDescriptions .negative_cond , input = Input .Connection
6167 )
6268 # Z-Image-Turbo uses guidance_scale=0.0 by default (no CFG)
@@ -103,25 +109,50 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
103109 def _load_text_conditioning (
104110 self ,
105111 context : InvocationContext ,
106- conditioning_name : str ,
112+ cond_field : ZImageConditioningField | list [ZImageConditioningField ],
113+ img_height : int ,
114+ img_width : int ,
107115 dtype : torch .dtype ,
108116 device : torch .device ,
109- ) -> torch .Tensor :
110- """Load Z-Image text conditioning."""
111- cond_data = context .conditioning .load (conditioning_name )
112- if len (cond_data .conditionings ) != 1 :
113- raise ValueError (
114- f"Expected exactly 1 conditioning entry for Z-Image, got { len (cond_data .conditionings )} . "
115- "Ensure you are using the Z-Image text encoder."
116- )
117- z_image_conditioning = cond_data .conditionings [0 ]
118- if not isinstance (z_image_conditioning , ZImageConditioningInfo ):
119- raise TypeError (
120- f"Expected ZImageConditioningInfo, got { type (z_image_conditioning ).__name__ } . "
121- "Ensure you are using the Z-Image text encoder."
122- )
123- z_image_conditioning = z_image_conditioning .to (dtype = dtype , device = device )
124- return z_image_conditioning .prompt_embeds
117+ ) -> list [ZImageTextConditioning ]:
118+ """Load Z-Image text conditioning with optional regional masks.
119+
120+ Args:
121+ context: The invocation context.
122+ cond_field: Single conditioning field or list of fields.
123+ img_height: Height of the image token grid (H // patch_size).
124+ img_width: Width of the image token grid (W // patch_size).
125+ dtype: Target dtype.
126+ device: Target device.
127+
128+ Returns:
129+ List of ZImageTextConditioning objects with embeddings and masks.
130+ """
131+ # Normalize to a list
132+ cond_list = [cond_field ] if isinstance (cond_field , ZImageConditioningField ) else cond_field
133+
134+ text_conditionings : list [ZImageTextConditioning ] = []
135+ for cond in cond_list :
136+ # Load the text embeddings
137+ cond_data = context .conditioning .load (cond .conditioning_name )
138+ assert len (cond_data .conditionings ) == 1
139+ z_image_conditioning = cond_data .conditionings [0 ]
140+ assert isinstance (z_image_conditioning , ZImageConditioningInfo )
141+ z_image_conditioning = z_image_conditioning .to (dtype = dtype , device = device )
142+ prompt_embeds = z_image_conditioning .prompt_embeds
143+
144+ # Load the mask, if provided
145+ mask : torch .Tensor | None = None
146+ if cond .mask is not None :
147+ mask = context .tensors .load (cond .mask .tensor_name )
148+ mask = mask .to (device = device )
149+ mask = ZImageRegionalPromptingExtension .preprocess_regional_prompt_mask (
150+ mask , img_height , img_width , dtype , device
151+ )
152+
153+ text_conditionings .append (ZImageTextConditioning (prompt_embeds = prompt_embeds , mask = mask ))
154+
155+ return text_conditionings
125156
126157 def _get_noise (
127158 self ,
@@ -198,33 +229,53 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
198229
199230 transformer_info = context .models .load (self .transformer .transformer )
200231
201- # Load positive conditioning
202- pos_prompt_embeds = self ._load_text_conditioning (
232+ # Calculate image token grid dimensions
233+ patch_size = 2 # Z-Image uses patch_size=2
234+ latent_height = self .height // LATENT_SCALE_FACTOR
235+ latent_width = self .width // LATENT_SCALE_FACTOR
236+ img_token_height = latent_height // patch_size
237+ img_token_width = latent_width // patch_size
238+ img_seq_len = img_token_height * img_token_width
239+
240+ # Load positive conditioning with regional masks
241+ pos_text_conditionings = self ._load_text_conditioning (
203242 context = context ,
204- conditioning_name = self .positive_conditioning .conditioning_name ,
243+ cond_field = self .positive_conditioning ,
244+ img_height = img_token_height ,
245+ img_width = img_token_width ,
205246 dtype = inference_dtype ,
206247 device = device ,
207248 )
208249
250+ # Create regional prompting extension
251+ regional_extension = ZImageRegionalPromptingExtension .from_text_conditionings (
252+ text_conditionings = pos_text_conditionings ,
253+ img_seq_len = img_seq_len ,
254+ )
255+
256+ # Get the concatenated prompt embeddings for the transformer
257+ pos_prompt_embeds = regional_extension .regional_text_conditioning .prompt_embeds
258+
209259 # Load negative conditioning if provided and guidance_scale > 0
210260 neg_prompt_embeds : torch .Tensor | None = None
211261 do_classifier_free_guidance = self .guidance_scale > 0.0 and self .negative_conditioning is not None
212262 if do_classifier_free_guidance :
213- if self .negative_conditioning is None :
214- raise ValueError ("Negative conditioning is required when guidance_scale > 0" )
215- neg_prompt_embeds = self ._load_text_conditioning (
263+ assert self .negative_conditioning is not None
264+ # Load all negative conditionings and concatenate embeddings
265+ # Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
266+ neg_text_conditionings = self ._load_text_conditioning (
216267 context = context ,
217- conditioning_name = self .negative_conditioning .conditioning_name ,
268+ cond_field = self .negative_conditioning ,
269+ img_height = img_token_height ,
270+ img_width = img_token_width ,
218271 dtype = inference_dtype ,
219272 device = device ,
220273 )
221-
222- # Calculate image sequence length for timestep shifting
223- patch_size = 2 # Z-Image uses patch_size=2
224- image_seq_len = ((self .height // LATENT_SCALE_FACTOR ) * (self .width // LATENT_SCALE_FACTOR )) // (patch_size ** 2 )
274+ # Concatenate all negative embeddings
275+ neg_prompt_embeds = torch .cat ([tc .prompt_embeds for tc in neg_text_conditionings ], dim = 0 )
225276
226277 # Calculate shift based on image sequence length
227- mu = self ._calculate_shift (image_seq_len )
278+ mu = self ._calculate_shift (img_seq_len )
228279
229280 # Generate sigma schedule with time shift
230281 sigmas = self ._get_sigmas (mu , self .steps )
@@ -322,6 +373,15 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
322373 )
323374 )
324375
376+ # Apply regional prompting patch if we have regional masks
377+ exit_stack .enter_context (
378+ patch_transformer_for_regional_prompting (
379+ transformer = transformer ,
380+ regional_attn_mask = regional_extension .regional_attn_mask ,
381+ img_seq_len = img_seq_len ,
382+ )
383+ )
384+
325385 # Denoising loop
326386 for step_idx in tqdm (range (total_steps )):
327387 sigma_curr = sigmas [step_idx ]
0 commit comments