diff --git a/test/prototype/test_pat.py b/test/prototype/test_pat.py new file mode 100644 index 0000000000..f34730bf99 --- /dev/null +++ b/test/prototype/test_pat.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import random +import unittest + +import torch +from torch import nn +from torch.testing._internal import common_utils + +from torchao.prototype.pat.group import ( + AttentionHeadGrouperDim0, + AttentionHeadGrouperDim1, + PackedSVDGrouper, + QKGrouper, + QKSVDGrouper, + SVDGrouper, +) +from torchao.prototype.pat.layers.masked_layernorm import MaskedLayerNorm +from torchao.prototype.pat.optim import ProxGroupLasso, ProxNuclearNorm, PruneOptimizer +from torchao.prototype.pat.utils import get_param_groups + + +class TestMaskedLayerNorm(common_utils.TestCase): + @common_utils.parametrize("batch", [1, 4]) + @common_utils.parametrize("seq_len", [2, 8]) + @common_utils.parametrize("embed_dim", [16, 64]) + def test_masked_layernorm(self, batch=1, seq_len=2, embed_dim=16): + dim2_nz = embed_dim // 2 + embed = torch.randn(batch, seq_len, embed_dim) + embed[..., dim2_nz:] = 0 + + masked_layer_norm = MaskedLayerNorm(embed_dim) + layer_norm = nn.LayerNorm(dim2_nz) + with torch.no_grad(): + layer_norm.weight.copy_(masked_layer_norm.weight[:dim2_nz]) + layer_norm.bias.copy_(masked_layer_norm.bias[:dim2_nz]) + + out = masked_layer_norm(embed) + expected_out = layer_norm(embed[..., :dim2_nz]) + torch.testing.assert_close(out[..., :dim2_nz], expected_out) + + +class MHADummyModel(nn.Module): + def __init__(self, embed_dim, num_heads, n_cls): + super().__init__() + self.mha = nn.MultiheadAttention(embed_dim, num_heads, bias=False) + self.classifier = nn.Linear(embed_dim, n_cls) + + def forward(self, x): + attn_output, _ = self.mha(x, x, x) + out = self.classifier(attn_output) + return out + + +class TestQKGrouper(common_utils.TestCase): + def __init__(self, methodName): + super(TestQKGrouper, self).__init__(methodName) + self.reg_lambda = 1.0 + self.prox_map = ProxGroupLasso(self.reg_lambda) + + @staticmethod + def _get_qk(p, embed_dim, qk_reg_index): + qk = p[:embed_dim] if qk_reg_index == 0 else p[embed_dim : (embed_dim * 2)] + return qk + + def get_gamma(self, p): + """Heuristic that uses the mean of the group to set gamma.""" + p_col = p[:, 0] + gamma = (1 - p_col.mean()) * torch.linalg.vector_norm(p_col) + gamma.div_(self.prox_map.tau(p_col)) + return gamma + + def _test_post_prune(self, p, qk_orig, embed_dim, qk_reg_index, gamma): + qk = self._get_qk(p, embed_dim, qk_reg_index) + nz_mask = qk.sum(dim=0).ne(0) + self.assertTrue(nz_mask.eq(0).any(), "No columns of Q/K were pruned") + + # original columns that are <= gamma are pruned + expect_nz_mask = qk_orig.gt(gamma).all(dim=0) + torch.testing.assert_close(nz_mask > 1, expect_nz_mask, atol=0, rtol=0) + + def _test_mha_inner(self, p, embed_dim, qk_reg_index): + qk_orig = self._get_qk(p, embed_dim, qk_reg_index).clone() + qk_no_prune = self._get_qk(p, embed_dim, int(not qk_reg_index)).clone() + v_orig = p[(embed_dim * 2) :].clone() + qk_pack_dim = 0 + with QKGrouper(p, qk_pack_dim, qk_reg_index) as grouper: + self.assertTrue(grouper.p.equal(qk_orig)) + + gamma = self.get_gamma(grouper.p) + _ = torch.vmap( + self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0 + )(grouper.p, gamma) + + self._test_post_prune(p, qk_orig, embed_dim, qk_reg_index, gamma) + + # unregularized query or key was not modified + no_prune = self._get_qk(p, embed_dim, int(not qk_reg_index)) + torch.testing.assert_close(no_prune, qk_no_prune, atol=0, rtol=0) + + # value was not modified + v = p[(embed_dim * 2) :] + torch.testing.assert_close(v, v_orig, atol=0, rtol=0) + + @common_utils.parametrize("embed_dim", [16, 64]) + @common_utils.parametrize("num_heads", [2, 4]) + @common_utils.parametrize("qk_reg_index", [0, 1]) + def test_pytorch_mha(self, embed_dim=16, num_heads=4, qk_reg_index=0): + assert embed_dim % num_heads == 0, ( + f"{embed_dim=} must be divisible by {num_heads=}" + ) + + # single in_proj_weight of shape (embed_dim * 3, embed_dim) + model = nn.MultiheadAttention(embed_dim, num_heads, bias=False) + p = model.in_proj_weight.detach() + self._test_mha_inner(p, embed_dim, qk_reg_index) + + @common_utils.parametrize("qk_reg_index", [0, 1]) + def test_e2e_optimizer(self, embed_dim=64, qk_reg_index=0): + n_cls = 3 + model = MHADummyModel(embed_dim, num_heads=4, n_cls=n_cls) + prune_config = { + "mha.in_proj_weight": { + "group_type": "QKGrouper", + "prox_type": "ProxGroupLasso", + "qk_pack_dim": 0, + "qk_reg_index": qk_reg_index, + } + } + param_groups = get_param_groups(model, prune_config, verbose=False) + self.assertEqual(len(param_groups), 3) + + p = model.mha.in_proj_weight.detach() + qk_orig = self._get_qk(p, embed_dim, qk_reg_index).clone() + + # set lr to gamma since we run a single step + gamma = self.get_gamma(qk_orig) + optimizer = PruneOptimizer( + torch.optim.SGD(param_groups, lr=gamma), reg_lambda=self.reg_lambda + ) + + data = torch.randn(1, 8, embed_dim) + label = torch.arange(0, n_cls) * data.mean(axis=-1, keepdim=True) + output = model(data) + loss = nn.functional.mse_loss(output, label) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + self._test_post_prune(p, qk_orig, embed_dim, qk_reg_index, gamma) + + +class TestAttentionHeadGrouper(common_utils.TestCase): + def __init__(self, methodName): + super(TestAttentionHeadGrouper, self).__init__(methodName) + self.reg_lambda = 1.0 + self.prox_map = ProxGroupLasso(self.reg_lambda) + + @staticmethod + def _get_view_shape_reduce_dim(dim, num_heads, head_pack_dim): + if head_pack_dim == 0: + view_shape = (num_heads, -1, dim) + reduce_dim = (1, 2) + else: + view_shape = (dim, num_heads, -1) + reduce_dim = (0, 2) + return view_shape, reduce_dim + + def _test_post_prune(self, p, p_orig, head_pack_dim, view_shape, reduce_dim, gamma): + nz_mask = p.view(*view_shape).sum(dim=reduce_dim).ne(0) + self.assertTrue(nz_mask.eq(0).any(), "No groups of p were pruned") + + # original groups that are <= gamma are pruned + expect_nz_mask = p_orig.view(*view_shape).gt(gamma).all(dim=reduce_dim) + torch.testing.assert_close(nz_mask > 1, expect_nz_mask, atol=0, rtol=0) + + def get_gamma(self, p, head_pack_dim, view_shape): + """Heuristic that uses the mean of the group to set gamma.""" + p = p.view(*view_shape) + p_group = p[0] if head_pack_dim == 0 else p[:, 0] + gamma = (1 - p_group.mean()) * torch.linalg.vector_norm(p_group) + gamma.div_(self.prox_map.tau(p_group)) + return gamma + + @common_utils.parametrize("dim", [64, 128]) + @common_utils.parametrize("head_pack_dim", [0, 1]) + def test_head_grouper(self, dim=16, head_pack_dim=0, head_dim_ratio=8): + assert dim % head_dim_ratio == 0, ( + f"{dim=} must be divisible by {head_dim_ratio=}" + ) + num_heads = dim // 8 + packed_dim = dim * num_heads + shape = (dim, packed_dim) if head_pack_dim == 0 else (packed_dim, dim) + model = nn.Linear(*shape, bias=False) + p = model.weight.detach() + p_orig = p.clone() + view_shape, reduce_dim = self._get_view_shape_reduce_dim( + dim, num_heads, head_pack_dim + ) + grouper_cls = ( + AttentionHeadGrouperDim0 if head_pack_dim == 0 else AttentionHeadGrouperDim1 + ) + with grouper_cls(p, num_heads) as grouper: + gamma = self.get_gamma(grouper.p, head_pack_dim, view_shape) + _ = torch.vmap( + self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0 + )(grouper.p, gamma) + self.assertEqual(grouper.p.size(head_pack_dim), num_heads) + self._test_post_prune(p, p_orig, head_pack_dim, view_shape, reduce_dim, gamma) + + +class TestSVDGrouper(common_utils.TestCase): + def __init__(self, methodName): + super(TestSVDGrouper, self).__init__(methodName) + self.reg_lambda = 1.0 + self.prox_map = ProxNuclearNorm(self.reg_lambda) + + @common_utils.parametrize("embed_dim", (16, 64)) + def test_grouper(self, embed_dim=16): + model = torch.nn.Linear(embed_dim, embed_dim) + p = model.weight + with SVDGrouper(p) as grouper: + gamma = grouper.p.mean() + p_orig = grouper.p.clone() + torch.vmap( + self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0 + )(grouper.p, gamma) + expect_nz_mask = p_orig.gt(gamma) + torch.testing.assert_close(grouper.p.ne(0), expect_nz_mask, atol=0, rtol=0) + + @common_utils.parametrize("embed_dim", (16, 64)) + @common_utils.parametrize("pack_dim", (0, 1)) + def test_qk_grouper(self, embed_dim=16, pack_dim=0): + shape = [embed_dim, embed_dim] + shape[int(not pack_dim)] *= 3 + model = torch.nn.Linear(*shape) + p = model.weight + with QKSVDGrouper(p, pack_dim=pack_dim) as grouper: + gamma = grouper.p.mean() + p_orig = grouper.p.clone() + torch.vmap( + self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0 + )(grouper.p, gamma) + expect_nz_mask = p_orig.gt(gamma) + torch.testing.assert_close(grouper.p.ne(0), expect_nz_mask, atol=0, rtol=0) + + @common_utils.parametrize("embed_dim", (16, 64)) + @common_utils.parametrize("pack_dim", (0, 1)) + def test_packed_grouper(self, embed_dim=16, npack=3, pack_dim=0): + shape = [embed_dim, embed_dim] + shape[int(not pack_dim)] *= npack + model = torch.nn.Linear(*shape) + p = model.weight + with PackedSVDGrouper(p, npack, pack_dim=pack_dim) as grouper: + gamma = grouper.p.mean(0).mean() + p_orig = grouper.p.clone() + torch.vmap( + self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0 + )(grouper.p.flatten(), gamma) + torch.testing.assert_close( + grouper.p.ne(0), p_orig.gt(gamma), atol=0, rtol=0 + ) + self.assertEqual(p.data_ptr(), grouper._p.data_ptr()) + + +common_utils.instantiate_parametrized_tests(TestMaskedLayerNorm) +common_utils.instantiate_parametrized_tests(TestQKGrouper) +common_utils.instantiate_parametrized_tests(TestAttentionHeadGrouper) +common_utils.instantiate_parametrized_tests(TestSVDGrouper) + +if __name__ == "__main__": + random.seed(0) + torch.manual_seed(0) + unittest.main() diff --git a/torchao/prototype/pat/README.md b/torchao/prototype/pat/README.md new file mode 100644 index 0000000000..1cb6c93d9c --- /dev/null +++ b/torchao/prototype/pat/README.md @@ -0,0 +1,50 @@ +# PAT: Pruning-Aware Training + +PAT is a library based on group Lasso regularization. It directly induces structured sparsity during training, removing the need for custom pruning metrics and/or multiple rounds of training. + +PAT's simple optimizer-only interface supports easy integration into existing training pipelines. The code is organized into two main components: +* grouper: defines the granularity of pruning (e.g., filter, channel, layer) +* proximal mapping: projects groups of weights onto sparse values + +## Optimizer-only interface + +This package provides a `PruneOptimizer` that simply wraps around a base optimizer inheriting from `torch.optim.Optimizer`. The following code snippet illustrates how to set up PAT: + +```python +from pat.optim import PruneOptimizer + +model = torchvision.models.resnet18().cuda() + +# split params into prunable and non-prunable groups +weights = [p for name, p in model.named_parameters() if name.endswith("weight")] +others = [p for name, p in model.named_parameters() if not name.endswith("weight")] + +# apply row-wise group Lasso regularization to the weights +param_groups = [ + { + "params": weights", + "group_type": "pat.group.Dim0Grouper", + "prox_type": "pat.prox.ProxGroupLasso", + "reg_lambda": 2e-4, + }, + {"params": others}, +] + +# create base optimizer (SGD, Adam or AdamW) +base_optimizer = torch.optim.SGD( + param_groups, lr=0.1, momentum=0.9, weight_decay=1e-4 +) + +# create PruneOptimizer +optimizer = PruneOptimizer(base_optimizer) +``` + +After creating `PruneOptimizer`, one can use it as a regular PyTorch optimizer. + +## Pruning configuration + +Pruning configs are dictionaries that define which parameter groups to prune and how to prune them. Each key-value pair in the config maps to a prunable parameter group of `PruneOptimizer`. The keys are used to match model parameters, while the values specify the pruning granularity and proximal map. The key can be one of the following types: + +- parameter name (string): e.g., `blocks.0.attn.qkv.weight` +- regex pattern (string): e.g., `:.*attn\.qkv\.weight` +- module type and parameter name suffix ((class, string) tuple): e.g., `(torch.nn.Linear, 'weight')` diff --git a/torchao/prototype/pat/__init__.py b/torchao/prototype/pat/__init__.py new file mode 100644 index 0000000000..52d71d0141 --- /dev/null +++ b/torchao/prototype/pat/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .group import ( # noqa: F401 + AttentionHeadGrouperDim0, + AttentionHeadGrouperDim1, + ConvFilterGrouper, + Dim0Grouper, + Dim1Grouper, + ElemGrouper, + LayerGrouper, + PackedSVDGrouper, + QKGrouper, + QKSVDGrouper, + SVDGrouper, +) +from .optim import ( # noqa: F401 + NMSGDOptimizer, + ProxGroupLasso, + ProxGroupLassoReduce, + ProxLasso, + ProxNuclearNorm, + PruneOptimizer, +) diff --git a/torchao/prototype/pat/group/__init__.py b/torchao/prototype/pat/group/__init__.py new file mode 100644 index 0000000000..5e1fe2b637 --- /dev/null +++ b/torchao/prototype/pat/group/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .attention import ( # noqa: F401 + AttentionHeadGrouperDim0, + AttentionHeadGrouperDim1, + QKGrouper, +) +from .conv import ConvFilterGrouper # noqa: F401 +from .dim import Dim0Grouper, Dim1Grouper # noqa: F401 +from .grouper import ( # noqa: F401 + ElemGrouper, + Grouper, + LayerGrouper, +) +from .low_rank import PackedSVDGrouper, QKSVDGrouper, SVDGrouper # noqa: F401 diff --git a/torchao/prototype/pat/group/attention.py b/torchao/prototype/pat/group/attention.py new file mode 100644 index 0000000000..8a3bb83c67 --- /dev/null +++ b/torchao/prototype/pat/group/attention.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from torch import Tensor + +from .dim import Dim0Grouper, Dim1Grouper +from .packed import PackedGrouperMixin + + +class QKGrouper(PackedGrouperMixin, Dim1Grouper): + """Grouper applied only to query and key weights. Assumes that query, key, + value weights are packed along `qk_pack_dim` dimension. + + Args: + p (Tensor): The packed query, key, value weights. + qk_pack_dim (int): Dimension along which query and key are packed. + qk_reg_index (int, optional): 0 for query, 1 for key. Default: 0. + """ + + def __init__( + self, + p: Tensor, + qk_pack_dim: int = 0, + qk_reg_index: int = 0, + ): + super().__init__(p, 3, qk_pack_dim) + + if qk_reg_index == 0: # query + start, end = 0, self.embed_dim + else: # key + start, end = self.embed_dim, self.embed_dim * 2 + + super(PackedGrouperMixin, self).__init__( + p[start:end] if qk_pack_dim == 0 else p[:, start:end] + ) + + +class AttentionHeadGrouperDim0(Dim0Grouper): + """Grouper for attention heads. + + Args: + p (Tensor): Input tensor with packed attention heads. + num_heads (int): Number of attention heads. + """ + + def __init__(self, p: Tensor, num_heads: int): + super().__init__(p) + + self.num_heads = num_heads + self.head_dim = p.size(0) // num_heads + + def __enter__(self): + self.p.data = self.p.data.view(self.num_heads, -1) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.p.data = self.p.data.view(self.num_heads * self.head_dim, -1) + super().__exit__(exc_type, exc_val, exc_tb) + + +class AttentionHeadGrouperDim1(Dim1Grouper): + def __init__(self, p: Tensor, num_heads: int): + self._orig_p = p + self.head_dim = p.size(1) // num_heads + p = p.view(-1, num_heads, self.head_dim).transpose(1, 2).contiguous() + super().__init__(p) + + self.num_heads = num_heads + + def __enter__(self): + self.p.data = self.p.data.view(-1, self.num_heads) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.p.data = self.p.data.view(-1, self.head_dim, self.num_heads) + self.p = self.p.data.transpose(1, 2).contiguous().view(self._orig_p.shape) + self._orig_p.data.copy_(self.p.data) + super().__exit__(exc_type, exc_val, exc_tb) diff --git a/torchao/prototype/pat/group/conv.py b/torchao/prototype/pat/group/conv.py new file mode 100644 index 0000000000..163ccf307b --- /dev/null +++ b/torchao/prototype/pat/group/conv.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from torch import Tensor + +from .grouper import Grouper + + +class ConvFilterGrouper(Grouper): + def __init__(self, p: Tensor): + assert p.dim() == 4, "ConvFilterGrouper only supports 4D tensors" + super().__init__(p, in_dims=0) + + def __enter__(self): + self.p.data = self.p.data.view(self.orig_shape[0] * self.orig_shape[1], -1) + return self diff --git a/torchao/prototype/pat/group/dim.py b/torchao/prototype/pat/group/dim.py new file mode 100644 index 0000000000..6d1a5d9611 --- /dev/null +++ b/torchao/prototype/pat/group/dim.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from torch import Tensor + +from .grouper import Grouper + + +class DimGrouperMixin: + def __init__(self, start_dim: int = 1, end_dim: int = -1): + self.start_dim = start_dim + self.end_dim = end_dim + + def __enter__(self): + if self.p.dim() > 2: + self.p.data = self.p.data.flatten( + start_dim=self.start_dim, end_dim=self.end_dim + ) + return self + + +class Dim0Grouper(DimGrouperMixin, Grouper): + def __init__(self, p: Tensor, start_dim: int = 1, end_dim: int = -1): + super().__init__(start_dim, end_dim) + super(DimGrouperMixin, self).__init__(p, in_dims=0) + + +class Dim1Grouper(DimGrouperMixin, Grouper): + def __init__(self, p: Tensor, start_dim: int = 1, end_dim: int = -1): + super().__init__(start_dim, end_dim) + super(DimGrouperMixin, self).__init__(p, in_dims=1) diff --git a/torchao/prototype/pat/group/grouper.py b/torchao/prototype/pat/group/grouper.py new file mode 100644 index 0000000000..e62d0fa143 --- /dev/null +++ b/torchao/prototype/pat/group/grouper.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC +from typing import Optional, Union + +from torch import Tensor + + +class Grouper(ABC): + def __init__(self, p: Tensor, in_dims: Optional[Union[int, tuple]] = None) -> None: + self.p = p + self.orig_shape = p.shape + self.in_dims = in_dims + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.p.data = self.p.data.view(self.orig_shape) + + def group_size(self): + return self.p.numel() // self.p.size(self.in_dims) + + def n_groups(self): + return self.p.numel() // self.group_size() + + +class ElemGrouper(Grouper): + def group_size(self): + return 1 + + +class LayerGrouper(Grouper): + def group_size(self): + return self.p.numel() diff --git a/torchao/prototype/pat/group/low_rank.py b/torchao/prototype/pat/group/low_rank.py new file mode 100644 index 0000000000..60d811fbf7 --- /dev/null +++ b/torchao/prototype/pat/group/low_rank.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import Tensor + +from ..utils import use_deterministic_algorithms +from .grouper import Grouper +from .packed import PackedGrouperMixin + + +class SVDGrouper(Grouper): + """Apply SVD to regularize the singular values of a parameter tensor.""" + + def __init__(self, p: Tensor, pack_dim: Optional[int] = None): + self._p = p + self.orig_shape = p.shape + + # Reshape input to 2D, then regularize its singular values + p = p.data.squeeze() + if pack_dim is None and p.dim() > 2: + p = p.flatten(start_dim=1) + + with use_deterministic_algorithms(): + (self.U, self.p, self.Vh) = torch.linalg.svd(p, full_matrices=False) + + self.in_dims = 0 + + @torch.no_grad() + def __exit__(self, exc_type, exc_val, exc_tb): + self._p.copy_( + torch.linalg.multi_dot([self.U, torch.diag(self.p), self.Vh]).view( + self.orig_shape + ) + ) + + +class QKSVDGrouper(PackedGrouperMixin, SVDGrouper): + def __init__(self, p: Tensor, pack_dim: int = 0): + super().__init__(p, 3, pack_dim) + + self.qk_dim = self.embed_dim * 2 + super(PackedGrouperMixin, self).__init__( + p[: self.qk_dim] if pack_dim == 0 else p[:, : self.qk_dim], + pack_dim=pack_dim, + ) + + @torch.no_grad() + def __exit__(self, exc_type, exc_val, exc_tb): + p = torch.linalg.multi_dot([self.U, torch.diag(self.p), self.Vh]) + if self.pack_dim == 0: + self._p[: self.qk_dim].copy_(p) + else: + self._p[:, : self.qk_dim].copy_(p) + + +class PackedSVDGrouper(PackedGrouperMixin, SVDGrouper): + """Wrapper around SVDGrouper to handle packed tensors.""" + + def __init__(self, p: Tensor, npack: int, pack_dim: int = 0): + super().__init__(p, npack, pack_dim) + + if pack_dim == 1: + p = p.t() + + super(PackedGrouperMixin, self).__init__( + p.view(npack, -1, self.embed_dim), pack_dim=pack_dim + ) + + @torch.no_grad() + def __exit__(self, exc_type, exc_val, exc_tb): + self._p.copy_(self.U @ torch.diag_embed(self.p) @ self.Vh) diff --git a/torchao/prototype/pat/group/packed.py b/torchao/prototype/pat/group/packed.py new file mode 100644 index 0000000000..d48d5fc12b --- /dev/null +++ b/torchao/prototype/pat/group/packed.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor + + +class PackedGrouperMixin: + def __init__(self, p: Tensor, npack: int, pack_dim: int = 0): + assert p.dim() == 2, f"Expected 2D tensor, got {p.dim()=}" + assert pack_dim < p.dim(), f"Invalid {pack_dim=} for {p.shape=}" + + if pack_dim == 0: + embed_dim = p.size(1) + expect_shape = (embed_dim * npack, embed_dim) + else: + embed_dim = p.size(0) + expect_shape = (embed_dim, embed_dim * npack) + assert p.shape == torch.Size(expect_shape), ( + f"Expected {expect_shape=}, got {p.shape=}" + ) + self.embed_dim = embed_dim + self.pack_dim = pack_dim diff --git a/torchao/prototype/pat/layers/__init__.py b/torchao/prototype/pat/layers/__init__.py new file mode 100644 index 0000000000..37d7cc08c6 --- /dev/null +++ b/torchao/prototype/pat/layers/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .masked_layernorm import MaskedLayerNorm + +__all__ = ["MaskedLayerNorm"] diff --git a/torchao/prototype/pat/layers/masked_layernorm.py b/torchao/prototype/pat/layers/masked_layernorm.py new file mode 100644 index 0000000000..4960ff680d --- /dev/null +++ b/torchao/prototype/pat/layers/masked_layernorm.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +import torch +import torch._prims_common as utils +from torch import Tensor, nn +from torch._prims_common.wrappers import _maybe_convert_to_dtype + + +def _nz_normalize(a: Tensor, norm_dims: Union[int, list[int]], eps: float) -> Tensor: + """Computes the normalized tensor, ignoring zeroed out values. + See torch._refs._normalize for more reference. + """ + computation_dtype = utils.get_computation_dtype(a.dtype) + a = _maybe_convert_to_dtype(a, computation_dtype) + + norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) + nz_mask = a.ne(0).detach() + count = torch.sum(nz_mask, dim=norm_dims, keepdim=True).clamp_(min=1) + mean = torch.sum(a, dim=norm_dims, keepdim=True) / count + + a_center = torch.where(nz_mask, a - mean, torch.zeros_like(a)) + biased_var = torch.sum(a_center.pow(2), dim=norm_dims, keepdim=True) / count + out = a_center * torch.rsqrt(biased_var + eps) + return out + + +class MaskedLayerNorm(nn.LayerNorm): + """Layer normalization that ignores zeroed out elements in the input tensor. + See torch._refs.native_layer_norm for reference. + """ + + def forward(self, input: Tensor) -> Tensor: + normalized_ndim = len(self.normalized_shape) + axis = input.ndim - normalized_ndim + reduction_dims = list(range(axis, input.ndim)) + out = _nz_normalize(input, reduction_dims, self.eps) + + if self.elementwise_affine: + out = out * self.weight + self.bias + out = _maybe_convert_to_dtype(out, input.dtype) + return out diff --git a/torchao/prototype/pat/optim/__init__.py b/torchao/prototype/pat/optim/__init__.py new file mode 100644 index 0000000000..966bd3d198 --- /dev/null +++ b/torchao/prototype/pat/optim/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +from torch.optim import Optimizer + +from .group_lasso import ProxGroupLasso, ProxGroupLassoReduce # noqa: F401 +from .lasso import ProxLasso # noqa: F401 +from .nm_sgd import NMSGDOptimizer +from .nuclear_norm import ProxNuclearNorm # noqa: F401 +from .proxmap import ProxMap # noqa: F401 +from .pruneopt import PruneOptimizer + + +def build_prune_optimizer( + base_optimizer: Optimizer, + prune_reg_lambda: float, + prune_warmup_steps: int = 0, + nm_gamma: float = 0.0, +) -> PruneOptimizer: + if nm_gamma > 0: + prune_opt_cls = partial(NMSGDOptimizer, nm_gamma=nm_gamma) + else: + prune_opt_cls = PruneOptimizer + + return prune_opt_cls( + base_optimizer, + warmup_steps=prune_warmup_steps, + reg_lambda=prune_reg_lambda, + ) diff --git a/torchao/prototype/pat/optim/group_lasso.py b/torchao/prototype/pat/optim/group_lasso.py new file mode 100644 index 0000000000..0c53dde7be --- /dev/null +++ b/torchao/prototype/pat/optim/group_lasso.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Union + +import torch +from torch import Tensor + +from ..utils import HAS_DTENSOR, is_dtensor +from .proxmap import ProxMap + +if HAS_DTENSOR: + from torch.distributed.tensor.experimental import local_map + from torch.distributed.tensor.placement_types import Replicate + + +class ProxGroupLasso(ProxMap): + @staticmethod + def tau(p: Tensor) -> float: + """Assumes that p is a group within the full tensor""" + return math.sqrt(p.numel()) + + def _get_norm(self, p): + return torch.linalg.vector_norm(p) + + def apply_(self, p: Tensor, gamma: Union[Tensor, float]) -> Tensor: + super().apply_(p, gamma) + p_norm = self._get_norm(p) + mult = torch.maximum( + 1 - self.threshold(p, gamma) / p_norm, torch.zeros_like(p_norm) + ) + p.mul_(mult) + return mult.eq(0).sum() + + +class ProxGroupLassoReduce(ProxGroupLasso): + @staticmethod + def partial_norm(p): + return p.square().sum() + + def _get_norm(self, p): + assert is_dtensor(p), f"Expected DTensor input but got {type(p)}" + partial_norm = local_map( + self.partial_norm, + out_placements=(Replicate() for _ in p.placements), + device_mesh=p.device_mesh, + )(p) + if partial_norm.dim() > 0: + partial_norm = partial_norm.sum() + return partial_norm.sqrt() diff --git a/torchao/prototype/pat/optim/lasso.py b/torchao/prototype/pat/optim/lasso.py new file mode 100644 index 0000000000..baf886d879 --- /dev/null +++ b/torchao/prototype/pat/optim/lasso.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +from torch import Tensor + +from .proxmap import ProxMap + + +class ProxLasso(ProxMap): + @staticmethod + def tau(p: Tensor) -> float: + return 1.0 + + def apply_(self, p: Tensor, gamma: Union[Tensor, float]) -> Tensor: + super().apply_(p, gamma) + mult = (1 - self.threshold(p, gamma) / p.abs()).clamp(min=0) + p.mul_(mult) + return mult.eq(0).sum() diff --git a/torchao/prototype/pat/optim/nm_sgd.py b/torchao/prototype/pat/optim/nm_sgd.py new file mode 100644 index 0000000000..9199d858ee --- /dev/null +++ b/torchao/prototype/pat/optim/nm_sgd.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.optim import Optimizer + +from .pruneopt import PruneOptimizer + + +class NMSGDOptimizer(PruneOptimizer): + """From "A Normal Map-Based Proximal Stochastic Gradient Method": https://arxiv.org/pdf/2305.05828v2 + Other parameters: + norm_gamma: float, default 0.0 + If > 0, then normalize gamma by parameter group dimension. + This is the same as the "normalized" option in N:M sparsity. + """ + + def __init__( + self, + base_optimizer: Optimizer, + warmup_steps: int = 0, + reg_lambda: float = 0.0, + nm_gamma: float = 0.0, + ) -> None: + super().__init__( + base_optimizer=base_optimizer, + warmup_steps=warmup_steps, + reg_lambda=reg_lambda, + ) + self.nm_gamma = nm_gamma + for group in self.regularized_param_groups(): + group["gamma"] = self.nm_gamma + + def _set_gamma(self, group): + pass + + @torch._disable_dynamo + def restore_latent_params(self) -> None: + """Restore latent parameters as optimizer parameters""" + gamma_inv = 1.0 / self.nm_gamma + for group in self.regularized_param_groups(): + for p in group["params"]: + if p.requires_grad: + p.grad.add_(self.state[p]["latent"] - p, alpha=gamma_inv) + p.copy_(self.state[p]["latent"]) diff --git a/torchao/prototype/pat/optim/nuclear_norm.py b/torchao/prototype/pat/optim/nuclear_norm.py new file mode 100644 index 0000000000..f2d971a7b2 --- /dev/null +++ b/torchao/prototype/pat/optim/nuclear_norm.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +import torch +from torch import Tensor + +from .proxmap import ProxMap + + +class ProxNuclearNorm(ProxMap): + @staticmethod + def tau(p: Tensor) -> float: + return 1.0 + + def apply_(self, p: Tensor, gamma: Union[Tensor, float]) -> Tensor: + super().apply_(p, gamma) + thresh = self.threshold(p, gamma) + zero_mask = p.le(thresh) + p.sub_(torch.where(zero_mask, p, thresh)) + return zero_mask.sum() diff --git a/torchao/prototype/pat/optim/proxmap.py b/torchao/prototype/pat/optim/proxmap.py new file mode 100644 index 0000000000..1f2d3f0d39 --- /dev/null +++ b/torchao/prototype/pat/optim/proxmap.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Union + +import torch +from torch import Tensor + + +class ProxMap(ABC): + """Abstract base class that defines the proximal mapping interface""" + + def __init__(self, reg_lambda: float) -> None: + self.reg_lambda = reg_lambda + + @staticmethod + @abstractmethod + def tau(p: Tensor) -> float: + """Return group-level regularization strength""" + + def threshold(self, p: Tensor, gamma: Union[Tensor, float]) -> Union[Tensor, float]: + """Return pruning threshold""" + return self.reg_lambda * self.tau(p) * gamma + + def apply_(self, p: Tensor, gamma: Union[Tensor, float]) -> Tensor: + """Provide interface for pruning (modify p in-place): + pruner.apply_(p, q, step_count) + Inputs: + p (Tensor): full or group-level tensor to be pruned + gamma (float): typically the cumulative sum over step sizes + """ + if isinstance(gamma, float) and gamma == 0: + return torch.zeros(1, dtype=torch.long, device=p.device) diff --git a/torchao/prototype/pat/optim/pruneopt.py b/torchao/prototype/pat/optim/pruneopt.py new file mode 100644 index 0000000000..71da519fa4 --- /dev/null +++ b/torchao/prototype/pat/optim/pruneopt.py @@ -0,0 +1,402 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from collections.abc import Callable +from typing import Any + +import torch +from torch import Tensor +from torch.optim import Optimizer + +from ..utils import HAS_DTENSOR, instantiate_module, is_dtensor, is_main_process +from ..utils.distributed import _maybe_async_aggregate, _sum_async_streams +from ..utils.torch import get_index_linspace + +if HAS_DTENSOR: + from torch.distributed.tensor import distribute_tensor + from torch.distributed.tensor.experimental import local_map + from torch.distributed.tensor.placement_types import Partial, Replicate, Shard + + +class PruneOptimizer(Optimizer): + """PruneOptimizer assembles functionalities of the following objects: + a base optimizer (e.g., SGD or AdamW) + - update the latent variables for QAT + Other parameters: + warmup_steps: int >= 0 + """ + + def __init__( + self, + base_optimizer: Optimizer, + warmup_steps: int = 0, + reg_lambda: float = 0.0, + ) -> None: + # need to reconstruct these objects if loading checkpoint + self.base_optimizer = base_optimizer + + # need to store these attributes in state_dict for checkpoint + self.num_steps = 0 + self.warmup_steps = warmup_steps + + self.has_svd = False + for group in self.regularized_param_groups(): + group["gamma"] = 0.0 + group.setdefault("reg_lambda", reg_lambda) + if group.get("group_type", None) == "SVDGrouper": + self.has_svd = True + + self.relative_sparsity = 0 + self.relative_factored_frac = 0 + + # NOTE: Filling state dict here cause Adam(W) error, which assumes + # empty state[p] at first step() where optimizer states are initialized + + def __getattribute__(self, name: str): + try: + attr = super(Optimizer, self).__getattribute__(name) + except AttributeError: + attr = self.base_optimizer.__getattribute__(name) + return attr + + def __repr__(self) -> str: + base_optimizer = "\n ".join(self.base_optimizer.__repr__().split("\n")) + extra_repr = "\n ".join(("(", base_optimizer)) + return f"{self.__class__.__name__} {extra_repr}\n)" + + @property + def state(self) -> defaultdict[Tensor, Any]: # pyre-ignore[3] + return self._state if hasattr(self, "_state") else self.base_optimizer.state + + @torch._disable_dynamo + def state_dict(self) -> dict[str, Any]: + return self.base_optimizer.state_dict() + + @torch._disable_dynamo + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self.base_optimizer.load_state_dict(state_dict) + + @property + def num_steps(self) -> int: + for group in self.regularized_param_groups(): + return group.get("num_steps", 0) + + @num_steps.setter + def num_steps(self, value: int) -> None: + for group in self.regularized_param_groups(): + group["num_steps"] = value + return + + @num_steps.deleter + def num_steps(self) -> None: + for group in self.regularized_param_groups(): + group.pop("num_steps", None) + return + + def regularized_param_groups(self): # pyre-ignore[3] + """Yield parameter groups that need to be pruned.""" + for group in self.param_groups: + if group.get("prox_type"): + yield group + + @staticmethod + def _get_grouper_kwargs(group) -> dict[str, Any]: + grouper_kwargs = {} + if group["group_type"].startswith("AttentionHeadGrouper"): + grouper_kwargs["num_heads"] = group["num_heads"] + elif group["group_type"] == "QKGrouper": + if "qk_pack_dim" in group: + grouper_kwargs["qk_pack_dim"] = group["qk_pack_dim"] + if "qk_reg_index" in group: + grouper_kwargs["qk_reg_index"] = group["qk_reg_index"] + elif group["group_type"] == "PackedSVDGrouper": + grouper_kwargs["npack"] = group["npack"] + if "pack_dim" in group: + grouper_kwargs["pack_dim"] = group["pack_dim"] + return grouper_kwargs + + @staticmethod + def _apply_prox( + grouper, prox_map, p, sv_count=None, **prox_kwargs + ) -> tuple[Tensor, bool]: + """ + Apply `prox_map` to the grouped parameter tensor `p` in place. Update + `sv_count` if provided. Handles both torch.Tensor and DTensor inputs, + mirroring `torch.vmap` semantics. Assumes prox_map.apply_ returns an + integer per group. + + Returns: + zero_elts: number of zero elements after applying prox map + zeros_are_summed: whether zero_elts is already globally summed + """ + gamma = prox_kwargs["gamma"] + zeros_are_summed = False + with grouper: + gamma_in_dims = None + if prox_kwargs["gamma_index_slope"] > 0: + # y = slope(2x - 1) + 1 + gamma = gamma * get_index_linspace( + prox_kwargs["gamma_index_slope"], + grouper.n_groups(), + device=p.device, + ) + gamma_in_dims = 0 + + if prox_kwargs["disable_vmap"]: + # Element- or layer-wise pruning + zero_elts = prox_map.apply_(grouper.p, gamma) + else: + if not prox_kwargs["is_svd_grouper"] and is_dtensor(p): + if not torch.is_tensor(gamma): + gamma = torch.tensor(gamma, device=p.device) + + gamma_placements = (Replicate(),) + if grouper.in_dims is not None and gamma.dim() > 0: + # Shard gamma according to grouper.in_dims + gamma_placements = (Shard(grouper.in_dims),) + if gamma.dim() <= grouper.in_dims: + gamma = gamma.unsqueeze(0) + gamma = distribute_tensor( + gamma, + device_mesh=p.device_mesh, + placements=gamma_placements, + ) + + # Derive input placements from grouper.p + p_in_placements = ( + Shard(grouper.in_dims) + if grouper.in_dims is not None and plc.is_shard() + else plc + for plc in grouper.p.placements + ) + + # Use local_map for DTensor-aware vectorization + zero_elts_per_group = local_map( + prox_map.apply_, + out_placements=[Partial()], + in_placements=( + p_in_placements, + gamma.placements if is_dtensor(gamma) else None, + ), + redistribute_inputs=True, + )(grouper.p, gamma) + + # Gather counts by calling redistribute implicitly + zero_elts = zero_elts_per_group.full_tensor().item() + else: + # torch.Tensor branch - use standard vmap + zero_elts_per_group = torch.vmap( + prox_map.apply_, + in_dims=(grouper.in_dims, gamma_in_dims), + out_dims=0, + )(grouper.p, gamma) + zero_elts = zero_elts_per_group.sum().item() + zeros_are_summed = True + + # Adjust for group-based pruning + if not prox_kwargs["is_svd_grouper"]: + zero_elts *= grouper.group_size() + + # Record for reconstruction and logging + if prox_kwargs["is_svd_grouper"]: + dim = 0 if sv_count.dim() > 1 else None + sv_count.copy_( + (grouper.p != 0).to(torch.uint8).sum(dim=dim) + if is_dtensor(p) + else torch.count_nonzero(grouper.p, dim=dim) + ) + + return zero_elts, zeros_are_summed + + def _set_gamma(self, group): + # AProx in practice: ensure shrinkage coefficient >= 1 + group["gamma"] += group["lr"] + + @torch.no_grad() + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if self.num_steps < self.warmup_steps: + # warmup stage: running the base optimizer only + loss = self.base_optimizer.step(closure=closure) # pyre-ignore[6] + self.num_steps += 1 + return loss + + if self.num_steps == self.warmup_steps: + # first step of qat, save latent params, instead of restore + self.save_latent_params() + else: + # qat: restore latent params for update by the base optimizer + self.restore_latent_params() + + # call base optimizer step() method to update latent parameters + loss = self.base_optimizer.step(closure=closure) # pyre-ignore[6] + + if hasattr(self, "_state"): + assert self.warmup_steps == 0 + # restore the temporary state to the base optimizer's state + for p in self._state.keys(): + self.base_optimizer.state[p]["latent"] = self._state[p]["latent"] + del self._state + + regularized_params = 0 + regularized_unfactored_size = 0 + if torch.distributed.is_initialized(): + regularized_zeros_buf = [] + regularized_factored_size_buf = [] + + regularized_zeros = 0 + regularized_factored_size = 0 + + for group in self.regularized_param_groups(): + self._set_gamma(group) + + # apply shrinkage to latent parameters in place + prox_map = instantiate_module( + f"torchao.prototype.pat.optim.{group['prox_type']}" + )(group["reg_lambda"]) + + # grouper is a context manager that reshapes p if needed + grouper_cls = instantiate_module( + f"torchao.prototype.pat.group.{group['group_type']}" + ) + grouper_kwargs = self._get_grouper_kwargs(group) + prox_kwargs = { + "gamma": group["gamma"], + "gamma_index_slope": group.get("gamma_index_slope", 0.0), + "disable_vmap": group["group_type"].endswith( + ("ElemGrouper", "LayerGrouper") + ), + "is_svd_grouper": group["group_type"].endswith("SVDGrouper"), + } + for p in group["params"]: + if not p.requires_grad: + continue + + # save latent parameters + state = self.state[p] + state["latent"].copy_(p) + + # store the number of non-zero singular values + if prox_kwargs["is_svd_grouper"]: + npack = grouper_kwargs.get("npack", 1) + state.setdefault( + "sv_count", torch.zeros(npack, dtype=torch.int, device=p.device) + ) + + # update the full tensor if sharded + sharded_p = None + if is_dtensor(p) and prox_kwargs["is_svd_grouper"]: + sharded_p = p + p = p.full_tensor() + + # only rank 0 of the device mesh should run the grouper + sv_count = state.get("sv_count") + if sharded_p is None or sharded_p.device_mesh.get_rank() == 0: + grouper = grouper_cls(p, **grouper_kwargs) + zero_elts, zeros_are_summed = self._apply_prox( + grouper, prox_map, p, sv_count=sv_count, **prox_kwargs + ) + if zeros_are_summed: + state["sparsity_frac"] = zero_elts / grouper.p.numel() + else: + _maybe_async_aggregate(regularized_zeros_buf, zero_elts) + + if torch.is_tensor(zero_elts): + zero_elts = zero_elts.item() + + if prox_kwargs["is_svd_grouper"]: + unfactored_size = grouper.U.size(0) * grouper.Vh.size(1) + n_singular_vals = grouper.p.numel() - zero_elts + factored_size = ( + grouper.U.size(0) + grouper.Vh.size(1) + ) * n_singular_vals + group["factored_frac"] = factored_size / unfactored_size + # Only aggregate if not already globally summed + if zeros_are_summed: + regularized_factored_size += factored_size + else: + _maybe_async_aggregate( + regularized_factored_size_buf, + torch.tensor( + factored_size, dtype=torch.int, device=p.device + ), + ) + + regularized_unfactored_size += unfactored_size + + # Only factor matrices if it reduces params + regularized_zeros += max(unfactored_size - factored_size, 0) + regularized_params += unfactored_size + else: + regularized_zeros += zero_elts + regularized_params += grouper.p.numel() + + # copy the updated full tensor to the sharded tensor + if sharded_p is not None: + torch.distributed.barrier() + if isinstance(sv_count, Tensor): + torch.distributed.broadcast(sv_count, src=0) + sharded_p.copy_( + distribute_tensor( + p, + device_mesh=sharded_p.device_mesh, + placements=sharded_p.placements, + ) + ) + + self.num_steps += 1 + + if torch.distributed.is_initialized() and is_main_process(): + regularized_zeros += _sum_async_streams(regularized_zeros_buf) + regularized_factored_size += _sum_async_streams( + regularized_factored_size_buf + ) + + if is_main_process(): + self.relative_sparsity = ( + regularized_zeros / regularized_params + if regularized_params > 0 + else 0.0 + ) + self.relative_factored_frac = ( + regularized_factored_size / regularized_unfactored_size + if regularized_unfactored_size > 0 + else 0.0 + ) + + return loss + + @torch._disable_dynamo + def restore_latent_params(self) -> None: + """Restore latent parameters as optimizer parameters""" + for group in self.regularized_param_groups(): + for p in group["params"]: + if p.requires_grad: + p.copy_(self.state[p]["latent"]) + + @torch._disable_dynamo + def save_latent_params(self) -> None: + """Save updated latent parameters before applying prox-map""" + if self.warmup_steps == 0: + assert len(self.state) == 0, "Expected empty state at first step()" + # Maintain the invariant that `len(self.state) == 0` before first + # self.base_optimizer.step() call by using a temporary state buffer + self._state = defaultdict(dict) + + for group in self.regularized_param_groups(): + for p in group["params"]: + if p.requires_grad: + state = self.state[p] + if "latent" not in state: + state["latent"] = p.detach().clone() + else: + state["latent"].copy_(p) diff --git a/torchao/prototype/pat/utils/__init__.py b/torchao/prototype/pat/utils/__init__.py new file mode 100644 index 0000000000..9879453523 --- /dev/null +++ b/torchao/prototype/pat/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .distributed import HAS_DTENSOR, is_dtensor # noqa: F401 +from .torch import ( # noqa: F401 + get_param_groups, + insert_svd_modules_, + instantiate_module, + is_main_process, + use_deterministic_algorithms, +) diff --git a/torchao/prototype/pat/utils/distributed.py b/torchao/prototype/pat/utils/distributed.py new file mode 100644 index 0000000000..17a5f0c5de --- /dev/null +++ b/torchao/prototype/pat/utils/distributed.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +from torch import Tensor +from torch import distributed as dist + +try: + from torch.distributed.tensor import DTensor + + HAS_DTENSOR = True +except ImportError: + HAS_DTENSOR = False + + +def is_main_process(): + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + return rank == 0 + + +def is_dtensor(x): + return HAS_DTENSOR and isinstance(x, DTensor) + + +class NoopHandle: + def wait(self): + pass + + +def _maybe_async_aggregate( + handle_buf: List[tuple[Tensor, dist.Work | NoopHandle]], input_tensor: Tensor +) -> None: + if dist.is_initialized() and not is_dtensor(input_tensor): + handle = dist.reduce(input_tensor, dst=0, async_op=True) + handle_buf.append((input_tensor, handle)) + else: + if is_dtensor(input_tensor): + input_tensor = input_tensor.full_tensor() + if is_main_process(): + handle_buf.append((input_tensor, NoopHandle())) + + +def _sum_async_streams(handle_buf: List[tuple[Tensor, dist.Work | NoopHandle]]) -> int: + assert isinstance(handle_buf, list), ( + f"Expected a list of async handles but got {type(handle_buf)}" + ) + output = 0 + for input_tensor, handle in handle_buf: + handle.wait() + output += input_tensor.item() + handle_buf.clear() + return output diff --git a/torchao/prototype/pat/utils/torch.py b/torchao/prototype/pat/utils/torch.py new file mode 100644 index 0000000000..fa45fb1e55 --- /dev/null +++ b/torchao/prototype/pat/utils/torch.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import re +from contextlib import contextmanager +from functools import partial +from importlib import import_module +from typing import Any, Dict, Optional, Set, Tuple + +import torch +from torch import nn + +from .distributed import is_main_process + +RE_PREFIX = ":" + + +def get_index_linspace( + index_slope: float, + n_indices: int, + device: torch.device, + max_val: Optional[float] = None, +): + gamma_multiplier = ( + torch.linspace(1 - index_slope, 1 + index_slope, n_indices, device=device) + .div_(2.0) + .clamp_(min=0.0, max=max_val) + ) + return gamma_multiplier + + +@contextmanager +def use_deterministic_algorithms(): + """Context manager to enable deterministic algorithms in PyTorch""" + deterministic_restore = torch.are_deterministic_algorithms_enabled() + torch.use_deterministic_algorithms(True) + try: + yield + finally: + torch.use_deterministic_algorithms(deterministic_restore) + + +class FuncDescriptor: + def __init__(self, func): + self.func = func + + def __get__(self, instance, owner): + return self.func(instance) + + +def instantiate_module(module_name: str): + prefix, name = module_name.rsplit(".", 1) + module = getattr(import_module(prefix), name) + return module + + +def get_param_groups( + model: nn.Module, + prune_config: Dict[Tuple[nn.Module, str], Any], + skip_wd_names: Optional[Set[str]] = None, + verbose: bool = True, +) -> Dict[str, Dict[str, Any]]: + # Create list of regex patterns for matching parameter names + re_pats = [ + re.compile(k[len(RE_PREFIX) :]) + for k in prune_config.keys() + if isinstance(k, str) and k.startswith(RE_PREFIX) + ] + + param_dict = {} + seen_tensors = set() + for param_name, param in model.named_parameters(): + module_name, _, param_basename = param_name.rpartition(".") + parent_module = model.get_submodule(module_name) if module_name else model + if param in seen_tensors: + continue + seen_tensors.add(param) + + group_key, group_val = None, None + for re_pat in re_pats: + if re_pat.match(param_name): + group_key = re_pat.pattern + group_val = prune_config[f"{RE_PREFIX}{group_key}"] + break + + # Check for exact parameter or module name matches + if group_key is None: + module_cls = parent_module.__class__ + if param_name in prune_config: + group_key = param_name + elif (module_cls, param_basename) in prune_config: + group_key = (module_cls, param_basename) + elif ( + param_basename == "bias" + or skip_wd_names + and param_basename in skip_wd_names + ): + group_key, group_val = "no_wd", {"weight_decay": 0} + else: + group_key, group_val = "wd", {} + + if group_val is None: + group_val = prune_config[group_key] + + param_dict.setdefault(group_key, group_val).setdefault("params", []).append( + param + ) + + param_groups = list(param_dict.values()) + + n_found_params = sum(len(v["params"]) for v in param_groups) + n_expect_params = len(list(model.parameters())) + assert n_found_params == n_expect_params, f"{n_found_params=}, {n_expect_params=}" + + if verbose and is_main_process(): + for k, v in param_dict.items(): + print(f"{k}: {len(v['params'])} params") + return param_groups + + +def latent_svd(self, name=""): + """Used when monkey patching the parameter to use SVD.""" + U = getattr(self, f"{name}_U") + S = getattr(self, f"{name}_S") + Vh = getattr(self, f"{name}_Vh") + orig_shape = torch.Size(getattr(self, f"{name}_orig_shape")) + return torch.linalg.multi_dot([U, torch.diag(S), Vh]).view(orig_shape) + + +def insert_svd_modules_(model: nn.Module, optimizer: torch.optim.Optimizer): + """Replaces the parameters of the model with their SVD decompositions.""" + param_set = { + p.data_ptr() + for group in optimizer.regularized_param_groups() + for p in group["params"] + if group["group_type"] == "SVDGrouper" + } + + def insert_inner_(model): + for mn, module in model.named_children(): + params_to_add = {} + for pn, p in module.named_parameters(recurse=False): + if p.data_ptr() not in param_set: + continue + + k = int(optimizer.state[p]["sv_count"].item()) + assert k > 0, f"Invalid sv_count={k}" + with instantiate_module("pat.group.SVDGrouper")(p) as grouper: + # patch parameter with SVD + module.register_buffer( + f"{pn}_orig_shape", torch.tensor(grouper.orig_shape) + ) + U, S, Vh = grouper.U[:, :k], grouper.p[:k], grouper.Vh[:k] + for name, value in zip( + (f"{pn}_U", f"{pn}_S", f"{pn}_Vh"), + (U, S, Vh), + ): + params_to_add[name] = value + + if is_main_process(): + print( + f"{tuple(grouper.orig_shape)} -> " + f"{tuple(U.shape)}, {tuple(S.shape)}, {tuple(Vh.shape)}" + ) + + module.__dict__.pop(pn, None) # delete the original parameter + setattr( + module.__class__, + pn, + FuncDescriptor(partial(latent_svd, name=pn)), + ) + + for name, value in params_to_add.items(): + module.register_parameter( + name, nn.Parameter(value, requires_grad=False) + ) + del params_to_add + + insert_inner_(module) + + insert_inner_(model)