-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When using the .from_transformer classmethod of the ZImageControlNetModel, the layers that are copied from the transformer are shallow copies. This means that any changes done to the controlnet weights after copying will also affect the transformer weights.
I'm unclear on what the best solution would be. Looking at FluxControlNetModen.from_transformer, they have already prepared the layers in the __init__ function, and therefore can directly load in using state dicts, but this is harder here, given how the __init__ and from_transformer functions have been made in ZImageControlNetModel. Short term, I've been able to circumvent it by using copy.deepcopy(...), but I think something more elegant is possible here.
Reproduction
import torch
from diffusers import ZImageControlNetModel, ZImageTransformer2DModel
transformer = ZImageTransformer2DModel.from_pretrained(
"Tongyi-MAI/Z-Image",
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
controlnet = ZImageControlNetModel(
control_layers_places=[0, 15, 29], # These parameters are not relevant, just for instantiation
control_refiner_layers_places=[0, 1], # ^
add_control_noise_refiner="control_noise_refiner", # ^
control_in_dim=16 # ^
)
controlnet = ZImageControlNetModel.from_transformer(
controlnet=controlnet,
transformer=transformer
)
# Weights before modification
print("Transformer t_embedder weights before modification:")
print(transformer.t_embedder.mlp[0].weight)
# Fill ControlNet t_embedder weights with constant value
torch.nn.init.constant_(controlnet.t_embedder.mlp[0].weight, 42.0)
print("ControlNet t_embedder weights after modification:")
print(controlnet.t_embedder.mlp[0].weight) # As expected, filled with 42s
# Should remain unchanged, but is also filled with 42s
print("Transformer t_embedder weights after ControlNet modification:")
print(transformer.t_embedder.mlp[0].weight)Logs
Transformer t_embedder weights before modification:
Parameter containing:
tensor([[ 0.0069, -0.0051, 0.0036, ..., 0.0325, -0.0153, -0.0089],
[ 0.0004, -0.0102, 0.0078, ..., 0.0262, 0.0229, 0.0505],
[-0.0024, 0.0032, -0.0040, ..., 0.0237, 0.0195, 0.0017],
...,
[-0.0059, -0.0013, 0.0008, ..., 0.0223, 0.0209, 0.0132],
[ 0.0026, -0.0014, -0.0052, ..., 0.0021, 0.0142, 0.0630],
[-0.0004, 0.0034, 0.0032, ..., 0.0339, 0.0187, 0.0254]],
dtype=torch.bfloat16, requires_grad=True)
ControlNet t_embedder weights after modification:
Parameter containing:
tensor([[42., 42., 42., ..., 42., 42., 42.],
[42., 42., 42., ..., 42., 42., 42.],
[42., 42., 42., ..., 42., 42., 42.],
...,
[42., 42., 42., ..., 42., 42., 42.],
[42., 42., 42., ..., 42., 42., 42.],
[42., 42., 42., ..., 42., 42., 42.]], dtype=torch.bfloat16,
requires_grad=True)
Transformer t_embedder weights after ControlNet modification:
Parameter containing:
tensor([[42., 42., 42., ..., 42., 42., 42.],
[42., 42., 42., ..., 42., 42., 42.],
[42., 42., 42., ..., 42., 42., 42.],
...,
[42., 42., 42., ..., 42., 42., 42.],
[42., 42., 42., ..., 42., 42., 42.],
[42., 42., 42., ..., 42., 42., 42.]], dtype=torch.bfloat16,
requires_grad=True)System Info
- 🤗 Diffusers version: 0.37.0.dev0
- Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.11.14
- PyTorch version (GPU?): 2.9.0+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.35.3
- Transformers version: 4.57.1
- Accelerate version: 1.11.0
- PEFT version: 0.17.1
- Bitsandbytes version: 0.48.2
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NVIDIA RTX 6000 Ada Generation, 49140 MiB
NVIDIA RTX A6000, 49140 MiB
NVIDIA RTX A6000, 49140 MiB - Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help?
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working