Skip to content

Commit 88b7e24

Browse files
Merge branch 'feat/z-image-turbo-support' into feat/z-image-regional-guidance
2 parents 9d6290a + 4ce0ef5 commit 88b7e24

File tree

25 files changed

+1272
-149
lines changed

25 files changed

+1272
-149
lines changed

invokeai/app/invocations/metadata.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def invoke(self, context: InvocationContext) -> MetadataOutput:
170170
title="Core Metadata",
171171
tags=["metadata"],
172172
category="metadata",
173-
version="2.0.0",
173+
version="2.1.0",
174174
classification=Classification.Internal,
175175
)
176176
class CoreMetadataInvocation(BaseInvocation):
@@ -221,6 +221,10 @@ class CoreMetadataInvocation(BaseInvocation):
221221
default=None,
222222
description="The VAE used for decoding, if the main model's default was not used",
223223
)
224+
qwen3_encoder: Optional[ModelIdentifierField] = InputField(
225+
default=None,
226+
description="The Qwen3 text encoder model used for Z-Image inference",
227+
)
224228

225229
# High resolution fix metadata.
226230
hrf_enabled: Optional[bool] = InputField(

invokeai/app/invocations/z_image_denoise.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414
Input,
1515
InputField,
1616
LatentsField,
17-
WithBoard,
18-
WithMetadata,
1917
ZImageConditioningField,
2018
)
21-
from invokeai.app.invocations.model import LoRAField, TransformerField
19+
from invokeai.app.invocations.model import TransformerField
2220
from invokeai.app.invocations.primitives import LatentsOutput
2321
from invokeai.app.services.shared.invocation_context import InvocationContext
2422
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
@@ -39,14 +37,11 @@
3937
title="Denoise - Z-Image",
4038
tags=["image", "z-image"],
4139
category="image",
42-
version="1.2.0",
40+
version="1.1.0",
4341
classification=Classification.Prototype,
4442
)
45-
class ZImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
46-
"""Run the denoising process with a Z-Image model.
47-
48-
Supports regional prompting by connecting multiple conditioning inputs with masks.
49-
"""
43+
class ZImageDenoiseInvocation(BaseInvocation):
44+
"""Run the denoising process with a Z-Image model."""
5045

5146
# If latents is provided, this means we are doing image-to-image.
5247
latents: Optional[LatentsField] = InputField(
@@ -167,8 +162,10 @@ def _get_noise(
167162
seed: int,
168163
) -> torch.Tensor:
169164
"""Generate initial noise tensor."""
165+
# Generate noise as float32 on CPU for maximum compatibility,
166+
# then cast to target dtype/device
170167
rand_device = "cpu"
171-
rand_dtype = torch.float16
168+
rand_dtype = torch.float32
172169

173170
return torch.randn(
174171
batch_size,
@@ -224,8 +221,8 @@ def time_shift(mu: float, sigma: float, t: float) -> float:
224221
return sigmas
225222

226223
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
227-
inference_dtype = torch.bfloat16
228224
device = TorchDevice.choose_torch_device()
225+
inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
229226

230227
transformer_info = context.models.load(self.transformer.transformer)
231228

@@ -324,7 +321,8 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
324321
inpaint_mask = self._prep_inpaint_mask(context, latents)
325322
inpaint_extension: RectifiedFlowInpaintExtension | None = None
326323
if inpaint_mask is not None:
327-
assert init_latents is not None
324+
if init_latents is None:
325+
raise ValueError("Initial latents are required when using an inpaint mask (image-to-image inpainting)")
328326
inpaint_extension = RectifiedFlowInpaintExtension(
329327
init_latents=init_latents,
330328
inpaint_mask=inpaint_mask,
@@ -352,7 +350,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
352350
# Determine if the model is quantized.
353351
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
354352
# slower inference than direct patching, but is agnostic to the quantization format.
355-
if transformer_config.format in [ModelFormat.Diffusers]:
353+
if transformer_config.format in [ModelFormat.Diffusers, ModelFormat.Checkpoint]:
356354
model_is_quantized = False
357355
elif transformer_config.format in [ModelFormat.GGUFQuantized]:
358356
model_is_quantized = True
@@ -456,6 +454,10 @@ def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatc
456454
"""Iterate over LoRA models to apply to the transformer."""
457455
for lora in self.transformer.loras:
458456
lora_info = context.models.load(lora.lora)
459-
assert isinstance(lora_info.model, ModelPatchRaw)
457+
if not isinstance(lora_info.model, ModelPatchRaw):
458+
raise TypeError(
459+
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
460+
"The LoRA model may be corrupted or incompatible."
461+
)
460462
yield (lora_info.model, lora.weight)
461463
del lora_info

invokeai/app/invocations/z_image_image_to_latents.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,18 @@ class ZImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
4141

4242
@staticmethod
4343
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
44-
assert isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder))
44+
if not isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder)):
45+
raise TypeError(
46+
f"Expected AutoencoderKL or FluxAutoEncoder for Z-Image VAE, got {type(vae_info.model).__name__}. "
47+
"Ensure you are using a compatible VAE model."
48+
)
4549

