Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 192 additions & 12 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import functools
import math
from typing import Any, Callable, NamedTuple, Optional, Union, Sequence
from typing import Any, Callable, NamedTuple, Optional, Union, Sequence, Literal

import jax
import jax.numpy as jnp
Expand All @@ -38,12 +38,24 @@

ReshapeFn = Callable[[jax.Array], jax.Array]


_PRECONDITIONINGS = ['frobenius', 'spectral', 'aol', 'schatten']
_DEFAULT_NS_COEFFS = (3.4445, -4.7750, 2.0315)
_DION_NS_COEFFS = [
(4.0848, -6.8946, 2.9270),
(3.9505, -6.3029, 2.6377),
(3.7418, -5.5913, 2.3037),
(2.8769, -3.1427, 1.2046),
(2.8366, -3.0525, 1.2012),
]
_NS_COEFFS_PRESET_DICT = {
'standard': _DEFAULT_NS_COEFFS,
'dion': _DION_NS_COEFFS,
}


class MuonDimensionNumbers(NamedTuple):
"""Specification for which weight axes participate in matrix projection.
"""
Specification for which weight axes participate in matrix projection.

Muon defines an orthogonalization for 2D matrix weights for matrix-vector
products:
Expand Down Expand Up @@ -193,7 +205,37 @@ def update_fn(updates, state, params=None):
return base.GradientTransformation(base.init_empty_state, update_fn)


def _newton_schulz_iterator(x: jax.Array, coeffs: jax.Array) -> jax.Array:
def _aol_first_newton_schulz_iteration(
x: jax.Array,
coeffs: jax.Array,
eps: jax.typing.ArrayLike = 1e-8,
) -> jax.Array:
# Implements the first Newton-Schulz step with AOL preconditioning
# which allows for better orthogonalization performance.
a = x @ x.T.conj()
rescaling = jnp.clip(jnp.abs(a).sum(axis=-1), min=eps)
s = jnp.expand_dims(jax.lax.rsqrt(rescaling), -1)
x, a = x * s, a * s * s.transpose(-1, -2)
b = coeffs[1] * a + coeffs[2] * a @ a
return coeffs[0] * x + b @ x


def _schatten_first_newton_schulz_iteration(
x: jax.Array,
coeffs: jax.Array,
eps: jax.typing.ArrayLike = 1e-8,
) -> jax.Array:
# Implements the first Newton-Schulz step with Schatten-4 norm
# preconditioning which allows for better orthogonalization performance.
a = x @ x.T
rescaling = jnp.clip(jnp.linalg.norm(a, ord='fro', axis=(-2, -1)), min=eps)
s = jnp.expand_dims(jax.lax.rsqrt(rescaling), (0, -1))
x, a = x * s, a * s ** 2
b = coeffs[1] * a + coeffs[2] * a @ a
return coeffs[0] * x + b @ x


def _base_newton_schulz_iteration(x: jax.Array, coeffs: jax.Array) -> jax.Array:
# Implements Newton-Schulz step f(X) = c_0 X + c_1 (XX^T)X + c_2 (XX^T)^2X,
# with quintic form f(X) = c_0 X + (c_1 A + c_2 AA)X, where A = XX^T.
# The NS step has the property f(X) = f(X^T)^T. That is, we can get equivalent
Expand All @@ -204,10 +246,40 @@ def _newton_schulz_iterator(x: jax.Array, coeffs: jax.Array) -> jax.Array:
return coeffs[0] * x + b @ x


def _aol_ns_iterator(i, x, coeffs):
# Modified first step using AOL rescaling
return jax.lax.cond(
i == 0,
lambda x: _aol_first_newton_schulz_iteration(x, coeffs),
lambda x: _base_newton_schulz_iteration(x, coeffs),
x,
)


def _schatten_ns_iterator(i, x, coeffs):
# Modified first step using Schatten-4 norm rescaling
return jax.lax.cond(
i == 0,
lambda x: _schatten_first_newton_schulz_iteration(x, coeffs),
lambda x: _base_newton_schulz_iteration(x, coeffs),
x,
)


def _base_ns_iterator(i, x, coeffs):
return _base_newton_schulz_iteration(x, coeffs)


