Authors: Da Chang, Yongxiang Liu, Ganzhao Yuan
Paper: https://arxiv.org/abs/2509.15816
We introduce MuonMVR, which applies Variance-Reduction to Muon, and provide a detailed theoretical analysis of its convergence.
The core idea of MuonMVR is to reduce gradient variance by leveraging information from the previous training step.
This requires a specific training loop structure: the gradient of the previous batch must be calculated before the optimizer performs step() on the gradient of the current batch.
This operation is handled by the optimizer.update_last_grad() method.
In practice, for low computational cost and low memory requirements, the existing and widely used approximate version, MVR1, is sufficient.
The mathematical formulas behind the different modes of the MuonMVR optimizer. Let
Standard Muon (EMA) The baseline optimizer uses a standard Exponential Moving Average (EMA) of the gradients. It serves as the foundation for the various variance-reduced variants.
MVR1: Single-Batch Variance Reduction
This is a precise variance reduction method. Its correction term is based on the difference between the gradient of the current batch (using the current parameters is_approx=True mode in the practical implementation.
By setting
MVR2: Dual-Batch Variance Reduction (Standard MVR)
This is a standard practice discussed in theory, where the variance reduction term is calculated on the same data batch
MVR3: Dual-Batch Variance Reduction (Approximate MVR)
This is another version that corresponds to the is_approx=False mode. Its correction term uses the difference between the gradient of the current batch (with current parameters
Because MVR2 and MVR3 require extra forward and backward passes and storage of additional data or model states, the existing and widely used approximate version, MVR1, is sufficient in practice for low computational cost and low memory requirements.
- Standard Training (No Gradient Accumulation)
This is the most basic usage. In each iteration, first calculate the gradient for the previous batch and call
update_last_grad(), then calculate the gradient for the current batch and callstep().# Initialize the optimizer (use the exact version or a high-precision approximate version) optimizer = MuonMVR(model.parameters(), lr=args.lr, gamma=args.gamma, is_approx=False) # Training loop previous_X, previous_Y = None, None for epoch in range(epochs): for X, Y in data_loader: if previous_X is not None: # Calculate the gradient for the previous batch logits, loss = model(previous_X, previous_Y) loss.backward() optimizer.update_last_grad() optimizer.zero_grad(set_to_none=True) # Process the current batch logits, loss = model(X, Y) loss.backward() optimizer.step() optimizer.zero_grad(set_to_none=True) # Store the current batch for the next iteration previous_X, previous_Y = X.clone(), Y.clone()
- Training with Gradient Accumulation
To use gradient accumulation, the losses for both the current and previous batches must be scaled by dividing by the number of accumulation steps.
optimizer.step()is only called after the required number of gradients has been accumulated.# Initialize the optimizer optimizer = MuonMVR(model.parameters(), lr=args.lr, gamma=args.gamma, is_approx=False) # Training loop previous_X, previous_Y = None, None accum_steps = 4 # Number of gradient accumulation steps for epoch in range(epochs): for i, (X, Y) in enumerate(data_loader): # Process the current batch logits, loss = model(X, Y) # Scale the loss to average the gradients loss = loss / accum_steps loss.backward() if previous_X is not None: # Calculate the gradient for the previous batch prev_logits, prev_loss = model(previous_X, previous_Y) prev_loss = prev_loss / accum_steps prev_loss.backward() optimizer.update_last_grad() # Clear the gradients for the previous batch, keeping the gradients for the current batch optimizer.zero_grad(set_to_none=True) # Update parameters when enough gradients have been accumulated if (i + 1) % accum_steps == 0: optimizer.step() optimizer.zero_grad(set_to_none=True) previous_X, previous_Y = None, None # Reset for the next accumulation cycle else: previous_X, previous_Y = X.clone(), Y.clone()
-
Optimizer Modes MuonMVR can be initialized in different modes to trade off between precision and computational cost.
Exact Variance Reduction (
is_approx=False) To achieve the most precise variance reduction, you must manually manage the model state. Before calculating the gradient for the previous batch, you need to load the model state from the previous iteration. This ensures that the gradient is computed with the correct model weights.optimizer = MuonMVR(model.parameters(), lr=args.lr, gamma=args.gamma, is_approx=False) old_state_dict = {} for batch_idx, (inputs, targets) in enumerate(trainloader): # Store the current model state cur_state_dict = {k: v.data.clone() for k, v in net.state_dict().items()} if old_state_dict: # Load the previous model state to compute the old gradient net.load_state_dict(old_state_dict) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.update_last_grad() # Restore the current model state to compute the new gradient net.load_state_dict(cur_state_dict) old_state_dict = {k: v.data.clone() for k, v in cur_state_dict.items()} # Standard forward/backward pass and step optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step()
Approximate Version (is_approx=True)
This mode uses a more computationally efficient approximation. Its training loop structure is the same as the standard training example.
# Initialize the optimizer and enable the approximate mode
optimizer = MuonMVR(model.parameters(), lr=args.lr, gamma=args.gamma, is_approx=True)We referenced the original author's implementation of Muon as well as the implementation from Moonlight-Muon, and adopted the gradient clipping strategy from MARS to further control gradient noise. We thank them for their open-source contributions!
@article{Chang2025OnTC,
title={On the Convergence of Muon and Beyond},
author={Da Chang and Yongxiang Liu and Ganzhao Yuan},
journal={ArXiv},
year={2025},
volume={abs/2509.15816}
}