4650
with vae_info.model_on_device() as (_, vae):
47-
assert isinstance(vae, (AutoencoderKL, FluxAutoEncoder))
51+
if not isinstance(vae, (AutoencoderKL, FluxAutoEncoder)):
52+
raise TypeError(
53+
f"Expected AutoencoderKL or FluxAutoEncoder, got {type(vae).__name__}. "
54+
"VAE model type changed unexpectedly after loading."
55+
)
4856

4957
vae_dtype = next(iter(vae.parameters())).dtype
5058
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
@@ -80,7 +88,11 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
8088
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
8189

8290
vae_info = context.models.load(self.vae.vae)
83-
assert isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder))
91+
if not isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder)):
92+
raise TypeError(
93+
f"Expected AutoencoderKL or FluxAutoEncoder for Z-Image VAE, got {type(vae_info.model).__name__}. "
94+
"Ensure you are using a compatible VAE model."
95+
)
8496

8597
context.util.signal_progress("Running VAE")
8698
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)

invokeai/app/invocations/z_image_latents_to_image.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,26 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
4545
latents = context.tensors.load(self.latents.latents_name)
4646

4747
vae_info = context.models.load(self.vae.vae)
48-
assert isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder))
48+
if not isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder)):
49+
raise TypeError(
50+
f"Expected AutoencoderKL or FluxAutoEncoder for Z-Image VAE, got {type(vae_info.model).__name__}. "
51+
"Ensure you are using a compatible VAE model."
52+
)
4953

5054
is_flux_vae = isinstance(vae_info.model, FluxAutoEncoder)
5155

5256
# FLUX VAE doesn't support seamless, so only apply for AutoencoderKL
5357
seamless_context = (
54-
nullcontext()
55-
if is_flux_vae
56-
else SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes)
58+
nullcontext() if is_flux_vae else SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes)
5759
)
5860

5961
with seamless_context, vae_info.model_on_device() as (_, vae):
6062
context.util.signal_progress("Running VAE")
61-
assert isinstance(vae, (AutoencoderKL, FluxAutoEncoder))
63+
if not isinstance(vae, (AutoencoderKL, FluxAutoEncoder)):
64+
raise TypeError(
65+
f"Expected AutoencoderKL or FluxAutoEncoder, got {type(vae).__name__}. "
66+
"VAE model type changed unexpectedly after loading."
67+
)
6268

6369
vae_dtype = next(iter(vae.parameters())).dtype
6470
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)

invokeai/app/invocations/z_image_lora_loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ def invoke(self, context: InvocationContext) -> ZImageLoRALoaderOutput:
136136
if not context.models.exists(lora.lora.key):
137137
raise Exception(f"Unknown lora: {lora.lora.key}!")
138138

