Skip to content

Commit cbe451a

Browse files
Merge pull request #355 from AI-Hypercomputer:wan-opt
PiperOrigin-RevId: 882762284
2 parents b433a51 + 62e0805 commit cbe451a

2 files changed

Lines changed: 20 additions & 12 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,18 +1081,28 @@ def __init__(
10811081
)
10821082

10831083
def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]:
1084-
dtype = xq.dtype
1085-
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
1086-
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
1084+
# 1. Extract cos and sin, keeping them in native bfloat16
1085+
cos = jnp.real(freqs_cis).astype(xq.dtype)
1086+
sin = jnp.imag(freqs_cis).astype(xq.dtype)
10871087

1088-
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
1089-
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
1088+
# 2. Reshape the last dimension into pairs
1089+
xq_reshaped = xq.reshape(*xq.shape[:-1], -1, 2)
1090+
xk_reshaped = xk.reshape(*xk.shape[:-1], -1, 2)
10901091

1091-
xq_out_complex = xq_ * freqs_cis
1092-
xk_out_complex = xk_ * freqs_cis
1092+
# 3. Unbind the pairs
1093+
xq_0, xq_1 = xq_reshaped[..., 0], xq_reshaped[..., 1]
1094+
xk_0, xk_1 = xk_reshaped[..., 0], xk_reshaped[..., 1]
10931095

1094-
xq_out = jnp.stack([jnp.real(xq_out_complex), jnp.imag(xq_out_complex)], axis=-1).reshape(xq.shape).astype(dtype)
1095-
xk_out = jnp.stack([jnp.real(xk_out_complex), jnp.imag(xk_out_complex)], axis=-1).reshape(xk.shape).astype(dtype)
1096+
# 4. Pure real arithmetic (XLA will fuse these instantly into FMA instructions)
1097+
xq_out_0 = xq_0 * cos - xq_1 * sin
1098+
xq_out_1 = xq_0 * sin + xq_1 * cos
1099+
1100+
xk_out_0 = xk_0 * cos - xk_1 * sin
1101+
xk_out_1 = xk_0 * sin + xk_1 * cos
1102+
1103+
# 5. Stack and reshape back to original
1104+
xq_out = jnp.stack([xq_out_0, xq_out_1], axis=-1).reshape(xq.shape)
1105+
xk_out = jnp.stack([xk_out_0, xk_out_1], axis=-1).reshape(xk.shape)
10961106

10971107
return xq_out, xk_out
10981108

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,5 @@ def layer_forward(hidden_states):
684684
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
685685
)
686686
hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6))
687-
hidden_states = jax.lax.collapse(hidden_states, 6, None)
688-
hidden_states = jax.lax.collapse(hidden_states, 4, 6)
689-
hidden_states = jax.lax.collapse(hidden_states, 2, 4)
687+
hidden_states = hidden_states.reshape(batch_size, -1, num_frames, height, width)
690688
return hidden_states

0 commit comments

Comments
 (0)