def orthogonalize_via_newton_schulz(
x: jax.Array,
ns_coeffs: jax.Array,
ns_steps: jax.typing.ArrayLike = 5,
preconditioning: Literal[
'frobenius',
'spectral',
'aol',
'schatten',
] = 'frobenius',
eps: jax.typing.ArrayLike = 1e-8,
dimension_numbers: MuonDimensionNumbers | None = None,
) -> jax.Array:
Expand All @@ -228,6 +300,7 @@ def orthogonalize_via_newton_schulz(
Must have shape (n, 3) where n is the number of iterations.
ns_steps: Number of Newton-schulz iterations.
Ignored if `ns_coeffs` is a 2D array.
preconditioning: Which preconditioning method to use.
eps: Term added to denominators to improve numerical stability.
dimension_numbers: Optional spec for reshaping a tensor before and after the
orthogonalization, to support non-2D parameters.
Expand All @@ -251,16 +324,38 @@ def _orthogonalize(x):
x = x.T
transposed = True

x /= jnp.linalg.norm(x) + eps # Ensure spectral norm is at most 1
ns_iterators = {
'frobenius': _base_ns_iterator,
'spectral': _base_ns_iterator,
'aol': _aol_ns_iterator,
'schatten': _schatten_ns_iterator,
}
if preconditioning not in _PRECONDITIONINGS:
raise ValueError(f'Unknown preconditioning {preconditioning}')
_ns_iterator = ns_iterators[preconditioning]

if preconditioning == 'frobenius':
x /= jnp.linalg.norm(x, ord='fro') + eps
elif preconditioning == 'spectral':
x /= jnp.linalg.norm(x, ord=2) + eps
else:
pass

ns_coeffs_ = ns_coeffs.astype(x.dtype)

if ns_coeffs_.ndim == 1:
x = jax.lax.fori_loop(
0, ns_steps, lambda _, x: _newton_schulz_iterator(x, ns_coeffs_), x,
0, ns_steps, lambda i, x: _ns_iterator(i, x, ns_coeffs_), x,
unroll=True) # Unroll to ensure efficient composition with jax.vmap.
else:
x, _ = jax.lax.scan(
lambda x, abc: (_newton_schulz_iterator(x, abc), None), x, ns_coeffs_
)
def _scan_body(carry, coeffs_step):
i, x = carry
x_new = _ns_iterator(i, x, coeffs_step)
return (i + 1, x_new), None

init_carry = (jnp.asarray(0, dtype=jnp.int32), x)
(_, x), _ = jax.lax.scan(_scan_body, init_carry, ns_coeffs_)

if transposed:
x = x.T
return x
Expand Down Expand Up @@ -289,6 +384,12 @@ def scale_by_muon(
*,
nesterov: bool = True,
adaptive: bool = False,
preconditioning: Literal[
'frobenius',
'spectral',
'aol',
'schatten',
] = 'frobenius',
weight_dimension_numbers: WeightDimNumOrFn | None = None,
) -> base.GradientTransformation:
r"""Rescale updates according to the Muon algorithm.
Expand All @@ -309,6 +410,12 @@ def scale_by_muon(
nesterov: Whether to use Nesterov momentum.
adaptive: Whether to scale the updates by the dual norm of the
original updates. See <https://arxiv.org/abs/2409.20325>
preconditioning: What type of preconditioning to use before NS iterations.
Available options are:
Comment thread
massena-t marked this conversation as resolved.
- 'frobenius' (default): Use Frobenius rescaling before NS.
- 'spectral' : Use Spectral norm rescaling before NS.
- 'aol': Use AOL rescaling to improve orthogonality.
- 'schatten': Use the Schatten-4 norm for rescaling.
weight_dimension_numbers: An optional tree with the same structure as the
params of `MuonDimensionNumbers`s, specifying how to reshape the
parameters before and after the orthogonalization OR a callable returning
Expand All @@ -323,16 +430,42 @@ def scale_by_muon(

Bernstein et al., `Old Optimizer, New Norm: An Anthology
<https://arxiv.org/abs/2409.20325>`_, 2024

Liu et al., `Muon is Scalable for LLM Training`,
<https://arxiv.org/abs/2502.16982>`_, 2025

Boissin et al., `Turbo-Muon: Accelerating Orthogonality-Based
Optimization with Pre-Conditioning`,
<https://arxiv.org/abs/2512.04632>`_, 2025

Ahn et al., `Dion: Distributed Orthonormalized Updates`,
<https://arxiv.org/abs/2504.05295>`_, 2025

Grishina et al., `Accelerating Newton-Schulz Iteration for Orthogonalization
via Chebyshev-type Polynomials`,
<https://arxiv.org/abs/2506.10935>`_, 2025

Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their
Application to the Muon Algorithm`,
<https://arxiv.org/pdf/2505.16932>`, 2025
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)

def init_fn(params):
mu = optax.tree.zeros_like(params, dtype=mu_dtype) # First moment
ns_coeffs_ = jnp.asarray(ns_coeffs)

if ns_coeffs_.ndim > 2 or ns_coeffs_.shape[-1] != 3:
raise ValueError(
f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}'
)
if ns_coeffs_.ndim == 2:
if not ns_coeffs_.shape[0] <= ns_steps:
raise ValueError(
f'Not enough coeffs to perform {ns_steps} steps'
)
ns_coeffs_ = ns_coeffs_[-ns_steps:]

return MuonState(
count=jnp.zeros([], jnp.int32),
mu=mu,
Expand Down Expand Up @@ -364,7 +497,7 @@ def update_fn(updates, state, params=None):
# Apply Newton-schulz orthogonalization.
updates = jax.tree.map(
lambda x, dim_num: orthogonalize_via_newton_schulz(
x, state.ns_coeffs, ns_steps, eps, dim_num),
x, state.ns_coeffs, ns_steps, preconditioning, eps, dim_num),
mu_hat, resolved_weight_dim_nums, is_leaf=_is_weight_dim_nums)
if adaptive:
# Scale the orthogonalized updates by the dual norm of the original
Expand All @@ -388,6 +521,7 @@ def muon(
tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike],
tuple[tuple[jax.typing.ArrayLike, jax.typing.ArrayLike,
jax.typing.ArrayLike], ...],
str,
] = _DEFAULT_NS_COEFFS,
ns_steps: jax.typing.ArrayLike = 5,
beta: jax.typing.ArrayLike = 0.95,
Expand All @@ -400,6 +534,12 @@ def muon(
*,
nesterov: bool = True,
adaptive: bool = False,
preconditioning: Literal[
'frobenius',
'spectral',
'aol',
'schatten',
] = 'frobenius',
adam_b1: jax.typing.ArrayLike = 0.9,
adam_b2: jax.typing.ArrayLike = 0.999,
adam_eps_root: jax.typing.ArrayLike = 0.0,
Expand All @@ -424,7 +564,8 @@ def muon(
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
ns_coeffs: Coefficients for the Newton-schulz method.
ns_coeffs: Coefficients for the Newton-schulz method (can be a string
indicator for a preset). Existing presets: `muon`, `dion`.
ns_steps: Number of Newton-schulz iterations.
Ignored if `ns_coeffs` is a tuple of tuples.
beta: Decay rate for the exponentially weighted average of grads.
Expand All @@ -442,6 +583,20 @@ def muon(
nesterov: Whether to use Nesterov momentum.
adaptive: Whether to scale the updates by the dual norm of the
original updates. See <https://arxiv.org/abs/2409.20325>
preconditioning: What type of preconditioning to use before NS iterations.
Available options are:
- 'frobenius' (default): Use Frobenius rescaling before NS:
safe, standard, but degrades orthogonalization quality when using
less than 5 NS steps.
- 'spectral' : Use Spectral norm rescaling before NS:
much more computationally intensive, but better orthogonalization
quality.
- 'aol': Use AOL rescalings to improve orthogonality with little to
no overhead, usually allows the user to remove one iterative NS step.
See <https://arxiv.org/abs/2512.04632>.
- 'schatten': Use the Schatten-4 norm for rescaling,
allows for better performance with little to no extra cost.
See <https://arxiv.org/abs/2506.10935>.
adam_b1: Exponential decay rate for Adam's first moment estimates.
adam_b2: Exponential decay rate for Adam's second moment estimates.
adam_eps_root: Epsilon to stabilize division in Adam, square root version.
Expand Down Expand Up @@ -473,11 +628,35 @@ def muon(

Liu et al., `Muon is Scalable for LLM Training`,
<https://arxiv.org/abs/2502.16982>`_, 2025

Boissin et al., `Turbo-Muon: Accelerating Orthogonality-Based
Optimization with Pre-Conditioning`,
<https://arxiv.org/abs/2512.04632>`_, 2025

Ahn et al., `Dion: Distributed Orthonormalized Updates`,
<https://arxiv.org/abs/2504.05295>`_, 2025

Grishina et al., `Accelerating Newton-Schulz Iteration for Orthogonalization
via Chebyshev-type Polynomials`,
<https://arxiv.org/abs/2506.10935>`_, 2025

Amsel et al., `The Polar Express: Optimal Matrix Sign Methods and Their
Application to the Muon Algorithm`,
<https://arxiv.org/pdf/2505.16932>`, 2025
"""

if adam_learning_rate is None:
adam_learning_rate = learning_rate

if isinstance(ns_coeffs, str):
if ns_coeffs not in _NS_COEFFS_PRESET_DICT:
raise ValueError(
f'Unknown ns_coeff preset string: {ns_coeffs}'
)
ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs]
else:
ns_coeffs_ = ns_coeffs

# None at root indicates the default 2D rule.
if muon_weight_dimension_numbers is None:
param_labels = lambda params: jax.tree.map(
Expand Down Expand Up @@ -516,13 +695,14 @@ def muon_weight_dim_nums_fn(params):
transforms={
'muon': combine.chain(
scale_by_muon(
ns_coeffs=ns_coeffs,
ns_coeffs=ns_coeffs_,
ns_steps=ns_steps,
beta=beta,
eps=eps,
mu_dtype=mu_dtype,
nesterov=nesterov,
adaptive=adaptive,
preconditioning=preconditioning,
weight_dimension_numbers=muon_weight_dim_nums_fn,
),
scale_by_shape(
Expand Down
Loading
Loading