139-
assert lora.lora.base is BaseModelType.ZImage
139+
if lora.lora.base is not BaseModelType.ZImage:
140+
raise ValueError(
141+
f"LoRA '{lora.lora.key}' is for {lora.lora.base.value if lora.lora.base else 'unknown'} models, "
142+
"not Z-Image models. Ensure you are using a Z-Image compatible LoRA."
143+
)
140144

141145
added_loras.append(lora.lora.key)
142146

invokeai/app/invocations/z_image_text_encoder.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,27 +79,41 @@ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.
7979
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
8080

8181
# Apply LoRA models to the text encoder
82+
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
8283
exit_stack.enter_context(
8384
LayerPatcher.apply_smart_model_patches(
8485
model=text_encoder,
8586
patches=self._lora_iterator(context),
8687
prefix=Z_IMAGE_LORA_QWEN3_PREFIX,
87-
dtype=torch.bfloat16,
88+
dtype=lora_dtype,
8889
)
8990
)
9091

9192
context.util.signal_progress("Running Qwen3 text encoder")
92-
assert isinstance(text_encoder, PreTrainedModel)
93-
assert isinstance(tokenizer, PreTrainedTokenizerBase)
93+
if not isinstance(text_encoder, PreTrainedModel):
94+
raise TypeError(
95+
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
96+
"The Qwen3 encoder model may be corrupted or incompatible."
97+
)
98+
if not isinstance(tokenizer, PreTrainedTokenizerBase):
99+
raise TypeError(
100+
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
101+
"The Qwen3 tokenizer may be corrupted or incompatible."
102+
)
94103

95104
# Apply chat template similar to diffusers ZImagePipeline
96105
# The chat template formats the prompt for the Qwen3 model
97-
prompt_formatted = tokenizer.apply_chat_template(
98-
[{"role": "user", "content": prompt}],
99-
tokenize=False,
100-
add_generation_prompt=True,
101-
enable_thinking=True,
102-
)
106+
try:
107+
prompt_formatted = tokenizer.apply_chat_template(
108+
[{"role": "user", "content": prompt}],
109+
tokenize=False,
110+
add_generation_prompt=True,
111+
enable_thinking=True,
112+
)
113+
except (AttributeError, TypeError) as e:
114+
# Fallback if tokenizer doesn't support apply_chat_template or enable_thinking
115+
context.logger.warning(f"Chat template failed ({e}), using raw prompt.")
116+
prompt_formatted = prompt
103117

104118
# Tokenize the formatted prompt
105119
text_inputs = tokenizer(
@@ -113,8 +127,16 @@ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.
113127

114128
text_input_ids = text_inputs.input_ids
115129
attention_mask = text_inputs.attention_mask
116-
assert isinstance(text_input_ids, torch.Tensor)
117-
assert isinstance(attention_mask, torch.Tensor)
130+
if not isinstance(text_input_ids, torch.Tensor):
131+
raise TypeError(
132+
f"Expected torch.Tensor for input_ids, got {type(text_input_ids).__name__}. "
133+
"Tokenizer returned unexpected type."
134+
)
135+
if not isinstance(attention_mask, torch.Tensor):
136+
raise TypeError(
137+
f"Expected torch.Tensor for attention_mask, got {type(attention_mask).__name__}. "
138+
"Tokenizer returned unexpected type."
139+
)
118140

119141
# Check for truncation
120142
untruncated_ids = tokenizer(prompt_formatted, padding="longest", return_tensors="pt").input_ids
@@ -135,6 +157,18 @@ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.
135157
attention_mask=prompt_mask,
136158
output_hidden_states=True,
137159
)
160+
161+
# Validate hidden_states output
162+
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
163+
raise RuntimeError(
164+
"Text encoder did not return hidden_states. "
165+
"Ensure output_hidden_states=True is supported by this model."
166+
)
167+
if len(outputs.hidden_states) < 2:
168+
raise RuntimeError(
169+
f"Expected at least 2 hidden states from text encoder, got {len(outputs.hidden_states)}. "
170+
"This may indicate an incompatible model or configuration."
171+
)
138172
prompt_embeds = outputs.hidden_states[-2]
139173

