Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@
"WanVACETransformer3DModel",
"ZImageTransformer2DModel",
"attention_backend",
"NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP",
]
)
_import_structure["modular_pipelines"].extend(
Expand Down Expand Up @@ -663,6 +664,7 @@
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
"ZImagePipeline",
"NewbiePipeline",
]
)

Expand Down Expand Up @@ -1009,6 +1011,7 @@
WanVACETransformer3DModel,
ZImageTransformer2DModel,
attention_backend,
NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP,
)
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
from .optimization import (
Expand Down Expand Up @@ -1361,6 +1364,7 @@
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
ZImagePipeline,
NewbiePipeline,
)

try:
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
_import_structure["transformers.transformer_newbie"] = ["NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP"]

if is_flax_available():
_import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
Expand Down Expand Up @@ -230,6 +231,7 @@
WanTransformer3DModel,
WanVACETransformer3DModel,
ZImageTransformer2DModel,
NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP,
)
from .unets import (
I2VGenXLUNet,
Expand Down
54 changes: 54 additions & 0 deletions src/diffusers/models/components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import warnings

import torch
import torch.nn as nn

try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")

class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
"""
Forward pass through the RMSNorm layer.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.

"""
output = self._norm(x.float()).type_as(x)
return output * self.weight
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@
from .transformer_wan_animate import WanAnimateTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
from .transformer_z_image import ZImageTransformer2DModel
from .transformer_newbie import NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP
Loading