Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,11 @@ class ZImageConditioningField(BaseModel):
"""A Z-Image conditioning tensor primitive value"""

conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor for regional prompting. "
"Excluded regions should be set to False, included regions should be set to True.",
)


class ConditioningField(BaseModel):
Expand Down
126 changes: 93 additions & 33 deletions invokeai/app/invocations/z_image_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,29 @@
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
from invokeai.backend.z_image.z_image_controlnet_extension import (
ZImageControlNetExtension,
z_image_forward_with_control,
)
from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting


@invocation(
"z_image_denoise",
title="Denoise - Z-Image",
tags=["image", "z-image"],
category="image",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class ZImageDenoiseInvocation(BaseInvocation):
"""Run the denoising process with a Z-Image model."""
"""Run the denoising process with a Z-Image model.

Supports regional prompting by connecting multiple conditioning inputs with masks.
"""

# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
Expand All @@ -63,10 +69,10 @@ class ZImageDenoiseInvocation(BaseInvocation):
transformer: TransformerField = InputField(
description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
)
positive_conditioning: ZImageConditioningField = InputField(
positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: Optional[ZImageConditioningField] = InputField(
negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
)
# Z-Image-Turbo works best without CFG (guidance_scale=1.0)
Expand Down Expand Up @@ -126,25 +132,50 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
def _load_text_conditioning(
self,
context: InvocationContext,
conditioning_name: str,
cond_field: ZImageConditioningField | list[ZImageConditioningField],
img_height: int,
img_width: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""Load Z-Image text conditioning."""
cond_data = context.conditioning.load(conditioning_name)
if len(cond_data.conditionings) != 1:
raise ValueError(
f"Expected exactly 1 conditioning entry for Z-Image, got {len(cond_data.conditionings)}. "
"Ensure you are using the Z-Image text encoder."
)
z_image_conditioning = cond_data.conditionings[0]
if not isinstance(z_image_conditioning, ZImageConditioningInfo):
raise TypeError(
f"Expected ZImageConditioningInfo, got {type(z_image_conditioning).__name__}. "
"Ensure you are using the Z-Image text encoder."
)
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
return z_image_conditioning.prompt_embeds
) -> list[ZImageTextConditioning]:
"""Load Z-Image text conditioning with optional regional masks.

Args:
context: The invocation context.
cond_field: Single conditioning field or list of fields.
img_height: Height of the image token grid (H // patch_size).
img_width: Width of the image token grid (W // patch_size).
dtype: Target dtype.
device: Target device.

Returns:
List of ZImageTextConditioning objects with embeddings and masks.
"""
# Normalize to a list
cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field

text_conditionings: list[ZImageTextConditioning] = []
for cond in cond_list:
# Load the text embeddings
cond_data = context.conditioning.load(cond.conditioning_name)
assert len(cond_data.conditionings) == 1
z_image_conditioning = cond_data.conditionings[0]
assert isinstance(z_image_conditioning, ZImageConditioningInfo)
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
prompt_embeds = z_image_conditioning.prompt_embeds

# Load the mask, if provided
mask: torch.Tensor | None = None
if cond.mask is not None:
mask = context.tensors.load(cond.mask.tensor_name)
mask = mask.to(device=device)
mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
mask, img_height, img_width, dtype, device
)

text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))

return text_conditionings

def _get_noise(
self,
Expand Down Expand Up @@ -221,14 +252,33 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:

transformer_info = context.models.load(self.transformer.transformer)

# Load positive conditioning
pos_prompt_embeds = self._load_text_conditioning(
# Calculate image token grid dimensions
patch_size = 2 # Z-Image uses patch_size=2
latent_height = self.height // LATENT_SCALE_FACTOR
latent_width = self.width // LATENT_SCALE_FACTOR
img_token_height = latent_height // patch_size
img_token_width = latent_width // patch_size
img_seq_len = img_token_height * img_token_width

# Load positive conditioning with regional masks
pos_text_conditionings = self._load_text_conditioning(
context=context,
conditioning_name=self.positive_conditioning.conditioning_name,
cond_field=self.positive_conditioning,
img_height=img_token_height,
img_width=img_token_width,
dtype=inference_dtype,
device=device,
)

# Create regional prompting extension
regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
text_conditionings=pos_text_conditionings,
img_seq_len=img_seq_len,
)

# Get the concatenated prompt embeddings for the transformer
pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds

# Load negative conditioning if provided and guidance_scale != 1.0
# CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
# At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
Expand All @@ -238,21 +288,22 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
)
if do_classifier_free_guidance:
if self.negative_conditioning is None:
raise ValueError("Negative conditioning is required when guidance_scale != 1.0")
neg_prompt_embeds = self._load_text_conditioning(
assert self.negative_conditioning is not None
# Load all negative conditionings and concatenate embeddings
# Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
neg_text_conditionings = self._load_text_conditioning(
context=context,
conditioning_name=self.negative_conditioning.conditioning_name,
cond_field=self.negative_conditioning,
img_height=img_token_height,
img_width=img_token_width,
dtype=inference_dtype,
device=device,
)

# Calculate image sequence length for timestep shifting
patch_size = 2 # Z-Image uses patch_size=2
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
# Concatenate all negative embeddings
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)

# Calculate shift based on image sequence length
mu = self._calculate_shift(image_seq_len)
mu = self._calculate_shift(img_seq_len)

# Generate sigma schedule with time shift
sigmas = self._get_sigmas(mu, self.steps)
Expand Down Expand Up @@ -443,6 +494,15 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
)
)

# Apply regional prompting patch if we have regional masks
exit_stack.enter_context(
patch_transformer_for_regional_prompting(
transformer=transformer,
regional_attn_mask=regional_extension.regional_attn_mask,
img_seq_len=img_seq_len,
)
)

# Denoising loop
for step_idx in tqdm(range(total_steps)):
sigma_curr = sigmas[step_idx]
Expand Down
26 changes: 21 additions & 5 deletions invokeai/app/invocations/z_image_text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
from typing import Iterator, Optional, Tuple

import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
TensorField,
UIComponent,
ZImageConditioningField,
)
from invokeai.app.invocations.model import Qwen3EncoderField
from invokeai.app.invocations.primitives import ZImageConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
Expand All @@ -27,25 +34,34 @@
title="Prompt - Z-Image",
tags=["prompt", "conditioning", "z-image"],
category="conditioning",
version="1.0.0",
version="1.1.0",
classification=Classification.Prototype,
)
class ZImageTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a Z-Image image."""
"""Encodes and preps a prompt for a Z-Image image.

Supports regional prompting by connecting a mask input.
"""

prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
qwen3_encoder: Qwen3EncoderField = InputField(
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)
mask: Optional[TensorField] = InputField(
default=None,
description="A mask defining the region that this conditioning prompt applies to.",
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
conditioning_name = context.conditioning.save(conditioning_data)
return ZImageConditioningOutput.build(conditioning_name)
return ZImageConditioningOutput(
conditioning=ZImageConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)

def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
"""Encode prompt using Qwen3 text encoder.
Expand Down
2 changes: 1 addition & 1 deletion invokeai/backend/z_image/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Z-Image Control Transformer support for InvokeAI
# Z-Image backend utilities
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
from invokeai.backend.z_image.z_image_control_transformer import ZImageControlTransformer2DModel
from invokeai.backend.z_image.z_image_controlnet_extension import (
Expand Down
1 change: 1 addition & 0 deletions invokeai/backend/z_image/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Z-Image extensions
Loading