From 2354fda9c0d66ae1e1606463a767b7b8173ff73a Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 17:44:12 +0000 Subject: [PATCH 01/20] init --- ...convert_z_image_controlnet_to_diffusers.py | 103 +++ src/diffusers/models/controlnets/__init__.py | 1 + .../models/controlnets/controlnet_z_image.py | 528 ++++++++++++++ .../transformers/transformer_z_image.py | 11 +- .../z_image/pipeline_z_image_controlnet.py | 674 ++++++++++++++++++ 5 files changed, 1315 insertions(+), 2 deletions(-) create mode 100644 scripts/convert_z_image_controlnet_to_diffusers.py create mode 100644 src/diffusers/models/controlnets/controlnet_z_image.py create mode 100644 src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py new file mode 100644 index 000000000000..c4b96cda02af --- /dev/null +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -0,0 +1,103 @@ +import argparse +from contextlib import nullcontext + +import torch +import safetensors.torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers.utils.import_utils import is_accelerate_available +from diffusers.models import ZImageTransformer2DModel +from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel + +""" +python scripts/convert_z_image_controlnet_to_diffusers.py \ +--original_z_image_repo_id "Tongyi-MAI/Z-Image-Turbo" \ +--original_controlnet_repo_id "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" \ +--filename "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" +--output_path "z-image-controlnet-hf/" +""" + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_z_image_repo_id", default="Tongyi-MAI/Z-Image-Turbo", type=str) +parser.add_argument("--original_controlnet_repo_id", default=None, type=str) +parser.add_argument("--filename", default="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", type=str) + +args = parser.parse_args() + + +def load_original_checkpoint(args): + if args.original_controlnet_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_controlnet_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_controlnet_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + +def load_z_image(args): + model = ZImageTransformer2DModel.from_pretrained(args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16) + return model.state_dict(), model.config + +def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict): + converted_state_dict = {} + + converted_state_dict.update(original_state_dict) + + to_copy = {"all_x_embedder.", "noise_refiner.", "context_refiner.", "t_embedder.", "cap_embedder.", "x_pad_token", "cap_pad_token"} + + for key in z_image.keys(): + for copy_key in to_copy: + if key.startswith(copy_key): + converted_state_dict[key] = z_image[key] + + return converted_state_dict + + +def main(args): + original_ckpt = load_original_checkpoint(args) + z_image, config = load_z_image(args) + + control_in_dim = 16 + control_layers_places = [0, 5, 10, 15, 20, 25] + + converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_ckpt) + + for key, tensor in converted_controlnet_state_dict.items(): + print(f"{key} - {tensor.dtype}") + + controlnet = ZImageControlNetModel( + all_patch_size=config["all_patch_size"], + all_f_patch_size=config["all_f_patch_size"], + in_channels=config["in_channels"], + dim=config["dim"], + n_layers=config["n_layers"], + n_refiner_layers=config["n_refiner_layers"], + n_heads=config["n_heads"], + n_kv_heads=config["n_kv_heads"], + norm_eps=config["norm_eps"], + qk_norm=config["qk_norm"], + cap_feat_dim=config["cap_feat_dim"], + rope_theta=config["rope_theta"], + t_scale=config["t_scale"], + axes_dims=config["axes_dims"], + axes_lens=config["axes_lens"], + control_layers_places=control_layers_places, + control_in_dim=control_in_dim, + ) + missing, unexpected = controlnet.load_state_dict(converted_controlnet_state_dict) + print(f"{missing=}") + print(f"{unexpected=}") + print("Saving Z-Image ControlNet in Diffusers format") + controlnet.save_pretrained(args.output_path, max_shard_size="5GB") + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index 7ce352879daa..fee7f231e899 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -19,6 +19,7 @@ ) from .controlnet_union import ControlNetUnionModel from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel + from .controlnet_z_image import ZImageControlNetModel from .multicontrolnet import MultiControlNetModel from .multicontrolnet_union import MultiControlNetUnionModel diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py new file mode 100644 index 000000000000..d6cede86812d --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -0,0 +1,528 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...models.normalization import RMSNorm +from ..controlnets.controlnet import zero_module +from ..modeling_utils import ModelMixin +from ..transformers.transformer_z_image import ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM + + +class ZImageControlTransformerBlock(ZImageTransformerBlock): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0 + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) + self.block_id = block_id + if block_id == 0: + self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) + self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) + + def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + +class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + control_layers_places: List[int]=None, + control_in_dim=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + self.n_layers = n_layers + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear(cap_feat_dim, dim, bias=True), + ) + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + ## Original Control layers + + self.control_layers_places = control_layers_places + self.control_in_dim = control_in_dim + + assert 0 in self.control_layers_places + + # control blocks + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock( + i, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + block_id=i + ) + for i in self.control_layers_places + ] + ) + + # control patch embeddings + all_x_embedder = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + self.control_noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify( + self, + all_image: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + cap_padding_len: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_image_size, + all_image_pos_ids, + all_image_pad_mask, + ) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + all_cap_pad_mask.append( + torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + cap_padded_feat = torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, + ) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + all_image_pos_ids, + all_cap_pos_ids, + all_image_pad_mask, + all_cap_pad_mask, + ) + + def forward( + self, + x: List[torch.Tensor], + cap_feats: List[torch.Tensor], + control_context: List[torch.Tensor], + t=None, + patch_size=2, + f_patch_size=1, + conditioning_scale: float = 1.0, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) + else: + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + ## Original forward_control + + # embeddings + bsz = len(control_context) + device = control_context[0].device + ( + control_context, + x_size, + x_pos_ids, + x_inner_pad_mask, + ) = self.patchify(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) + + # control_context embed & refine + x_item_seqlens = [len(_) for _ in control_context] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + # Match t_embedder output dtype to control_context for layerwise casting compatibility + adaln_input = t.type_as(control_context) + control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context = list(control_context.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.control_noise_refiner: + control_context = self._gradient_checkpointing_func(layer, control_context, x_attn_mask, x_freqs_cis, adaln_input) + else: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) + + # unified + cap_item_seqlens = [len(_) for _ in cap_feats] + control_context_unified = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) + control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + c = control_context_unified + + new_kwargs = dict(x=unified, attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input) + + for layer in self.control_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + c = self._gradient_checkpointing_func(layer, c, **new_kwargs) + else: + c = layer(c, **new_kwargs) + + hints = torch.unbind(c)[:-1] * conditioning_scale + controlnet_block_samples = {} + for layer_idx in range(self.n_layers): + if layer_idx in self.control_layers_places: + hints_idx = self.control_layers_places.index(layer_idx) + controlnet_block_samples[layer_idx] = hints[hints_idx] + + return controlnet_block_samples diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5c401b9d202b..2d332217d897 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -538,6 +538,7 @@ def forward( cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, + controlnet_block_samples: Optional[dict[int, torch.Tensor]]=None, return_dict: bool = True, ): assert patch_size in self.all_patch_size @@ -635,13 +636,19 @@ def forward( unified_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.layers: + for layer_idx, layer in enumerate(self.layers): unified = self._gradient_checkpointing_func( layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input ) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] else: - for layer in self.layers: + for layer_idx, layer in enumerate(self.layers): unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = list(unified.unbind(dim=0)) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py new file mode 100644 index 000000000000..609b141be796 --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -0,0 +1,674 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets import ZImageControlNetModel +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImagePipeline + + >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + ... prompt, + ... height=1024, + ... width=1024, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + controlnet: ZImageControlNetModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + control_image = control_image.unsqueeze(2) + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + controlnet_block_samples = self.controlnet( + latent_model_input_list, + prompt_embeds_model_input, + control_image, + timestep_model_input, + conditioning_scale=controlnet_conditioning_scale, + ) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + controlnet_block_samples=controlnet_block_samples, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) From 1e2009de435516caf7b6e67ab215f8f6299c375f Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 17:48:38 +0000 Subject: [PATCH 02/20] passed transformer --- .../models/controlnets/controlnet_z_image.py | 98 +++---------------- .../z_image/pipeline_z_image_controlnet.py | 1 + 2 files changed, 15 insertions(+), 84 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index d6cede86812d..6fe9d38ce3d1 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -23,7 +23,7 @@ from ...models.normalization import RMSNorm from ..controlnets.controlnet import zero_module from ..modeling_utils import ModelMixin -from ..transformers.transformer_z_image import ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM +from ..transformers.transformer_z_image import ZImageTransformer2DModel, ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM class ZImageControlTransformerBlock(ZImageTransformerBlock): @@ -66,87 +66,16 @@ def __init__( self, all_patch_size=(2,), all_f_patch_size=(1,), - in_channels=16, dim=3840, - n_layers=30, n_refiner_layers=2, n_heads=30, n_kv_heads=30, norm_eps=1e-5, qk_norm=True, - cap_feat_dim=2560, - rope_theta=256.0, - t_scale=1000.0, - axes_dims=[32, 48, 48], - axes_lens=[1024, 512, 512], control_layers_places: List[int]=None, control_in_dim=None, ): super().__init__() - self.in_channels = in_channels - self.out_channels = in_channels - self.all_patch_size = all_patch_size - self.all_f_patch_size = all_f_patch_size - self.dim = dim - self.n_heads = n_heads - - self.rope_theta = rope_theta - self.t_scale = t_scale - self.gradient_checkpointing = False - self.n_layers = n_layers - - assert len(all_patch_size) == len(all_f_patch_size) - - all_x_embedder = {} - for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): - x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) - all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder - - self.all_x_embedder = nn.ModuleDict(all_x_embedder) - self.noise_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - 1000 + layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=True, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.context_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=False, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) - self.cap_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps), - nn.Linear(cap_feat_dim, dim, bias=True), - ) - - self.x_pad_token = nn.Parameter(torch.empty((1, dim))) - self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) - - self.axes_dims = axes_dims - self.axes_lens = axes_lens - - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - - ## Original Control layers - self.control_layers_places = control_layers_places self.control_in_dim = control_in_dim @@ -366,6 +295,7 @@ def patchify_and_embed( def forward( self, + transformer: ZImageTransformer2DModel, x: List[torch.Tensor], cap_feats: List[torch.Tensor], control_context: List[torch.Tensor], @@ -380,7 +310,7 @@ def forward( bsz = len(x) device = x[0].device t = t * self.t_scale - t = self.t_embedder(t) + t = transformer.t_embedder(t) ( x, @@ -398,13 +328,13 @@ def forward( x_max_item_seqlen = max(x_item_seqlens) x = torch.cat(x, dim=0) - x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + x = transformer.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) # Match t_embedder output dtype to x for layerwise casting compatibility adaln_input = t.type_as(x) - x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) x = pad_sequence(x, batch_first=True, padding_value=0.0) x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) @@ -413,10 +343,10 @@ def forward( x_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.noise_refiner: + for layer in transformer.noise_refiner: x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) else: - for layer in self.noise_refiner: + for layer in transformer.noise_refiner: x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) # cap embed & refine @@ -425,10 +355,10 @@ def forward( cap_max_item_seqlen = max(cap_item_seqlens) cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = self.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = transformer.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = transformer.cap_pad_token cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list(transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) @@ -437,10 +367,10 @@ def forward( cap_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.context_refiner: + for layer in transformer.context_refiner: cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) else: - for layer in self.context_refiner: + for layer in transformer.context_refiner: cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) # unified @@ -485,7 +415,7 @@ def forward( adaln_input = t.type_as(control_context) control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token control_context = list(control_context.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 609b141be796..d374b8032ea8 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -594,6 +594,7 @@ def __call__( latent_model_input_list = list(latent_model_input.unbind(dim=0)) controlnet_block_samples = self.controlnet( + self.transformer, latent_model_input_list, prompt_embeds_model_input, control_image, From 0c308394049f2c7a65c697cf88ea9c40d9ca4333 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 17:50:01 +0000 Subject: [PATCH 03/20] ruff --- ...convert_z_image_controlnet_to_diffusers.py | 21 ++++++++--- .../models/controlnets/controlnet_z_image.py | 36 +++++++++---------- .../transformers/transformer_z_image.py | 2 +- .../z_image/pipeline_z_image_controlnet.py | 3 +- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py index c4b96cda02af..a9f97d81676d 100644 --- a/scripts/convert_z_image_controlnet_to_diffusers.py +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -1,14 +1,15 @@ import argparse from contextlib import nullcontext -import torch import safetensors.torch +import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers.utils.import_utils import is_accelerate_available from diffusers.models import ZImageTransformer2DModel from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel +from diffusers.utils.import_utils import is_accelerate_available + """ python scripts/convert_z_image_controlnet_to_diffusers.py \ @@ -42,16 +43,28 @@ def load_original_checkpoint(args): original_state_dict = safetensors.torch.load_file(ckpt_path) return original_state_dict + def load_z_image(args): - model = ZImageTransformer2DModel.from_pretrained(args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16) + model = ZImageTransformer2DModel.from_pretrained( + args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ) return model.state_dict(), model.config + def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict): converted_state_dict = {} converted_state_dict.update(original_state_dict) - to_copy = {"all_x_embedder.", "noise_refiner.", "context_refiner.", "t_embedder.", "cap_embedder.", "x_pad_token", "cap_pad_token"} + to_copy = { + "all_x_embedder.", + "noise_refiner.", + "context_refiner.", + "t_embedder.", + "cap_embedder.", + "x_pad_token", + "cap_pad_token", + } for key in z_image.keys(): for copy_key in to_copy: diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 6fe9d38ce3d1..b76a2c54c3d8 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -20,15 +20,18 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...models.normalization import RMSNorm from ..controlnets.controlnet import zero_module from ..modeling_utils import ModelMixin -from ..transformers.transformer_z_image import ZImageTransformer2DModel, ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM +from ..transformers.transformer_z_image import ( + SEQ_MULTI_OF, + ZImageTransformer2DModel, + ZImageTransformerBlock, +) class ZImageControlTransformerBlock(ZImageTransformerBlock): def __init__( - self, + self, layer_id: int, dim: int, n_heads: int, @@ -36,7 +39,7 @@ def __init__( norm_eps: float, qk_norm: bool, modulation=True, - block_id=0 + block_id=0, ): super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) self.block_id = block_id @@ -57,7 +60,8 @@ def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs): all_c += [c_skip, c] c = torch.stack(all_c) return c - + + class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True @@ -72,7 +76,7 @@ def __init__( n_kv_heads=30, norm_eps=1e-5, qk_norm=True, - control_layers_places: List[int]=None, + control_layers_places: List[int] = None, control_in_dim=None, ): super().__init__() @@ -84,15 +88,7 @@ def __init__( # control blocks self.control_layers = nn.ModuleList( [ - ZImageControlTransformerBlock( - i, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - block_id=i - ) + ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) for i in self.control_layers_places ] ) @@ -425,7 +421,9 @@ def forward( if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.control_noise_refiner: - control_context = self._gradient_checkpointing_func(layer, control_context, x_attn_mask, x_freqs_cis, adaln_input) + control_context = self._gradient_checkpointing_func( + layer, control_context, x_attn_mask, x_freqs_cis, adaln_input + ) else: for layer in self.control_noise_refiner: control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) @@ -440,14 +438,14 @@ def forward( control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) c = control_context_unified - new_kwargs = dict(x=unified, attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input) - + new_kwargs = {"x": unified, "attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input} + for layer in self.control_layers: if torch.is_grad_enabled() and self.gradient_checkpointing: c = self._gradient_checkpointing_func(layer, c, **new_kwargs) else: c = layer(c, **new_kwargs) - + hints = torch.unbind(c)[:-1] * conditioning_scale controlnet_block_samples = {} for layer_idx in range(self.n_layers): diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 2d332217d897..70ffced8b63a 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -538,7 +538,7 @@ def forward( cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, - controlnet_block_samples: Optional[dict[int, torch.Tensor]]=None, + controlnet_block_samples: Optional[dict[int, torch.Tensor]] = None, return_dict: bool = True, ): assert patch_size in self.all_patch_size diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index d374b8032ea8..44906a0db519 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -89,7 +89,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -509,7 +508,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.vae.dtype, - ) + ) height, width = control_image.shape[-2:] control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor From 52f996e226dfd1e7f1a1b0d001c022dae71e24a8 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:03:23 +0000 Subject: [PATCH 04/20] convert passed --- ...convert_z_image_controlnet_to_diffusers.py | 58 ++----------------- 1 file changed, 6 insertions(+), 52 deletions(-) diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py index a9f97d81676d..aed27c14f205 100644 --- a/scripts/convert_z_image_controlnet_to_diffusers.py +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -1,19 +1,17 @@ import argparse from contextlib import nullcontext -import safetensors.torch import torch +import safetensors.torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers.models import ZImageTransformer2DModel from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel from diffusers.utils.import_utils import is_accelerate_available """ python scripts/convert_z_image_controlnet_to_diffusers.py \ ---original_z_image_repo_id "Tongyi-MAI/Z-Image-Turbo" \ --original_controlnet_repo_id "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" \ --filename "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" --output_path "z-image-controlnet-hf/" @@ -23,7 +21,6 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext parser = argparse.ArgumentParser() -parser.add_argument("--original_z_image_repo_id", default="Tongyi-MAI/Z-Image-Turbo", type=str) parser.add_argument("--original_controlnet_repo_id", default=None, type=str) parser.add_argument("--filename", default="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", type=str) parser.add_argument("--checkpoint_path", default=None, type=str) @@ -44,72 +41,29 @@ def load_original_checkpoint(args): return original_state_dict -def load_z_image(args): - model = ZImageTransformer2DModel.from_pretrained( - args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16 - ) - return model.state_dict(), model.config - - -def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict): +def convert_z_image_controlnet_checkpoint_to_diffusers(original_state_dict): converted_state_dict = {} converted_state_dict.update(original_state_dict) - to_copy = { - "all_x_embedder.", - "noise_refiner.", - "context_refiner.", - "t_embedder.", - "cap_embedder.", - "x_pad_token", - "cap_pad_token", - } - - for key in z_image.keys(): - for copy_key in to_copy: - if key.startswith(copy_key): - converted_state_dict[key] = z_image[key] - return converted_state_dict def main(args): original_ckpt = load_original_checkpoint(args) - z_image, config = load_z_image(args) control_in_dim = 16 control_layers_places = [0, 5, 10, 15, 20, 25] - converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_ckpt) - - for key, tensor in converted_controlnet_state_dict.items(): - print(f"{key} - {tensor.dtype}") + converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(original_ckpt) controlnet = ZImageControlNetModel( - all_patch_size=config["all_patch_size"], - all_f_patch_size=config["all_f_patch_size"], - in_channels=config["in_channels"], - dim=config["dim"], - n_layers=config["n_layers"], - n_refiner_layers=config["n_refiner_layers"], - n_heads=config["n_heads"], - n_kv_heads=config["n_kv_heads"], - norm_eps=config["norm_eps"], - qk_norm=config["qk_norm"], - cap_feat_dim=config["cap_feat_dim"], - rope_theta=config["rope_theta"], - t_scale=config["t_scale"], - axes_dims=config["axes_dims"], - axes_lens=config["axes_lens"], control_layers_places=control_layers_places, control_in_dim=control_in_dim, - ) - missing, unexpected = controlnet.load_state_dict(converted_controlnet_state_dict) - print(f"{missing=}") - print(f"{unexpected=}") + ).to(torch.bfloat16) + controlnet.load_state_dict(converted_controlnet_state_dict) print("Saving Z-Image ControlNet in Diffusers format") - controlnet.save_pretrained(args.output_path, max_shard_size="5GB") + controlnet.save_pretrained(args.output_path) if __name__ == "__main__": From 4b446b394150575b322836b048acd3eeeb2072a3 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:03:30 +0000 Subject: [PATCH 05/20] __init__ --- src/diffusers/__init__.py | 4 ++++ src/diffusers/models/__init__.py | 2 ++ src/diffusers/pipelines/__init__.py | 4 ++-- src/diffusers/pipelines/z_image/__init__.py | 2 ++ 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index eb8e86c4c89d..f45be1560716 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -277,6 +277,7 @@ "WanTransformer3DModel", "WanVACETransformer3DModel", "ZImageTransformer2DModel", + "ZImageControlNetModel", "attention_backend", ] ) @@ -661,6 +662,7 @@ "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", "ZImagePipeline", + "ZImageControlNetPipeline", ] ) @@ -1004,6 +1006,7 @@ WanTransformer3DModel, WanVACETransformer3DModel, ZImageTransformer2DModel, + ZImageControlNetModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1357,6 +1360,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ZImagePipeline, + ZImageControlNetPipeline, ) try: diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29d8b0b5a55d..7ea15ef2a215 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -66,6 +66,7 @@ _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"] _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] _import_structure["embeddings"] = ["ImageProjection"] @@ -180,6 +181,7 @@ SD3MultiControlNetModel, SparseControlNetModel, UNetControlNetXSModel, + ZImageControlNetModel, ) from .embeddings import ImageProjection from .modeling_utils import ModelMixin diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3d669aecf556..fe6af5cd1e0b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -404,7 +404,7 @@ "Kandinsky5T2IPipeline", "Kandinsky5I2IPipeline", ] - _import_structure["z_image"] = ["ZImagePipeline"] + _import_structure["z_image"] = ["ZImagePipeline", "ZImageControlNetPipeline"] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -841,7 +841,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .z_image import ZImagePipeline + from .z_image import ZImagePipeline, ZImageControlNetPipeline try: if not is_onnx_available(): diff --git a/src/diffusers/pipelines/z_image/__init__.py b/src/diffusers/pipelines/z_image/__init__.py index f95b3e5a0bed..842d5690e3d7 100644 --- a/src/diffusers/pipelines/z_image/__init__.py +++ b/src/diffusers/pipelines/z_image/__init__.py @@ -23,6 +23,7 @@ else: _import_structure["pipeline_output"] = ["ZImagePipelineOutput"] _import_structure["pipeline_z_image"] = ["ZImagePipeline"] + _import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -35,6 +36,7 @@ else: from .pipeline_output import ZImagePipelineOutput from .pipeline_z_image import ZImagePipeline + from .pipeline_z_image_controlnet import ZImageControlNetPipeline else: import sys From a1ff390ecebb5afb9f6282209526adfdaf31c5d5 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:03:37 +0000 Subject: [PATCH 06/20] pipeline example --- .../pipelines/z_image/pipeline_z_image_controlnet.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 44906a0db519..ae81105eea27 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -36,9 +36,13 @@ Examples: ```py >>> import torch - >>> from diffusers import ZImagePipeline + >>> from diffusers import ZImageControlNetPipeline + >>> from diffusers import ZImageControlNetModel - >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> controlnet_model = "..." + >>> controlnet = ZImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) + + >>> pipe = ZImageControlNetPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. @@ -47,9 +51,11 @@ >>> # (2) Use flash attention 3 >>> # pipe.transformer.set_attention_backend("_flash_3") - >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> prompt = "A girl in city, 25 years old, cool, futuristic" >>> image = pipe( ... prompt, + ... control_image=control_image, ... height=1024, ... width=1024, ... num_inference_steps=9, From 7ab347d812a5b78076f79541f4046d59948d0464 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:04:01 +0000 Subject: [PATCH 07/20] ruff --- scripts/convert_z_image_controlnet_to_diffusers.py | 2 +- src/diffusers/__init__.py | 4 ++-- src/diffusers/models/controlnets/controlnet_z_image.py | 7 ++++++- src/diffusers/pipelines/__init__.py | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py index aed27c14f205..e5d5f34e36e8 100644 --- a/scripts/convert_z_image_controlnet_to_diffusers.py +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -1,8 +1,8 @@ import argparse from contextlib import nullcontext -import torch import safetensors.torch +import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f45be1560716..398f72167ad3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1005,8 +1005,8 @@ WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, - ZImageTransformer2DModel, ZImageControlNetModel, + ZImageTransformer2DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1359,8 +1359,8 @@ WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, - ZImagePipeline, ZImageControlNetPipeline, + ZImagePipeline, ) try: diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index b76a2c54c3d8..ff148781f49a 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -438,7 +438,12 @@ def forward( control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) c = control_context_unified - new_kwargs = {"x": unified, "attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input} + new_kwargs = { + "x": unified, + "attn_mask": unified_attn_mask, + "freqs_cis": unified_freqs_cis, + "adaln_input": adaln_input, + } for layer in self.control_layers: if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index fe6af5cd1e0b..10ce49fe8111 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -841,7 +841,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .z_image import ZImagePipeline, ZImageControlNetPipeline + from .z_image import ZImageControlNetPipeline, ZImagePipeline try: if not is_onnx_available(): From 8cab0c953c7b732324a92d4e5de067b8bb290a5d Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:05:36 +0000 Subject: [PATCH 08/20] pipeline load_image --- src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index ae81105eea27..67771dddabd7 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -38,6 +38,7 @@ >>> import torch >>> from diffusers import ZImageControlNetPipeline >>> from diffusers import ZImageControlNetModel + >>> from diffusers.utils import load_image >>> controlnet_model = "..." >>> controlnet = ZImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) From 8688fa66a110bd0d77df8d03299fc3a42130ce07 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:24:05 +0000 Subject: [PATCH 09/20] t_scale --- src/diffusers/models/controlnets/controlnet_z_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index ff148781f49a..070724a85883 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -305,7 +305,7 @@ def forward( bsz = len(x) device = x[0].device - t = t * self.t_scale + t = t * transformer.t_scale t = transformer.t_embedder(t) ( From 9051272d47082c5cf6bc409b6368332cdac16f97 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:28:26 +0000 Subject: [PATCH 10/20] x_pad_token --- src/diffusers/models/controlnets/controlnet_z_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 070724a85883..48b9a66a25a3 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -409,7 +409,7 @@ def forward( # Match t_embedder output dtype to control_context for layerwise casting compatibility adaln_input = t.type_as(control_context) - control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token control_context = list(control_context.split(x_item_seqlens, dim=0)) x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) From 0d8c3f1a28180fc85fc9a4e0696d5f4f11def56f Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:28:34 +0000 Subject: [PATCH 11/20] controlnet_block_samples --- src/diffusers/models/controlnets/controlnet_z_image.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 48b9a66a25a3..0127f7f9683f 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -452,10 +452,6 @@ def forward( c = layer(c, **new_kwargs) hints = torch.unbind(c)[:-1] * conditioning_scale - controlnet_block_samples = {} - for layer_idx in range(self.n_layers): - if layer_idx in self.control_layers_places: - hints_idx = self.control_layers_places.index(layer_idx) - controlnet_block_samples[layer_idx] = hints[hints_idx] + controlnet_block_samples = {layer_idx: hints[idx] for idx, layer_idx in enumerate(self.control_layers_places)} return controlnet_block_samples From f789325ccd8f3f6fb35dffdd4acea6f21f30084e Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:29:54 +0000 Subject: [PATCH 12/20] conditioning_scale --- src/diffusers/models/controlnets/controlnet_z_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 0127f7f9683f..3a200b252a01 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -451,7 +451,7 @@ def forward( else: c = layer(c, **new_kwargs) - hints = torch.unbind(c)[:-1] * conditioning_scale - controlnet_block_samples = {layer_idx: hints[idx] for idx, layer_idx in enumerate(self.control_layers_places)} + hints = torch.unbind(c)[:-1] + controlnet_block_samples = {layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places)} return controlnet_block_samples From 5f8ab7bf98549ff6bdc63db500ad1433b7cf84e2 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:33:36 +0000 Subject: [PATCH 13/20] self.config --- src/diffusers/models/controlnets/controlnet_z_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 3a200b252a01..d0f8b861e0c9 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -300,8 +300,8 @@ def forward( f_patch_size=1, conditioning_scale: float = 1.0, ): - assert patch_size in self.all_patch_size - assert f_patch_size in self.all_f_patch_size + assert patch_size in self.config.all_patch_size + assert f_patch_size in self.config.all_f_patch_size bsz = len(x) device = x[0].device From bc72f9ce93ca691018fb8f1b420684b6e18a6d55 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 20:08:24 +0000 Subject: [PATCH 14/20] sample_mode, default controlnet_conditioning_scale --- .../pipelines/z_image/pipeline_z_image_controlnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 67771dddabd7..d0460cf09244 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -360,7 +360,7 @@ def __call__( sigmas: Optional[List[float]] = None, guidance_scale: float = 5.0, control_image: PipelineImageInput = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + controlnet_conditioning_scale: Union[float, List[float]] = 0.75, cfg_normalization: bool = False, cfg_truncation: float = 1.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -517,7 +517,7 @@ def __call__( dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] - control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax") control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image = control_image.unsqueeze(2) From 13b706a99f209197352bcb8790260727d75b2b9b Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 20:16:49 +0000 Subject: [PATCH 15/20] ruff --- src/diffusers/models/controlnets/controlnet_z_image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index d0f8b861e0c9..c121f42c1a78 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -452,6 +452,8 @@ def forward( c = layer(c, **new_kwargs) hints = torch.unbind(c)[:-1] - controlnet_block_samples = {layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places)} + controlnet_block_samples = { + layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) + } return controlnet_block_samples From 09849a77465e49f6d1ca056638d917b6937f4b95 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Dec 2025 20:48:16 +0000 Subject: [PATCH 16/20] ZImageControlTransformer2DModel --- src/diffusers/__init__.py | 2 + src/diffusers/loaders/peft.py | 1 + src/diffusers/models/__init__.py | 2 + .../models/controlnets/controlnet_z_image.py | 360 +------- src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_z_image.py | 7 - .../transformer_z_image_control.py | 784 ++++++++++++++++++ .../z_image/pipeline_z_image_controlnet.py | 18 +- 8 files changed, 808 insertions(+), 367 deletions(-) create mode 100644 src/diffusers/models/transformers/transformer_z_image_control.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 398f72167ad3..746021bfd706 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -277,6 +277,7 @@ "WanTransformer3DModel", "WanVACETransformer3DModel", "ZImageTransformer2DModel", + "ZImageControlTransformer2DModel", "ZImageControlNetModel", "attention_backend", ] @@ -1006,6 +1007,7 @@ WanTransformer3DModel, WanVACETransformer3DModel, ZImageControlNetModel, + ZImageControlTransformer2DModel, ZImageTransformer2DModel, attention_backend, ) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3f8519bbfa32..62182f2d205f 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -65,6 +65,7 @@ "QwenImageTransformer2DModel": lambda model_cls, weights: weights, "Flux2Transformer2DModel": lambda model_cls, weights: weights, "ZImageTransformer2DModel": lambda model_cls, weights: weights, + "ZImageControlTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ea15ef2a215..48d06def1c3c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -117,6 +117,7 @@ _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"] + _import_structure["transformers.transformer_z_image_control"] = ["ZImageControlTransformer2DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -231,6 +232,7 @@ WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, + ZImageControlTransformer2DModel, ZImageTransformer2DModel, ) from .unets import ( diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index c121f42c1a78..0972fb46c07b 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -12,19 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional import torch import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ..controlnets.controlnet import zero_module from ..modeling_utils import ModelMixin from ..transformers.transformer_z_image import ( - SEQ_MULTI_OF, - ZImageTransformer2DModel, ZImageTransformerBlock, ) @@ -47,7 +44,14 @@ def __init__( self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) - def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs): + def forward( + self, + c: torch.Tensor, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): if self.block_id == 0: c = self.before_proj(c) + x all_c = [] @@ -55,7 +59,7 @@ def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs): all_c = list(torch.unbind(c)) c = all_c.pop(-1) - c = super().forward(c, **kwargs) + c = super().forward(c, attn_mask, freqs_cis, adaln_input) c_skip = self.after_proj(c) all_c += [c_skip, c] c = torch.stack(all_c) @@ -115,345 +119,5 @@ def __init__( ] ) - @staticmethod - def create_coordinate_grid(size, start=None, device=None): - if start is None: - start = (0 for _ in size) - - axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] - grids = torch.meshgrid(axes, indexing="ij") - return torch.stack(grids, dim=-1) - - def patchify( - self, - all_image: List[torch.Tensor], - patch_size: int, - f_patch_size: int, - cap_padding_len: int, - ): - pH = pW = patch_size - pF = f_patch_size - device = all_image[0].device - - all_image_out = [] - all_image_size = [] - all_image_pos_ids = [] - all_image_pad_mask = [] - - for i, image in enumerate(all_image): - ### Process Image - C, F, H, W = image.size() - all_image_size.append((F, H, W)) - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW - - image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) - # "c f pf h ph w pw -> (f h w) (pf ph pw c)" - image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) - - image_ori_len = len(image) - image_padding_len = (-image_ori_len) % SEQ_MULTI_OF - - image_ori_pos_ids = self.create_coordinate_grid( - size=(F_tokens, H_tokens, W_tokens), - start=(cap_padding_len + 1, 0, 0), - device=device, - ).flatten(0, 2) - image_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(image_padding_len, 1) - ) - image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) - all_image_pos_ids.append(image_padded_pos_ids) - # pad mask - all_image_pad_mask.append( - torch.cat( - [ - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), - torch.ones((image_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - ) - # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) - all_image_out.append(image_padded_feat) - - return ( - all_image_out, - all_image_size, - all_image_pos_ids, - all_image_pad_mask, - ) - - def patchify_and_embed( - self, - all_image: List[torch.Tensor], - all_cap_feats: List[torch.Tensor], - patch_size: int, - f_patch_size: int, - ): - pH = pW = patch_size - pF = f_patch_size - device = all_image[0].device - - all_image_out = [] - all_image_size = [] - all_image_pos_ids = [] - all_image_pad_mask = [] - all_cap_pos_ids = [] - all_cap_pad_mask = [] - all_cap_feats_out = [] - - for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): - ### Process Caption - cap_ori_len = len(cap_feat) - cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF - # padded position ids - cap_padded_pos_ids = self.create_coordinate_grid( - size=(cap_ori_len + cap_padding_len, 1, 1), - start=(1, 0, 0), - device=device, - ).flatten(0, 2) - all_cap_pos_ids.append(cap_padded_pos_ids) - # pad mask - all_cap_pad_mask.append( - torch.cat( - [ - torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), - torch.ones((cap_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - ) - # padded feature - cap_padded_feat = torch.cat( - [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], - dim=0, - ) - all_cap_feats_out.append(cap_padded_feat) - - ### Process Image - C, F, H, W = image.size() - all_image_size.append((F, H, W)) - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW - - image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) - # "c f pf h ph w pw -> (f h w) (pf ph pw c)" - image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) - - image_ori_len = len(image) - image_padding_len = (-image_ori_len) % SEQ_MULTI_OF - - image_ori_pos_ids = self.create_coordinate_grid( - size=(F_tokens, H_tokens, W_tokens), - start=(cap_ori_len + cap_padding_len + 1, 0, 0), - device=device, - ).flatten(0, 2) - image_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(image_padding_len, 1) - ) - image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) - all_image_pos_ids.append(image_padded_pos_ids) - # pad mask - all_image_pad_mask.append( - torch.cat( - [ - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), - torch.ones((image_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - ) - # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) - all_image_out.append(image_padded_feat) - - return ( - all_image_out, - all_cap_feats_out, - all_image_size, - all_image_pos_ids, - all_cap_pos_ids, - all_image_pad_mask, - all_cap_pad_mask, - ) - - def forward( - self, - transformer: ZImageTransformer2DModel, - x: List[torch.Tensor], - cap_feats: List[torch.Tensor], - control_context: List[torch.Tensor], - t=None, - patch_size=2, - f_patch_size=1, - conditioning_scale: float = 1.0, - ): - assert patch_size in self.config.all_patch_size - assert f_patch_size in self.config.all_f_patch_size - - bsz = len(x) - device = x[0].device - t = t * transformer.t_scale - t = transformer.t_embedder(t) - - ( - x, - cap_feats, - x_size, - x_pos_ids, - cap_pos_ids, - x_inner_pad_mask, - cap_inner_pad_mask, - ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) - - # x embed & refine - x_item_seqlens = [len(_) for _ in x] - assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) - x_max_item_seqlen = max(x_item_seqlens) - - x = torch.cat(x, dim=0) - x = transformer.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) - - # Match t_embedder output dtype to x for layerwise casting compatibility - adaln_input = t.type_as(x) - x[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token - x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) - - x = pad_sequence(x, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in transformer.noise_refiner: - x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) - else: - for layer in transformer.noise_refiner: - x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) - - # cap embed & refine - cap_item_seqlens = [len(_) for _ in cap_feats] - assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) - cap_max_item_seqlen = max(cap_item_seqlens) - - cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = transformer.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = transformer.cap_pad_token - cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list(transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) - - cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) - cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(cap_item_seqlens): - cap_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in transformer.context_refiner: - cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) - else: - for layer in transformer.context_refiner: - cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) - - # unified - unified = [] - unified_freqs_cis = [] - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) - unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) - unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] - assert unified_item_seqlens == [len(_) for _ in unified] - unified_max_item_seqlen = max(unified_item_seqlens) - - unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) - unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_item_seqlens): - unified_attn_mask[i, :seq_len] = 1 - - ## Original forward_control - - # embeddings - bsz = len(control_context) - device = control_context[0].device - ( - control_context, - x_size, - x_pos_ids, - x_inner_pad_mask, - ) = self.patchify(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) - - # control_context embed & refine - x_item_seqlens = [len(_) for _ in control_context] - assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) - x_max_item_seqlen = max(x_item_seqlens) - - control_context = torch.cat(control_context, dim=0) - control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) - - # Match t_embedder output dtype to control_context for layerwise casting compatibility - adaln_input = t.type_as(control_context) - control_context[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token - control_context = list(control_context.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) - - control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.control_noise_refiner: - control_context = self._gradient_checkpointing_func( - layer, control_context, x_attn_mask, x_freqs_cis, adaln_input - ) - else: - for layer in self.control_noise_refiner: - control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) - - # unified - cap_item_seqlens = [len(_) for _ in cap_feats] - control_context_unified = [] - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) - control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) - c = control_context_unified - - new_kwargs = { - "x": unified, - "attn_mask": unified_attn_mask, - "freqs_cis": unified_freqs_cis, - "adaln_input": adaln_input, - } - - for layer in self.control_layers: - if torch.is_grad_enabled() and self.gradient_checkpointing: - c = self._gradient_checkpointing_func(layer, c, **new_kwargs) - else: - c = layer(c, **new_kwargs) - - hints = torch.unbind(c)[:-1] - controlnet_block_samples = { - layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) - } - - return controlnet_block_samples + def forward(self): + pass diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a42f6b2716e1..13322aa29ae4 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -48,3 +48,4 @@ from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel from .transformer_z_image import ZImageTransformer2DModel + from .transformer_z_image_control import ZImageControlTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 70ffced8b63a..7c01361b681d 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -538,7 +538,6 @@ def forward( cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, - controlnet_block_samples: Optional[dict[int, torch.Tensor]] = None, return_dict: bool = True, ): assert patch_size in self.all_patch_size @@ -640,15 +639,9 @@ def forward( unified = self._gradient_checkpointing_func( layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input ) - if controlnet_block_samples is not None: - if layer_idx in controlnet_block_samples: - unified = unified + controlnet_block_samples[layer_idx] else: for layer_idx, layer in enumerate(self.layers): unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) - if controlnet_block_samples is not None: - if layer_idx in controlnet_block_samples: - unified = unified + controlnet_block_samples[layer_idx] unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = list(unified.unbind(dim=0)) diff --git a/src/diffusers/models/transformers/transformer_z_image_control.py b/src/diffusers/models/transformers/transformer_z_image_control.py new file mode 100644 index 000000000000..61b752f58f68 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_z_image_control.py @@ -0,0 +1,784 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.attention_processor import Attention +from ...models.modeling_utils import ModelMixin +from ...models.normalization import RMSNorm +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_dispatch import dispatch_attention_fn +from ..modeling_outputs import Transformer2DModelOutput +from .transformer_z_image import ZImageTransformer2DModel + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, mid_size, bias=True), + nn.SiLU(), + nn.Linear(mid_size, out_size, bias=True), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + compute_dtype = getattr(self.mlp[0], "compute_dtype", None) + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + elif compute_dtype is not None: + t_freq = t_freq.to(compute_dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class ZSingleStreamAttnProcessor: + """ + Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the + original Z-ImageAttention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: # dropout + output = attn.to_out[1](output) + + return output + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +@maybe_allow_in_graph +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // n_heads, + heads=n_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-5, + bias=False, + out_bias=False, + processor=ZSingleStreamAttnProcessor(), + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c): + scale = 1.0 + self.adaLN_modulation(c) + x = self.norm_final(x) * scale.unsqueeze(1) + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) + + +class ZImageControlTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + _no_split_modules = ["ZImageTransformerBlock", "ZImageControlTransformerBlock"] + _repeated_blocks = ["ZImageTransformerBlock", "ZImageControlTransformerBlock"] + _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers + + @register_to_config + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + control_layers_places: List[int] = None, + control_in_dim=None, + ) -> None: + super().__init__() + from ...models.controlnets.controlnet_z_image import ZImageControlTransformerBlock + + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + head_dim = dim // n_heads + assert head_dim == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + self.control_layers_places = control_layers_places + self.control_in_dim = control_in_dim + + assert 0 in self.control_layers_places + + # control blocks + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) + for i in self.control_layers_places + ] + ) + + # control patch embeddings + all_x_embedder = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + self.control_noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + @classmethod + def from_controlnet( + cls, + transformer: ZImageTransformer2DModel, + controlnet, + load_weights: bool = True, + ): + controlnet.to(device=transformer.device) + + if transformer.config["dim"] != controlnet.config["dim"]: + raise ValueError("Incompatible ControlNet, got a different dim.") + + config = dict(transformer.config) + config["_class_name"] = cls.__name__ + + config["control_layers_places"] = controlnet.config["control_layers_places"] + config["control_in_dim"] = controlnet.config["control_in_dim"] + + expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) + config = FrozenDict({k: config.get(k) for k in config if k in expected_kwargs or k in optional_kwargs}) + config["_class_name"] = cls.__name__ + model = cls.from_config(config) + + if not load_weights: + return model + + for i, control_layer in enumerate(controlnet.control_layers): + model.control_layers[i].load_state_dict(control_layer.state_dict()) + + for i, control_all_x_embedder in enumerate(controlnet.control_all_x_embedder): + model.control_all_x_embedder[i].load_state_dict(control_all_x_embedder.state_dict()) + + for i, control_noise_refiner in enumerate(controlnet.control_noise_refiner): + model.control_noise_refiner[i].load_state_dict(control_noise_refiner.state_dict()) + + model.to(transformer.dtype) + + return model + + def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + cap_pad_mask = torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + all_cap_pad_mask.append( + cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) + ) + + # padded feature + cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padded_pos_ids = torch.cat( + [ + image_ori_pos_ids, + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(image_padding_len, 1), + ], + dim=0, + ) + all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids) + # pad mask + image_pad_mask = torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + all_image_pad_mask.append( + image_pad_mask + if image_padding_len > 0 + else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) + ) + # padded feature + image_padded_feat = torch.cat( + [image, image[-1:].repeat(image_padding_len, 1)], + dim=0, + ) + all_image_out.append(image_padded_feat if image_padding_len > 0 else image) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + all_image_pos_ids, + all_cap_pos_ids, + all_image_pad_mask, + all_cap_pad_mask, + ) + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + patch_size=2, + f_patch_size=1, + control_context: Optional[List[torch.Tensor]] = None, + conditioning_scale: float = 1.0, + return_dict: bool = True, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) + else: + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list( + self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + ) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + ## ControlNet start + + controlnet_block_samples = None + if control_context is not None: + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context = list(control_context.split(x_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.control_noise_refiner: + control_context = self._gradient_checkpointing_func( + layer, control_context, x_attn_mask, x_freqs_cis, adaln_input + ) + else: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) + + # unified + control_context_unified = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) + control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + + for layer in self.control_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + control_context_unified = self._gradient_checkpointing_func( + layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input + ) + else: + control_context_unified = layer( + control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input + ) + + hints = torch.unbind(control_context_unified)[:-1] + controlnet_block_samples = { + layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) + } + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer_idx, layer in enumerate(self.layers): + unified = self._gradient_checkpointing_func( + layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input + ) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] + else: + for layer_idx, layer in enumerate(self.layers): + unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) + unified = list(unified.unbind(dim=0)) + x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + + if not return_dict: + return (x,) + + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index d0460cf09244..2faea94fe134 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -22,7 +22,7 @@ from ...loaders import FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets import ZImageControlNetModel -from ...models.transformers import ZImageTransformer2DModel +from ...models.transformers import ZImageControlTransformer2DModel, ZImageTransformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring @@ -167,10 +167,12 @@ def __init__( vae: AutoencoderKL, text_encoder: PreTrainedModel, tokenizer: AutoTokenizer, - transformer: ZImageTransformer2DModel, + transformer: Union[ZImageControlTransformer2DModel, ZImageTransformer2DModel], controlnet: ZImageControlNetModel, ): super().__init__() + if isinstance(transformer, ZImageTransformer2DModel): + transformer = ZImageControlTransformer2DModel.from_controlnet(transformer, controlnet) self.register_modules( vae=vae, @@ -599,20 +601,12 @@ def __call__( latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) - controlnet_block_samples = self.controlnet( - self.transformer, - latent_model_input_list, - prompt_embeds_model_input, - control_image, - timestep_model_input, - conditioning_scale=controlnet_conditioning_scale, - ) - model_out_list = self.transformer( latent_model_input_list, timestep_model_input, prompt_embeds_model_input, - controlnet_block_samples=controlnet_block_samples, + control_context=control_image, + conditioning_scale=controlnet_conditioning_scale, )[0] if apply_cfg: From f63a5a8ddf474ae76e04bbf98f9cb63a7d2f1c24 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Dec 2025 21:11:44 +0000 Subject: [PATCH 17/20] ModuleDict --- .../models/transformers/transformer_z_image_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image_control.py b/src/diffusers/models/transformers/transformer_z_image_control.py index 61b752f58f68..90c5bd6c7f24 100644 --- a/src/diffusers/models/transformers/transformer_z_image_control.py +++ b/src/diffusers/models/transformers/transformer_z_image_control.py @@ -471,8 +471,8 @@ def from_controlnet( for i, control_layer in enumerate(controlnet.control_layers): model.control_layers[i].load_state_dict(control_layer.state_dict()) - for i, control_all_x_embedder in enumerate(controlnet.control_all_x_embedder): - model.control_all_x_embedder[i].load_state_dict(control_all_x_embedder.state_dict()) + for key, control_all_x_embedder in controlnet.control_all_x_embedder.items(): + model.control_all_x_embedder[key].load_state_dict(control_all_x_embedder.state_dict()) for i, control_noise_refiner in enumerate(controlnet.control_noise_refiner): model.control_noise_refiner[i].load_state_dict(control_noise_refiner.state_dict()) From f9540cbb14e02e4498e123ddbe9d0fd6031ed830 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Dec 2025 21:13:07 +0000 Subject: [PATCH 18/20] patchify control_context --- .../transformer_z_image_control.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_z_image_control.py b/src/diffusers/models/transformers/transformer_z_image_control.py index 90c5bd6c7f24..9d48bb15cc6a 100644 --- a/src/diffusers/models/transformers/transformer_z_image_control.py +++ b/src/diffusers/models/transformers/transformer_z_image_control.py @@ -610,6 +610,34 @@ def patchify_and_embed( all_cap_pad_mask, ) + def patchify( + self, + all_image: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + all_image_out = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return all_image_out + def forward( self, x: List[torch.Tensor], @@ -719,6 +747,7 @@ def forward( controlnet_block_samples = None if control_context is not None: + control_context = self.patchify(control_context, patch_size, f_patch_size) control_context = torch.cat(control_context, dim=0) control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) From 3e472ac4a43dd7082dbe7d0851a0705afa72f0aa Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Dec 2025 21:23:36 +0000 Subject: [PATCH 19/20] transformer weights --- .../transformer_z_image_control.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_z_image_control.py b/src/diffusers/models/transformers/transformer_z_image_control.py index 9d48bb15cc6a..55116bb75690 100644 --- a/src/diffusers/models/transformers/transformer_z_image_control.py +++ b/src/diffusers/models/transformers/transformer_z_image_control.py @@ -468,6 +468,29 @@ def from_controlnet( if not load_weights: return model + for key, all_x_embedder in transformer.all_x_embedder.items(): + model.all_x_embedder[key].load_state_dict(all_x_embedder.state_dict()) + + for key, all_final_layer in transformer.all_final_layer.items(): + model.all_final_layer[key].load_state_dict(all_final_layer.state_dict()) + + for i, noise_refiner in enumerate(transformer.noise_refiner): + model.noise_refiner[i].load_state_dict(noise_refiner.state_dict()) + + for i, context_refiner in enumerate(transformer.context_refiner): + model.context_refiner[i].load_state_dict(context_refiner.state_dict()) + + model.t_embedder.load_state_dict(transformer.t_embedder.state_dict()) + + model.cap_embedder.load_state_dict(transformer.cap_embedder.state_dict()) + + model.x_pad_token = transformer.x_pad_token + + model.cap_pad_token = transformer.cap_pad_token + + for i, layer in enumerate(transformer.layers): + model.layers[i].load_state_dict(layer.state_dict()) + for i, control_layer in enumerate(controlnet.control_layers): model.control_layers[i].load_state_dict(control_layer.state_dict()) From 0e7c643f02b523dd37b3b598e2b16c4c6839f4e9 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Dec 2025 21:30:08 +0000 Subject: [PATCH 20/20] -enumerate in ZImageTransformer2DModel --- src/diffusers/models/transformers/transformer_z_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 7c01361b681d..5c401b9d202b 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -635,12 +635,12 @@ def forward( unified_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer_idx, layer in enumerate(self.layers): + for layer in self.layers: unified = self._gradient_checkpointing_func( layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input ) else: - for layer_idx, layer in enumerate(self.layers): + for layer in self.layers: unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)