Skip to content

Commit ed854ba

Browse files
runamefacebook-github-bot
authored andcommitted
Sync KL-Shampoo to OSS
Summary: Sync KL-Shampoo implementation to OSS. Also, add KL-Shampoo reference to README. Differential Revision: D89193149
1 parent a8601eb commit ed854ba

7 files changed

Lines changed: 387 additions & 10 deletions

File tree

distributed_shampoo/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ Key distinctives of this implementation include:
5151
- Option to (approximately) correct the eigenvalues/run Adam in the eigenbasis of Shampoo's preconditioner (SOAP) [2,6,7].
5252
- Option to use an adaptive preconditioner update frequency when symmetric eigendecompositions or the QR algorithm is used [8].
5353
- Spectral descent via reduced SVD or Newton-Schulz iteration for 2D gradients, or gradients that have been reshaped to 2D [9,10]. This can be used to implement Muon [11], see [Example 6](#example-6-muon).
54+
- KL-Shampoo (without per-factor matrix eigenvalue correction) [12].
5455

5556
## Requirements
5657

@@ -784,3 +785,4 @@ If you use PyTorch Distributed Shampoo in your work, please use the following Bi
784785
9. [Preconditioned Spectral Descent for Deep Learning](https://papers.nips.cc/paper_files/paper/2015/hash/f50a6c02a3fc5a3a5d4d9391f05f3efc-Abstract.html). David E. Carlson, Edo Collins, Ya-Ping Hsieh, Lawrence Carin, Volkan Cevher. NeurIPS, 2015.
785786
10. [Old Optimizer, New Norm: An Anthology](https://arxiv.org/abs/2409.20325). Jeremy Bernstein, Laker Newhouse. Tech report, 2024.
786787
11. [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/). Keller Jordan, Yuchen Jin, Vlado Boza, Jiacheng You, Franz Cesista, Laker Newhouse, Jeremy Bernstein. Blog post, 2024.
788+
12. [Understanding and Improving Shampoo and SOAP via Kullback-Leibler Minimization](https://arxiv.org/abs/2509.03378). Wu Lin, Scott C. Lowe, Felix Dangel, Runa Eschenhagen, Zikun Xu, Roger B. Grosse. Tech report, 2025.

distributed_shampoo/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
DefaultSOAPConfig,
4444
DefaultSpectralDescentPreconditionerConfig,
4545
DistributedConfig,
46+
EigendecomposedKLShampooPreconditionerConfig,
4647
EigendecomposedShampooPreconditionerConfig,
4748
EigenvalueCorrectedShampooPreconditionerConfig,
4849
FSDPDistributedConfig,
@@ -52,6 +53,7 @@
5253
HybridShardDistributedConfig,
5354
PreconditionerConfig,
5455
RMSpropPreconditionerConfig,
56+
RootInvKLShampooPreconditionerConfig,
5557
RootInvShampooPreconditionerConfig,
5658
SGDPreconditionerConfig,
5759
ShampooPreconditionerConfig,
@@ -83,7 +85,9 @@
8385
"ShampooPreconditionerConfig", # Abstract base class (based on `AmortizedPreconditionerConfig`).
8486
"RootInvShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`.
8587
"DefaultShampooConfig", # Default `RootInvShampooPreconditionerConfig` using `EigenConfig`.
88+
"RootInvKLShampooPreconditionerConfig", # Based on `RootInvShampooPreconditionerConfig`.
8689
"EigendecomposedShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`.
90+
"EigendecomposedKLShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`.
8791
"EigenvalueCorrectedShampooPreconditionerConfig", # Based on `AmortizedPreconditionerConfig`.
8892
"DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigendecompositionConfig`.
8993
"DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QREigendecompositionConfig`.

distributed_shampoo/distributed_shampoo.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@
4848
SGDPreconditionerList,
4949
)
5050
from distributed_shampoo.preconditioner.shampoo_preconditioner_list import (
51+
EigendecomposedKLShampooPreconditionerList,
5152
EigendecomposedShampooPreconditionerList,
5253
EigenvalueCorrectedShampooPreconditionerList,
54+
RootInvKLShampooPreconditionerList,
5355
RootInvShampooPreconditionerList,
5456
)
5557
from distributed_shampoo.preconditioner.sign_descent_preconditioner_list import (
@@ -72,6 +74,7 @@
7274
DISTRIBUTED_CONFIG,
7375
DistributedConfig,
7476
DISTRIBUTOR,
77+
EigendecomposedKLShampooPreconditionerConfig,
7578
EigendecomposedShampooPreconditionerConfig,
7679
EigenvalueCorrectedShampooPreconditionerConfig,
7780
EPSILON,
@@ -98,6 +101,7 @@
98101
PreconditionerConfig,
99102
PREVIOUS_GRAD_SELECTOR,
100103
RMSpropPreconditionerConfig,
104+
RootInvKLShampooPreconditionerConfig,
101105
RootInvShampooPreconditionerConfig,
102106
SGDPreconditionerConfig,
103107
SHAMPOO_PRECONDITIONER_LIST,
@@ -638,13 +642,17 @@ def _preconditioner_config_to_list_cls(
638642
RootInvShampooPreconditionerConfig()
639643
| EigendecomposedShampooPreconditionerConfig()
640644
| EigenvalueCorrectedShampooPreconditionerConfig()
645+
| RootInvKLShampooPreconditionerConfig()
646+
| EigendecomposedKLShampooPreconditionerConfig()
641647
):
642648
preconditioner_config_to_list_cls: dict[
643649
type[PreconditionerConfig], Callable[..., PreconditionerList]
644650
] = {
645651
RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList,
646652
EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList,
647653
EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList,
654+
RootInvKLShampooPreconditionerConfig: RootInvKLShampooPreconditionerList,
655+
EigendecomposedKLShampooPreconditionerConfig: EigendecomposedKLShampooPreconditionerList,
648656
}
649657
return preconditioner_config_to_list_cls[type(preconditioner_config)](
650658
block_list=state_lists[DISTRIBUTOR].local_blocked_params,

distributed_shampoo/preconditioner/shampoo_preconditioner_list.py

Lines changed: 122 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,25 @@ def compress_preconditioner_list(
12701270
self._local_preconditioned_dims_selector_list, local_grad_selector
12711271
)
12721272

1273+
@profile_decorator
1274+
def _compute_outer_product_list(
1275+
self,
1276+
grad: Tensor,
1277+
order: int,
1278+
preconditioned_dims_selector: tuple[bool, ...],
1279+
kronecker_factors: _ShampooKroneckerFactorsUnwrappedType,
1280+
) -> tuple[Tensor, ...]:
1281+
# Construct outer product list for updating Kronecker factors.
1282+
return tuple(
1283+
torch.tensordot(
1284+
grad,
1285+
grad,
1286+
# Contracts across all dimensions except for k.
1287+
dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type]
1288+
)
1289+
for k in compress_list(range(order), preconditioned_dims_selector)
1290+
)
1291+
12731292
@profile_decorator
12741293
def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None:
12751294
# NOTE: Unlike AdagradPreconditionerList, we will loop through each gradient individually.
@@ -1286,15 +1305,8 @@ def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None:
12861305
if not kronecker_factors.factor_matrices:
12871306
continue
12881307

1289-
# Construct outer product list for updating Kronecker factors.
1290-
outer_product_list = tuple(
1291-
torch.tensordot(
1292-
grad,
1293-
grad,
1294-
# Contracts across all dimensions except for k.
1295-
dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type]
1296-
)
1297-
for k in compress_list(range(order), preconditioned_dims_selector)
1308+
outer_product_list = self._compute_outer_product_list(
1309+
grad, order, preconditioned_dims_selector, kronecker_factors
12981310
)
12991311

13001312
if self._beta2 != 1.0:
@@ -1673,7 +1685,6 @@ def update_preconditioners(
16731685
kronecker_factors.corrected_eigenvalues.mul_(self._beta2)
16741686

16751687
# NOTE: The case when self._weighting_factor == 1.0 is not well tested and might not be stable.
1676-
print(f"{self._weighting_factor=}")
16771688
kronecker_factors.corrected_eigenvalues.addcmul_(
16781689
grad, grad, value=self._weighting_factor
16791690
)
@@ -1707,3 +1718,104 @@ def _compute_preconditioned_gradient(
17071718
preconditioner_list=kronecker_factors.factor_matrices_eigenvectors,
17081719
dims=([0], [1]),
17091720
)
1721+
1722+
1723+
class RootInvKLShampooPreconditionerList(RootInvShampooPreconditionerList):
1724+
"""Root inverse KL-Shampoo preconditioners for list of parameters."""
1725+
1726+
@profile_decorator
1727+
def _compute_outer_product_list(
1728+
self,
1729+
grad: Tensor,
1730+
order: int,
1731+
preconditioned_dims_selector: tuple[bool, ...],
1732+
kronecker_factors: RootInvShampooKroneckerFactorsUnwrapped,
1733+
) -> tuple[Tensor, ...]:
1734+
# Construct outer product list for updating Kronecker factors.
1735+
outer_product_list = []
1736+
for idx_of_k, k in enumerate(
1737+
compress_list(range(order), preconditioned_dims_selector)
1738+
):
1739+
# KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products.
1740+
local_preconditioned_dims_selector = list(preconditioned_dims_selector)
1741+
local_preconditioned_dims_selector[k] = False
1742+
preconditioned_grad = self._precondition_grad(
1743+
grad=grad,
1744+
preconditioned_dims_selector=tuple(local_preconditioned_dims_selector),
1745+
preconditioner_list=tuple(
1746+
inv_factor_matrix
1747+
for idx, inv_factor_matrix in enumerate(
1748+
kronecker_factors.inv_factor_matrices
1749+
)
1750+
if idx != idx_of_k
1751+
),
1752+
)
1753+
outer_product_list.append(
1754+
torch.tensordot(
1755+
preconditioned_grad,
1756+
preconditioned_grad,
1757+
# Contracts across all dimensions except for k.
1758+
dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type]
1759+
)
1760+
)
1761+
return tuple(outer_product_list)
1762+
1763+
1764+
class EigendecomposedKLShampooPreconditionerList(
1765+
EigendecomposedShampooPreconditionerList
1766+
):
1767+
"""Eigendecomposed KL-Shampoo preconditioners for list of parameters."""
1768+
1769+
@profile_decorator
1770+
def _compute_outer_product_list(
1771+
self,
1772+
grad: Tensor,
1773+
order: int,
1774+
preconditioned_dims_selector: tuple[bool, ...],
1775+
kronecker_factors: EigendecomposedShampooKroneckerFactorsUnwrapped,
1776+
) -> tuple[Tensor, ...]:
1777+
# TODO: remove assertion when rank_deficient_stability_config is generalized to MatrixFunctionConfig
1778+
assert isinstance(
1779+
self._preconditioner_config.amortized_computation_config,
1780+
EigendecompositionConfig,
1781+
)
1782+
rank_deficient_stability_config = self._preconditioner_config.amortized_computation_config.rank_deficient_stability_config
1783+
1784+
# Construct outer product list for updating Kronecker factors.
1785+
outer_product_list = []
1786+
for idx_of_k, k in enumerate(
1787+
compress_list(range(order), preconditioned_dims_selector)
1788+
):
1789+
# KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products.
1790+
local_preconditioned_dims_selector = list(preconditioned_dims_selector)
1791+
local_preconditioned_dims_selector[k] = False
1792+
preconditioned_grad = self._precondition_grad(
1793+
grad=grad,
1794+
preconditioned_dims_selector=tuple(local_preconditioned_dims_selector),
1795+
preconditioner_list=tuple(
1796+
matrix_inverse_root_from_eigendecomposition(
1797+
L=eigenvalues,
1798+
Q=eigenvectors,
1799+
root=Fraction(root),
1800+
epsilon=self._epsilon,
1801+
rank_deficient_stability_config=rank_deficient_stability_config,
1802+
)
1803+
for idx, (eigenvalues, eigenvectors, root) in enumerate(
1804+
zip(
1805+
kronecker_factors.factor_matrices_eigenvalues,
1806+
kronecker_factors.factor_matrices_eigenvectors,
1807+
kronecker_factors.roots,
1808+
strict=True,
1809+
)
1810+
)
1811+
if idx != idx_of_k
1812+
),
1813+
)
1814+
outer_product = torch.tensordot(
1815+
preconditioned_grad,
1816+
preconditioned_grad,
1817+
# Contracts across all dimensions except for k.
1818+
dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type]
1819+
)
1820+
outer_product_list.append(outer_product)
1821+
return tuple(outer_product_list)

0 commit comments

Comments
 (0)