140174
# Z-Image expects a 2D tensor [seq_len, hidden_dim] with only valid tokens
@@ -143,13 +177,21 @@ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.
143177
# Since batch_size=1, we take the first item and filter by mask
144178
prompt_embeds = prompt_embeds[0][prompt_mask[0]]
145179

146-
assert isinstance(prompt_embeds, torch.Tensor)
180+
if not isinstance(prompt_embeds, torch.Tensor):
181+
raise TypeError(
182+
f"Expected torch.Tensor for prompt embeddings, got {type(prompt_embeds).__name__}. "
183+
"Text encoder returned unexpected type."
184+
)
147185
return prompt_embeds
148186

149187
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
150188
"""Iterate over LoRA models to apply to the Qwen3 text encoder."""
151189
for lora in self.qwen3_encoder.loras:
152190
lora_info = context.models.load(lora.lora)
153-
assert isinstance(lora_info.model, ModelPatchRaw)
191+
if not isinstance(lora_info.model, ModelPatchRaw):
192+
raise TypeError(
193+
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
194+
"The LoRA model may be corrupted or incompatible."
195+
)
154196
yield (lora_info.model, lora.weight)
155197
del lora_info

invokeai/backend/model_manager/configs/factory.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
Main_Checkpoint_SD2_Config,
6161
Main_Checkpoint_SDXL_Config,
6262
Main_Checkpoint_SDXLRefiner_Config,
63+
Main_Checkpoint_ZImage_Config,
6364
Main_Diffusers_CogView4_Config,
6465
Main_Diffusers_SD1_Config,
6566
Main_Diffusers_SD2_Config,
@@ -71,13 +72,17 @@
7172
Main_GGUF_ZImage_Config,
7273
MainModelDefaultSettings,
7374
)
75+
from invokeai.backend.model_manager.configs.qwen3_encoder import (
76+
Qwen3Encoder_Checkpoint_Config,
77+
Qwen3Encoder_GGUF_Config,
78+
Qwen3Encoder_Qwen3Encoder_Config,
79+
)
7480
from invokeai.backend.model_manager.configs.siglip import SigLIP_Diffusers_Config
7581
from invokeai.backend.model_manager.configs.spandrel import Spandrel_Checkpoint_Config
7682
from invokeai.backend.model_manager.configs.t2i_adapter import (
7783
T2IAdapter_Diffusers_SD1_Config,
7884
T2IAdapter_Diffusers_SDXL_Config,
7985
)
80-
from invokeai.backend.model_manager.configs.qwen3_encoder import Qwen3Encoder_Qwen3Encoder_Config
8186
from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config
8287
from invokeai.backend.model_manager.configs.textual_inversion import (
8388
TI_File_SD1_Config,
@@ -150,6 +155,7 @@
150155
Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()],
151156
Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
152157
Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
158+
Annotated[Main_Checkpoint_ZImage_Config, Main_Checkpoint_ZImage_Config.get_tag()],
153159
# Main (Pipeline) - quantized formats
154160
Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()],
155161
Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()],
@@ -194,6 +200,8 @@
194200
Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()],
195201
# Qwen3 Encoder
196202
Annotated[Qwen3Encoder_Qwen3Encoder_Config, Qwen3Encoder_Qwen3Encoder_Config.get_tag()],
203+
Annotated[Qwen3Encoder_Checkpoint_Config, Qwen3Encoder_Checkpoint_Config.get_tag()],
204+
Annotated[Qwen3Encoder_GGUF_Config, Qwen3Encoder_GGUF_Config.get_tag()],
197205
# TI - file format
198206
Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()],
199207
Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()],

0 commit comments

Comments
 (0)