Skip to content

Commit 17335ab

Browse files
feat: Add Regional Guidance support for Z-Image model
Implements regional prompting for Z-Image (S3-DiT Transformer) allowing different prompts to affect different image regions using attention masks. Backend changes: - Add ZImageRegionalPromptingExtension for mask preparation - Add ZImageTextConditioning and ZImageRegionalTextConditioning data classes - Patch transformer forward to inject 4D regional attention masks - Use additive float mask (0.0 attend, -inf block) in bfloat16 for compatibility - Alternate regional/full attention layers for global coherence Frontend changes: - Update buildZImageGraph to support regional conditioning collectors - Update addRegions to create z_image_text_encoder nodes for regions - Update addZImageLoRAs to handle optional negCond when guidance_scale=0 - Add Z-Image validation (no IP adapters, no autoNegative)
1 parent ab6b672 commit 17335ab

File tree

14 files changed

+24511
-5346
lines changed

14 files changed

+24511
-5346
lines changed

invokeai/app/invocations/fields.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ class ZImageConditioningField(BaseModel):
327327
"""A Z-Image conditioning tensor primitive value"""
328328

329329
conditioning_name: str = Field(description="The name of conditioning tensor")
330+
mask: Optional[TensorField] = Field(
331+
default=None,
332+
description="The mask associated with this conditioning tensor for regional prompting. "
333+
"Excluded regions should be set to False, included regions should be set to True.",
334+
)
330335

331336

332337
class ConditioningField(BaseModel):

invokeai/app/invocations/z_image_denoise.py

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,24 @@
2727
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
2828
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
2929
from invokeai.backend.util.devices import TorchDevice
30+
from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
31+
from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
32+
from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting
3033

3134

3235
@invocation(
3336
"z_image_denoise",
3437
title="Denoise - Z-Image",
3538
tags=["image", "z-image"],
3639
category="image",
37-
version="1.1.0",
40+
version="1.2.0",
3841
classification=Classification.Prototype,
3942
)
4043
class ZImageDenoiseInvocation(BaseInvocation):
41-
"""Run the denoising process with a Z-Image model."""
44+
"""Run the denoising process with a Z-Image model.
45+
46+
Supports regional prompting by connecting multiple conditioning inputs with masks.
47+
"""
4248

4349
# If latents is provided, this means we are doing image-to-image.
4450
latents: Optional[LatentsField] = InputField(
@@ -53,10 +59,10 @@ class ZImageDenoiseInvocation(BaseInvocation):
5359
transformer: TransformerField = InputField(
5460
description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
5561
)
56-
positive_conditioning: ZImageConditioningField = InputField(
62+
positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
5763
description=FieldDescriptions.positive_cond, input=Input.Connection
5864
)
59-
negative_conditioning: Optional[ZImageConditioningField] = InputField(
65+
negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
6066
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
6167
)
6268
# Z-Image-Turbo uses guidance_scale=0.0 by default (no CFG)
@@ -103,25 +109,50 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
103109
def _load_text_conditioning(
104110
self,
105111
context: InvocationContext,
106-
conditioning_name: str,
112+
cond_field: ZImageConditioningField | list[ZImageConditioningField],
113+
img_height: int,
114+
img_width: int,
107115
dtype: torch.dtype,
108116
device: torch.device,
109-
) -> torch.Tensor:
110-
"""Load Z-Image text conditioning."""
111-
cond_data = context.conditioning.load(conditioning_name)
112-
if len(cond_data.conditionings) != 1:
113-
raise ValueError(
114-
f"Expected exactly 1 conditioning entry for Z-Image, got {len(cond_data.conditionings)}. "
115-
"Ensure you are using the Z-Image text encoder."
116-
)
117-
z_image_conditioning = cond_data.conditionings[0]
118-
if not isinstance(z_image_conditioning, ZImageConditioningInfo):
119-
raise TypeError(
120-
f"Expected ZImageConditioningInfo, got {type(z_image_conditioning).__name__}. "
121-
"Ensure you are using the Z-Image text encoder."
122-
)
123-
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
124-
return z_image_conditioning.prompt_embeds
117+
) -> list[ZImageTextConditioning]:
118+
"""Load Z-Image text conditioning with optional regional masks.
119+
120+
Args:
121+
context: The invocation context.
122+
cond_field: Single conditioning field or list of fields.
123+
img_height: Height of the image token grid (H // patch_size).
124+
img_width: Width of the image token grid (W // patch_size).
125+
dtype: Target dtype.
126+
device: Target device.
127+
128+
Returns:
129+
List of ZImageTextConditioning objects with embeddings and masks.
130+
"""
131+
# Normalize to a list
132+
cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field
133+
134+
text_conditionings: list[ZImageTextConditioning] = []
135+
for cond in cond_list:
136+
# Load the text embeddings
137+
cond_data = context.conditioning.load(cond.conditioning_name)
138+
assert len(cond_data.conditionings) == 1
139+
z_image_conditioning = cond_data.conditionings[0]
140+
assert isinstance(z_image_conditioning, ZImageConditioningInfo)
141+
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
142+
prompt_embeds = z_image_conditioning.prompt_embeds
143+
144+
# Load the mask, if provided
145+
mask: torch.Tensor | None = None
146+
if cond.mask is not None:
147+
mask = context.tensors.load(cond.mask.tensor_name)
148+
mask = mask.to(device=device)
149+
mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
150+
mask, img_height, img_width, dtype, device
151+
)
152+
153+
text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))
154+
155+
return text_conditionings
125156

