Skip to content

ZImageControlNet.from_transformer creates a shallow copy of the transformer weights #13077

@Christophernph

Description

@Christophernph

Describe the bug

When using the .from_transformer classmethod of the ZImageControlNetModel, the layers that are copied from the transformer are shallow copies. This means that any changes done to the controlnet weights after copying will also affect the transformer weights.

I'm unclear on what the best solution would be. Looking at FluxControlNetModen.from_transformer, they have already prepared the layers in the __init__ function, and therefore can directly load in using state dicts, but this is harder here, given how the __init__ and from_transformer functions have been made in ZImageControlNetModel. Short term, I've been able to circumvent it by using copy.deepcopy(...), but I think something more elegant is possible here.

Reproduction

import torch
from diffusers import ZImageControlNetModel, ZImageTransformer2DModel

transformer = ZImageTransformer2DModel.from_pretrained(
    "Tongyi-MAI/Z-Image",
    subfolder="transformer",
    torch_dtype=torch.bfloat16,   
)

controlnet = ZImageControlNetModel(
    control_layers_places=[0, 15, 29],                  # These parameters are not relevant, just for instantiation 
    control_refiner_layers_places=[0, 1],               # ^
    add_control_noise_refiner="control_noise_refiner",  # ^
    control_in_dim=16                                   # ^
)

controlnet = ZImageControlNetModel.from_transformer(
    controlnet=controlnet,
    transformer=transformer
)

# Weights before modification
print("Transformer t_embedder weights before modification:")
print(transformer.t_embedder.mlp[0].weight)

# Fill ControlNet t_embedder weights with constant value
torch.nn.init.constant_(controlnet.t_embedder.mlp[0].weight, 42.0)
print("ControlNet t_embedder weights after modification:")
print(controlnet.t_embedder.mlp[0].weight) # As expected, filled with 42s

# Should remain unchanged, but is also filled with 42s
print("Transformer t_embedder weights after ControlNet modification:")
print(transformer.t_embedder.mlp[0].weight)

Logs

Transformer t_embedder weights before modification:
Parameter containing:
tensor([[ 0.0069, -0.0051,  0.0036,  ...,  0.0325, -0.0153, -0.0089],
        [ 0.0004, -0.0102,  0.0078,  ...,  0.0262,  0.0229,  0.0505],
        [-0.0024,  0.0032, -0.0040,  ...,  0.0237,  0.0195,  0.0017],
        ...,
        [-0.0059, -0.0013,  0.0008,  ...,  0.0223,  0.0209,  0.0132],
        [ 0.0026, -0.0014, -0.0052,  ...,  0.0021,  0.0142,  0.0630],
        [-0.0004,  0.0034,  0.0032,  ...,  0.0339,  0.0187,  0.0254]],
       dtype=torch.bfloat16, requires_grad=True)
ControlNet t_embedder weights after modification:
Parameter containing:
tensor([[42., 42., 42.,  ..., 42., 42., 42.],
        [42., 42., 42.,  ..., 42., 42., 42.],
        [42., 42., 42.,  ..., 42., 42., 42.],
        ...,
        [42., 42., 42.,  ..., 42., 42., 42.],
        [42., 42., 42.,  ..., 42., 42., 42.],
        [42., 42., 42.,  ..., 42., 42., 42.]], dtype=torch.bfloat16,
       requires_grad=True)
Transformer t_embedder weights after ControlNet modification:
Parameter containing:
tensor([[42., 42., 42.,  ..., 42., 42., 42.],
        [42., 42., 42.,  ..., 42., 42., 42.],
        [42., 42., 42.,  ..., 42., 42., 42.],
        ...,
        [42., 42., 42.,  ..., 42., 42., 42.],
        [42., 42., 42.,  ..., 42., 42., 42.],
        [42., 42., 42.,  ..., 42., 42., 42.]], dtype=torch.bfloat16,
       requires_grad=True)

System Info

  • 🤗 Diffusers version: 0.37.0.dev0
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.11.14
  • PyTorch version (GPU?): 2.9.0+cu126 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.35.3
  • Transformers version: 4.57.1
  • Accelerate version: 1.11.0
  • PEFT version: 0.17.1
  • Bitsandbytes version: 0.48.2
  • Safetensors version: 0.6.2
  • xFormers version: not installed
  • Accelerator: NVIDIA RTX 6000 Ada Generation, 49140 MiB
    NVIDIA RTX A6000, 49140 MiB
    NVIDIA RTX A6000, 49140 MiB
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions