diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index f7693ec5d3ac..224b2c156927 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -106,7 +106,8 @@ def apply_rotary_emb( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): - x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + x1 = hidden_states[..., 0::2] + x2 = hidden_states[..., 1::2] cos = freqs_cos[..., 0::2] sin = freqs_sin[..., 1::2] out = torch.empty_like(hidden_states)