Skip to content

Commit ab6b672

Browse files
Feature: Add Z-Image-Turbo model support (#8671)
Add comprehensive support for Z-Image-Turbo (S3-DiT) models including: Backend: - New BaseModelType.ZImage in taxonomy - Z-Image model config classes (ZImageTransformerConfig, Qwen3TextEncoderConfig) - Model loader for Z-Image transformer and Qwen3 text encoder - Z-Image conditioning data structures - Step callback support for Z-Image with FLUX latent RGB factors Invocations: - z_image_model_loader: Load Z-Image transformer and Qwen3 encoder - z_image_text_encoder: Encode prompts using Qwen3 with chat template - z_image_denoise: Flow matching denoising with time-shifted sigmas - z_image_image_to_latents: Encode images to 16-channel latents - z_image_latents_to_image: Decode latents using FLUX VAE Frontend: - Z-Image graph builder for text-to-image generation - Model picker and validation updates for z-image base type - CFG scale now allows 0 (required for Z-Image-Turbo) - Clip skip disabled for Z-Image (uses Qwen3, not CLIP) - Optimal dimension settings for Z-Image (1024x1024) Technical details: - Uses Qwen3 text encoder (not CLIP/T5) - 16 latent channels with FLUX-compatible VAE - Flow matching scheduler with dynamic time shift - 8 inference steps recommended for Turbo variant - bfloat16 inference dtype ## Summary <!--A description of the changes in this PR. Include the kind of change (fix, feature, docs, etc), the "why" and the "how". Screenshots or videos are useful for frontend changes.--> ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions - Install a Z-Image-Turbo model (e.g., from HuggingFace) - Select the model in the Model Picker - Generate a text-to-image with: - CFG Scale: 0 - Steps: 8 - Resolution: 1024x1024 - Verify the generated image is coherent (not noise) ## Merge Plan Standard merge, no special considerations needed. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _❗Changes to a redux slice have a corresponding migration_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 926923b + 93a587d commit ab6b672

File tree

70 files changed

+42667
-822
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+42667
-822
lines changed

invokeai/app/api/dependencies.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
FLUXConditioningInfo,
5050
SD3ConditioningInfo,
5151
SDXLConditioningInfo,
52+
ZImageConditioningInfo,
5253
)
5354
from invokeai.backend.util.logging import InvokeAILogger
5455
from invokeai.version.invokeai_version import __version__
@@ -129,6 +130,7 @@ def initialize(
129130
FLUXConditioningInfo,
130131
SD3ConditioningInfo,
131132
CogView4ConditioningInfo,
133+
ZImageConditioningInfo,
132134
],
133135
ephemeral=True,
134136
),

invokeai/app/invocations/fields.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ class FieldDescriptions:
154154
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
155155
t5_encoder = "T5 tokenizer and text encoder"
156156
glm_encoder = "GLM (THUDM) tokenizer and text encoder"
157+
qwen3_encoder = "Qwen3 tokenizer and text encoder"
157158
clip_embed_model = "CLIP Embed loader"
158159
clip_g_model = "CLIP-G Embed loader"
159160
unet = "UNet (scheduler, LoRAs)"
@@ -169,6 +170,7 @@ class FieldDescriptions:
169170
flux_model = "Flux model (Transformer) to load"
170171
sd3_model = "SD3 model (MMDiTX) to load"
171172
cogview4_model = "CogView4 model (Transformer) to load"
173+
z_image_model = "Z-Image model (Transformer) to load"
172174
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
173175
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
174176
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -321,6 +323,12 @@ class CogView4ConditioningField(BaseModel):
321323
conditioning_name: str = Field(description="The name of conditioning tensor")
322324

323325

326+
class ZImageConditioningField(BaseModel):
327+
"""A Z-Image conditioning tensor primitive value"""
328+
329+
conditioning_name: str = Field(description="The name of conditioning tensor")
330+
331+
324332
class ConditioningField(BaseModel):
325333
"""A conditioning tensor primitive value"""
326334

invokeai/app/invocations/latents_to_image.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,6 @@
22

