@@ -1081,18 +1081,28 @@ def __init__(
10811081 )
10821082
10831083 def _apply_rope (self , xq : jax .Array , xk : jax .Array , freqs_cis : jax .Array ) -> Tuple [jax .Array , jax .Array ]:
1084- dtype = xq . dtype
1085- reshape_xq = xq . astype ( jnp .float32 ). reshape ( * xq .shape [: - 1 ], - 1 , 2 )
1086- reshape_xk = xk . astype ( jnp .float32 ). reshape ( * xk . shape [: - 1 ], - 1 , 2 )
1084+ # 1. Extract cos and sin, keeping them in native bfloat16
1085+ cos = jnp .real ( freqs_cis ). astype ( xq .dtype )
1086+ sin = jnp .imag ( freqs_cis ). astype ( xq . dtype )
10871087
1088- xq_ = jax .lax .complex (reshape_xq [..., 0 ], reshape_xq [..., 1 ])
1089- xk_ = jax .lax .complex (reshape_xk [..., 0 ], reshape_xk [..., 1 ])
1088+ # 2. Reshape the last dimension into pairs
1089+ xq_reshaped = xq .reshape (* xq .shape [:- 1 ], - 1 , 2 )
1090+ xk_reshaped = xk .reshape (* xk .shape [:- 1 ], - 1 , 2 )
10901091
1091- xq_out_complex = xq_ * freqs_cis
1092- xk_out_complex = xk_ * freqs_cis
1092+ # 3. Unbind the pairs
1093+ xq_0 , xq_1 = xq_reshaped [..., 0 ], xq_reshaped [..., 1 ]
1094+ xk_0 , xk_1 = xk_reshaped [..., 0 ], xk_reshaped [..., 1 ]
10931095
1094- xq_out = jnp .stack ([jnp .real (xq_out_complex ), jnp .imag (xq_out_complex )], axis = - 1 ).reshape (xq .shape ).astype (dtype )
1095- xk_out = jnp .stack ([jnp .real (xk_out_complex ), jnp .imag (xk_out_complex )], axis = - 1 ).reshape (xk .shape ).astype (dtype )
1096+ # 4. Pure real arithmetic (XLA will fuse these instantly into FMA instructions)
1097+ xq_out_0 = xq_0 * cos - xq_1 * sin
1098+ xq_out_1 = xq_0 * sin + xq_1 * cos
1099+
1100+ xk_out_0 = xk_0 * cos - xk_1 * sin
1101+ xk_out_1 = xk_0 * sin + xk_1 * cos
1102+
1103+ # 5. Stack and reshape back to original
1104+ xq_out = jnp .stack ([xq_out_0 , xq_out_1 ], axis = - 1 ).reshape (xq .shape )
1105+ xk_out = jnp .stack ([xk_out_0 , xk_out_1 ], axis = - 1 ).reshape (xk .shape )
10961106
10971107 return xq_out , xk_out
10981108
0 commit comments