126157
def _get_noise(
127158
self,
@@ -198,33 +229,53 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
198229

199230
transformer_info = context.models.load(self.transformer.transformer)
200231

201-
# Load positive conditioning
202-
pos_prompt_embeds = self._load_text_conditioning(
232+
# Calculate image token grid dimensions
233+
patch_size = 2 # Z-Image uses patch_size=2
234+
latent_height = self.height // LATENT_SCALE_FACTOR
235+
latent_width = self.width // LATENT_SCALE_FACTOR
236+
img_token_height = latent_height // patch_size
237+
img_token_width = latent_width // patch_size
238+
img_seq_len = img_token_height * img_token_width
239+
240+
# Load positive conditioning with regional masks
241+
pos_text_conditionings = self._load_text_conditioning(
203242
context=context,
204-
conditioning_name=self.positive_conditioning.conditioning_name,
243+
cond_field=self.positive_conditioning,
244+
img_height=img_token_height,
245+
img_width=img_token_width,
205246
dtype=inference_dtype,
206247
device=device,
207248
)
208249

250+
# Create regional prompting extension
251+
regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
252+
text_conditionings=pos_text_conditionings,
253+
img_seq_len=img_seq_len,
254+
)
255+
256+
# Get the concatenated prompt embeddings for the transformer
257+
pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds
258+
209259
# Load negative conditioning if provided and guidance_scale > 0
210260
neg_prompt_embeds: torch.Tensor | None = None
211261
do_classifier_free_guidance = self.guidance_scale > 0.0 and self.negative_conditioning is not None
212262
if do_classifier_free_guidance:
213-
if self.negative_conditioning is None:
214-
raise ValueError("Negative conditioning is required when guidance_scale > 0")
215-
neg_prompt_embeds = self._load_text_conditioning(
263+
assert self.negative_conditioning is not None
264+
# Load all negative conditionings and concatenate embeddings
265+
# Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
266+
neg_text_conditionings = self._load_text_conditioning(
216267
context=context,
217-
conditioning_name=self.negative_conditioning.conditioning_name,
268+
cond_field=self.negative_conditioning,
269+
img_height=img_token_height,
270+
img_width=img_token_width,
218271
dtype=inference_dtype,
219272
device=device,
220273
)
221-
222-
# Calculate image sequence length for timestep shifting
223-
patch_size = 2 # Z-Image uses patch_size=2
224-
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
274+
# Concatenate all negative embeddings
275+
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
225276

226277
# Calculate shift based on image sequence length
227-
mu = self._calculate_shift(image_seq_len)
278+
mu = self._calculate_shift(img_seq_len)
228279

229280
# Generate sigma schedule with time shift
230281
sigmas = self._get_sigmas(mu, self.steps)
@@ -322,6 +373,15 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
322373
)
323374
)
324375

376+
# Apply regional prompting patch if we have regional masks
377+
exit_stack.enter_context(
378+
patch_transformer_for_regional_prompting(
379+
transformer=transformer,
380+
regional_attn_mask=regional_extension.regional_attn_mask,
381+
img_seq_len=img_seq_len,
382+
)
383+
)
384+
325385
# Denoising loop
326386
for step_idx in tqdm(range(total_steps)):
327387
sigma_curr = sigmas[step_idx]

invokeai/app/invocations/z_image_text_encoder.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from contextlib import ExitStack
2-
from typing import Iterator, Tuple
2+
from typing import Iterator, Optional, Tuple
33

44
import torch
55
from transformers import PreTrainedModel, PreTrainedTokenizerBase
66

77
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
8-
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
8+
from invokeai.app.invocations.fields import (
9+
FieldDescriptions,
10+
Input,
11+
InputField,
12+
TensorField,
13+
UIComponent,
14+
ZImageConditioningField,
15+
)
916
from invokeai.app.invocations.model import Qwen3EncoderField
1017
from invokeai.app.invocations.primitives import ZImageConditioningOutput
1118
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -27,25 +34,34 @@
2734
title="Prompt - Z-Image",
2835
tags=["prompt", "conditioning", "z-image"],
2936
category="conditioning",
30-
version="1.0.0",
37+
version="1.1.0",
3138
classification=Classification.Prototype,
3239
)
3340
class ZImageTextEncoderInvocation(BaseInvocation):
34-
"""Encodes and preps a prompt for a Z-Image image."""
41+
"""Encodes and preps a prompt for a Z-Image image.
42+
43+
Supports regional prompting by connecting a mask input.
44+
"""
3545

3646
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
3747
qwen3_encoder: Qwen3EncoderField = InputField(
3848
title="Qwen3 Encoder",
3949
description=FieldDescriptions.qwen3_encoder,
4050
input=Input.Connection,
4151
)
52+
mask: Optional[TensorField] = InputField(
53+
default=None,
54+
description="A mask defining the region that this conditioning prompt applies to.",
55+
)
4256

4357
@torch.no_grad()
4458
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
4559
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
4660
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
4761
conditioning_name = context.conditioning.save(conditioning_data)
48-
return ZImageConditioningOutput.build(conditioning_name)
62+
return ZImageConditioningOutput(
63+
conditioning=ZImageConditioningField(conditioning_name=conditioning_name, mask=self.mask)
64+
)
4965

5066
def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
5167
"""Encode prompt using Qwen3 text encoder.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Z-Image backend utilities
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Z-Image extensions

0 commit comments

Comments
 (0)