33
import torch
44
from diffusers.image_processor import VaeImageProcessor
5-
from diffusers.models.attention_processor import (
6-
AttnProcessor2_0,
7-
LoRAAttnProcessor2_0,
8-
LoRAXFormersAttnProcessor,
9-
XFormersAttnProcessor,
10-
)
115
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
126
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
137

@@ -77,26 +71,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
7771
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
7872
latents = latents.to(TorchDevice.choose_torch_device())
7973
if self.fp32:
74+
# FP32 mode: convert everything to float32 for maximum precision
8075
vae.to(dtype=torch.float32)
81-
82-
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
83-
vae.decoder.mid_block.attentions[0].processor,
84-
(
85-
AttnProcessor2_0,
86-
XFormersAttnProcessor,
87-
LoRAXFormersAttnProcessor,
88-
LoRAAttnProcessor2_0,
89-
),
90-
)
91-
# if xformers or torch_2_0 is used attention block does not need
92-
# to be in float32 which can save lots of memory
93-
if use_torch_2_0_or_xformers:
94-
vae.post_quant_conv.to(latents.dtype)
95-
vae.decoder.conv_in.to(latents.dtype)
96-
vae.decoder.mid_block.to(latents.dtype)
97-
else:
98-
latents = latents.float()
99-
76+
latents = latents.float()
10077
else:
10178
vae.to(dtype=torch.float16)
10279
latents = latents.half()

invokeai/app/invocations/metadata.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ def invoke(self, context: InvocationContext) -> MetadataOutput:
158158
"cogview4_img2img",
159159
"cogview4_inpaint",
160160
"cogview4_outpaint",
161+
"z_image_txt2img",
162+
"z_image_img2img",
163+
"z_image_inpaint",
164+
"z_image_outpaint",
161165
]
162166

163167

@@ -166,7 +170,7 @@ def invoke(self, context: InvocationContext) -> MetadataOutput:
166170
title="Core Metadata",
167171
tags=["metadata"],
168172
category="metadata",
169-
version="2.0.0",
173+
version="2.1.0",
170174
classification=Classification.Internal,
171175
)
172176
class CoreMetadataInvocation(BaseInvocation):
@@ -217,6 +221,10 @@ class CoreMetadataInvocation(BaseInvocation):
217221
default=None,
218222
description="The VAE used for decoding, if the main model's default was not used",
219223
)
224+
qwen3_encoder: Optional[ModelIdentifierField] = InputField(
225+
default=None,
226+
description="The Qwen3 text encoder model used for Z-Image inference",
227+
)
220228

221229
# High resolution fix metadata.
222230
hrf_enabled: Optional[bool] = InputField(

invokeai/app/invocations/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ class GlmEncoderField(BaseModel):
7272
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
7373

7474

75+
class Qwen3EncoderField(BaseModel):
76+
"""Field for Qwen3 text encoder used by Z-Image models."""
77+
78+
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
79+
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
80+
loras: List[LoRAField] = Field(default_factory=list, description="LoRAs to apply on model loading")
81+
82+
7583
class VAEField(BaseModel):
7684
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
7785
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')

invokeai/app/invocations/primitives.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
SD3ConditioningField,
2828
TensorField,
2929
UIComponent,
30+
ZImageConditioningField,
3031
)
3132
from invokeai.app.services.images.images_common import ImageDTO
3233
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -461,6 +462,17 @@ def build(cls, conditioning_name: str) -> "CogView4ConditioningOutput":
461462
return cls(conditioning=CogView4ConditioningField(conditioning_name=conditioning_name))
462463

463464

465+
@invocation_output("z_image_conditioning_output")
466+
class ZImageConditioningOutput(BaseInvocationOutput):
467+
"""Base class for nodes that output a Z-Image text conditioning tensor."""
468+
469+
conditioning: ZImageConditioningField = OutputField(description=FieldDescriptions.cond)
470+
471+
@classmethod
472+
def build(cls, conditioning_name: str) -> "ZImageConditioningOutput":
473+
return cls(conditioning=ZImageConditioningField(conditioning_name=conditioning_name))
474+
475+
464476
@invocation_output("conditioning_output")
465477
class ConditioningOutput(BaseInvocationOutput):
466478
"""Base class for nodes that output a single conditioning tensor"""

0 commit comments

Comments
 (0)