Skip to content

Commit 461481f

Browse files
committed
Add pruning-aware training in torchao.prototype.pat
1 parent cb67b03 commit 461481f

File tree

22 files changed

+1583
-0
lines changed

22 files changed

+1583
-0
lines changed

test/prototype/test_pat.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import random
8+
import unittest
9+
10+
import torch
11+
from torch import nn
12+
from torch.testing._internal import common_utils
13+
14+
from torchao.prototype.pat.group import (
15+
AttentionHeadGrouperDim0,
16+
AttentionHeadGrouperDim1,
17+
PackedSVDGrouper,
18+
QKGrouper,
19+
QKSVDGrouper,
20+
SVDGrouper,
21+
)
22+
from torchao.prototype.pat.layers.masked_layernorm import MaskedLayerNorm
23+
from torchao.prototype.pat.optim import ProxGroupLasso, ProxNuclearNorm, PruneOptimizer
24+
from torchao.prototype.pat.utils import get_param_groups
25+
26+
27+
class TestMaskedLayerNorm(common_utils.TestCase):
28+
@common_utils.parametrize("batch", [1, 4])
29+
@common_utils.parametrize("seq_len", [2, 8])
30+
@common_utils.parametrize("embed_dim", [16, 64])
31+
def test_masked_layernorm(self, batch=1, seq_len=2, embed_dim=16):
32+
dim2_nz = embed_dim // 2
33+
embed = torch.randn(batch, seq_len, embed_dim)
34+
embed[..., dim2_nz:] = 0
35+
36+
masked_layer_norm = MaskedLayerNorm(embed_dim)
37+
layer_norm = nn.LayerNorm(dim2_nz)
38+
with torch.no_grad():
39+
layer_norm.weight.copy_(masked_layer_norm.weight[:dim2_nz])
40+
layer_norm.bias.copy_(masked_layer_norm.bias[:dim2_nz])
41+
42+
out = masked_layer_norm(embed)
43+
expected_out = layer_norm(embed[..., :dim2_nz])
44+
torch.testing.assert_close(out[..., :dim2_nz], expected_out)
45+
46+
47+
class MHADummyModel(nn.Module):
48+
def __init__(self, embed_dim, num_heads, n_cls):
49+
super().__init__()
50+
self.mha = nn.MultiheadAttention(embed_dim, num_heads, bias=False)
51+
self.classifier = nn.Linear(embed_dim, n_cls)
52+
53+
def forward(self, x):
54+
attn_output, _ = self.mha(x, x, x)
55+
out = self.classifier(attn_output)
56+
return out
57+
58+
59+
class TestQKGrouper(common_utils.TestCase):
60+
def __init__(self, methodName):
61+
super(TestQKGrouper, self).__init__(methodName)
62+
self.reg_lambda = 1.0
63+
self.prox_map = ProxGroupLasso(self.reg_lambda)
64+
65+
@staticmethod
66+
def _get_qk(p, embed_dim, qk_reg_index):
67+
qk = p[:embed_dim] if qk_reg_index == 0 else p[embed_dim : (embed_dim * 2)]
68+
return qk
69+
70+
def get_gamma(self, p):
71+
"""Heuristic that uses the mean of the group to set gamma."""
72+
p_col = p[:, 0]
73+
gamma = (1 - p_col.mean()) * torch.linalg.vector_norm(p_col)
74+
gamma.div_(self.prox_map.tau(p_col))
75+
return gamma
76+
77+
def _test_post_prune(self, p, qk_orig, embed_dim, qk_reg_index, gamma):
78+
qk = self._get_qk(p, embed_dim, qk_reg_index)
79+
nz_mask = qk.sum(dim=0).ne(0)
80+
self.assertTrue(nz_mask.eq(0).any(), "No columns of Q/K were pruned")
81+
82+
# original columns that are <= gamma are pruned
83+
expect_nz_mask = qk_orig.gt(gamma).all(dim=0)
84+
torch.testing.assert_close(nz_mask > 1, expect_nz_mask, atol=0, rtol=0)
85+
86+
def _test_mha_inner(self, p, embed_dim, qk_reg_index):
87+
qk_orig = self._get_qk(p, embed_dim, qk_reg_index).clone()
88+
qk_no_prune = self._get_qk(p, embed_dim, int(not qk_reg_index)).clone()
89+
v_orig = p[(embed_dim * 2) :].clone()
90+
qk_pack_dim = 0
91+
with QKGrouper(p, qk_pack_dim, qk_reg_index) as grouper:
92+
self.assertTrue(grouper.p.equal(qk_orig))
93+
94+
gamma = self.get_gamma(grouper.p)
95+
_ = torch.vmap(
96+
self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0
97+
)(grouper.p, gamma)
98+
99+
self._test_post_prune(p, qk_orig, embed_dim, qk_reg_index, gamma)
100+
101+
# unregularized query or key was not modified
102+
no_prune = self._get_qk(p, embed_dim, int(not qk_reg_index))
103+
torch.testing.assert_close(no_prune, qk_no_prune, atol=0, rtol=0)
104+
105+
# value was not modified
106+
v = p[(embed_dim * 2) :]
107+
torch.testing.assert_close(v, v_orig, atol=0, rtol=0)
108+
109+
@common_utils.parametrize("embed_dim", [16, 64])
110+
@common_utils.parametrize("num_heads", [2, 4])
111+
@common_utils.parametrize("qk_reg_index", [0, 1])
112+
def test_pytorch_mha(self, embed_dim=16, num_heads=4, qk_reg_index=0):
113+
assert embed_dim % num_heads == 0, (
114+
f"{embed_dim=} must be divisible by {num_heads=}"
115+
)
116+
117+
# single in_proj_weight of shape (embed_dim * 3, embed_dim)
118+
model = nn.MultiheadAttention(embed_dim, num_heads, bias=False)
119+
p = model.in_proj_weight.detach()
120+
self._test_mha_inner(p, embed_dim, qk_reg_index)
121+
122+
@common_utils.parametrize("qk_reg_index", [0, 1])
123+
def test_e2e_optimizer(self, embed_dim=64, qk_reg_index=0):
124+
n_cls = 3
125+
model = MHADummyModel(embed_dim, num_heads=4, n_cls=n_cls)
126+
prune_config = {
127+
"mha.in_proj_weight": {
128+
"group_type": "QKGrouper",
129+
"prox_type": "ProxGroupLasso",
130+
"qk_pack_dim": 0,
131+
"qk_reg_index": qk_reg_index,
132+
}
133+
}
134+
param_groups = get_param_groups(model, prune_config, verbose=False)
135+
self.assertEqual(len(param_groups), 3)
136+
137+
p = model.mha.in_proj_weight.detach()
138+
qk_orig = self._get_qk(p, embed_dim, qk_reg_index).clone()
139+
140+
# set lr to gamma since we run a single step
141+
gamma = self.get_gamma(qk_orig)
142+
optimizer = PruneOptimizer(
143+
torch.optim.SGD(param_groups, lr=gamma), reg_lambda=self.reg_lambda
144+
)
145+
146+
data = torch.randn(1, 8, embed_dim)
147+
label = torch.arange(0, n_cls) * data.mean(axis=-1, keepdim=True)
148+
output = model(data)
149+
loss = nn.functional.mse_loss(output, label)
150+
151+
optimizer.zero_grad()
152+
loss.backward()
153+
optimizer.step()
154+
155+
self._test_post_prune(p, qk_orig, embed_dim, qk_reg_index, gamma)
156+
157+
158+
class TestAttentionHeadGrouper(common_utils.TestCase):
159+
def __init__(self, methodName):
160+
super(TestAttentionHeadGrouper, self).__init__(methodName)
161+
self.reg_lambda = 1.0
162+
self.prox_map = ProxGroupLasso(self.reg_lambda)
163+
164+
@staticmethod
165+
def _get_view_shape_reduce_dim(dim, num_heads, head_pack_dim):
166+
if head_pack_dim == 0:
167+
view_shape = (num_heads, -1, dim)
168+
reduce_dim = (1, 2)
169+
else:
170+
view_shape = (dim, num_heads, -1)
171+
reduce_dim = (0, 2)
172+
return view_shape, reduce_dim
173+
174+
def _test_post_prune(self, p, p_orig, head_pack_dim, view_shape, reduce_dim, gamma):
175+
nz_mask = p.view(*view_shape).sum(dim=reduce_dim).ne(0)
176+
self.assertTrue(nz_mask.eq(0).any(), "No groups of p were pruned")
177+
178+
# original groups that are <= gamma are pruned
179+
expect_nz_mask = p_orig.view(*view_shape).gt(gamma).all(dim=reduce_dim)
180+
torch.testing.assert_close(nz_mask > 1, expect_nz_mask, atol=0, rtol=0)
181+
182+
def get_gamma(self, p, head_pack_dim, view_shape):
183+
"""Heuristic that uses the mean of the group to set gamma."""
184+
p = p.view(*view_shape)
185+
p_group = p[0] if head_pack_dim == 0 else p[:, 0]
186+
gamma = (1 - p_group.mean()) * torch.linalg.vector_norm(p_group)
187+
gamma.div_(self.prox_map.tau(p_group))
188+
return gamma
189+
190+
@common_utils.parametrize("dim", [64, 128])
191+
@common_utils.parametrize("head_pack_dim", [0, 1])
192+
def test_head_grouper(self, dim=16, head_pack_dim=0, head_dim_ratio=8):
193+
assert dim % head_dim_ratio == 0, (
194+
f"{dim=} must be divisible by {head_dim_ratio=}"
195+
)
196+
num_heads = dim // 8
197+
packed_dim = dim * num_heads
198+
shape = (dim, packed_dim) if head_pack_dim == 0 else (packed_dim, dim)
199+
model = nn.Linear(*shape, bias=False)
200+
p = model.weight.detach()
201+
p_orig = p.clone()
202+
view_shape, reduce_dim = self._get_view_shape_reduce_dim(
203+
dim, num_heads, head_pack_dim
204+
)
205+
grouper_cls = (
206+
AttentionHeadGrouperDim0 if head_pack_dim == 0 else AttentionHeadGrouperDim1
207+
)
208+
with grouper_cls(p, num_heads) as grouper:
209+
gamma = self.get_gamma(grouper.p, head_pack_dim, view_shape)
210+
_ = torch.vmap(
211+
self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0
212+
)(grouper.p, gamma)
213+
self.assertEqual(grouper.p.size(head_pack_dim), num_heads)
214+
self._test_post_prune(p, p_orig, head_pack_dim, view_shape, reduce_dim, gamma)
215+
216+
217+
class TestSVDGrouper(common_utils.TestCase):
218+
def __init__(self, methodName):
219+
super(TestSVDGrouper, self).__init__(methodName)
220+
self.reg_lambda = 1.0
221+
self.prox_map = ProxNuclearNorm(self.reg_lambda)
222+
223+
@common_utils.parametrize("embed_dim", (16, 64))
224+
def test_grouper(self, embed_dim=16):
225+
model = torch.nn.Linear(embed_dim, embed_dim)
226+
p = model.weight
227+
with SVDGrouper(p) as grouper:
228+
gamma = grouper.p.mean()
229+
p_orig = grouper.p.clone()
230+
torch.vmap(
231+
self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0
232+
)(grouper.p, gamma)
233+
expect_nz_mask = p_orig.gt(gamma)
234+
torch.testing.assert_close(grouper.p.ne(0), expect_nz_mask, atol=0, rtol=0)
235+
236+
@common_utils.parametrize("embed_dim", (16, 64))
237+
@common_utils.parametrize("pack_dim", (0, 1))
238+
def test_qk_grouper(self, embed_dim=16, pack_dim=0):
239+
shape = [embed_dim, embed_dim]
240+
shape[int(not pack_dim)] *= 3
241+
model = torch.nn.Linear(*shape)
242+
p = model.weight
243+
with QKSVDGrouper(p, pack_dim=pack_dim) as grouper:
244+
gamma = grouper.p.mean()
245+
p_orig = grouper.p.clone()
246+
torch.vmap(
247+
self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0
248+
)(grouper.p, gamma)
249+
expect_nz_mask = p_orig.gt(gamma)
250+
torch.testing.assert_close(grouper.p.ne(0), expect_nz_mask, atol=0, rtol=0)
251+
252+
@common_utils.parametrize("embed_dim", (16, 64))
253+
@common_utils.parametrize("pack_dim", (0, 1))
254+
def test_packed_grouper(self, embed_dim=16, npack=3, pack_dim=0):
255+
shape = [embed_dim, embed_dim]
256+
shape[int(not pack_dim)] *= npack
257+
model = torch.nn.Linear(*shape)
258+
p = model.weight
259+
with PackedSVDGrouper(p, npack, pack_dim=pack_dim) as grouper:
260+
gamma = grouper.p.mean(0).mean()
261+
p_orig = grouper.p.clone()
262+
torch.vmap(
263+
self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0
264+
)(grouper.p.flatten(), gamma)
265+
torch.testing.assert_close(
266+
grouper.p.ne(0), p_orig.gt(gamma), atol=0, rtol=0
267+
)
268+
self.assertEqual(p.data_ptr(), grouper._p.data_ptr())
269+
270+
271+
common_utils.instantiate_parametrized_tests(TestMaskedLayerNorm)
272+
common_utils.instantiate_parametrized_tests(TestQKGrouper)
273+
common_utils.instantiate_parametrized_tests(TestAttentionHeadGrouper)
274+
common_utils.instantiate_parametrized_tests(TestSVDGrouper)
275+
276+
if __name__ == "__main__":
277+
random.seed(0)
278+
torch.manual_seed(0)
279+
unittest.main()

torchao/prototype/pat/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# PAT: Pruning-Aware Training
2+
3+
PAT is a library based on group Lasso regularization. It directly induces structured sparsity during training, removing the need for custom pruning metrics and/or multiple rounds of training.
4+
5+
PAT's simple optimizer-only interface supports easy integration into existing training pipelines. The code is organized into two main components:
6+
* grouper: defines the granularity of pruning (e.g., filter, channel, layer)
7+
* proximal mapping: projects groups of weights onto sparse values
8+
9+
## Optimizer-only interface
10+
11+
This package provides a `PruneOptimizer` that simply wraps around a base optimizer inheriting from `torch.optim.Optimizer`. The following code snippet illustrates how to set up PAT:
12+
13+
```python
14+
from pat.optim import PruneOptimizer
15+
16+
model = torchvision.models.resnet18().cuda()
17+
18+
# split params into prunable and non-prunable groups
19+
weights = [p for name, p in model.named_parameters() if name.endswith("weight")]
20+
others = [p for name, p in model.named_parameters() if not name.endswith("weight")]
21+
22+
# apply row-wise group Lasso regularization to the weights
23+
param_groups = [
24+
{
25+
"params": weights",
26+
"group_type": "pat.group.Dim0Grouper",
27+
"prox_type": "pat.prox.ProxGroupLasso",
28+
"reg_lambda": 2e-4,
29+
},
30+
{"params": others},
31+
]
32+
33+
# create base optimizer (SGD, Adam or AdamW)
34+
base_optimizer = torch.optim.SGD(
35+
param_groups, lr=0.1, momentum=0.9, weight_decay=1e-4
36+
)
37+
38+
# create PruneOptimizer
39+
optimizer = PruneOptimizer(base_optimizer)
40+
```
41+
42+
After creating `PruneOptimizer`, one can use it as a regular PyTorch optimizer.
43+
44+
## Pruning configuration
45+
46+
Pruning configs are dictionaries that define which parameter groups to prune and how to prune them. Each key-value pair in the config maps to a prunable parameter group of `PruneOptimizer`. The keys are used to match model parameters, while the values specify the pruning granularity and proximal map. The key can be one of the following types:
47+
48+
- parameter name (string): e.g., `blocks.0.attn.qkv.weight`
49+
- regex pattern (string): e.g., `:.*attn\.qkv\.weight`
50+
- module type and parameter name suffix ((class, string) tuple): e.g., `(torch.nn.Linear, 'weight')`

torchao/prototype/pat/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .group import ( # noqa: F401
8+
AttentionHeadGrouperDim0,
9+
AttentionHeadGrouperDim1,
10+
ConvFilterGrouper,
11+
Dim0Grouper,
12+
Dim1Grouper,
13+
ElemGrouper,
14+
LayerGrouper,
15+
PackedSVDGrouper,
16+
QKGrouper,
17+
QKSVDGrouper,
18+
SVDGrouper,
19+
)
20+
from .optim import ( # noqa: F401
21+
NMSGDOptimizer,
22+
ProxGroupLasso,
23+
ProxGroupLassoReduce,
24+
ProxLasso,
25+
ProxNuclearNorm,
26+
PruneOptimizer,
27+
)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .attention import ( # noqa: F401
8+
AttentionHeadGrouperDim0,
9+
AttentionHeadGrouperDim1,
10+
QKGrouper,
11+
)
12+
from .conv import ConvFilterGrouper # noqa: F401
13+
from .dim import Dim0Grouper, Dim1Grouper # noqa: F401
14+
from .grouper import ( # noqa: F401
15+
ElemGrouper,
16+
Grouper,
17+
LayerGrouper,
18+
)
19+
from .low_rank import PackedSVDGrouper, QKSVDGrouper, SVDGrouper # noqa: F401

0 commit comments

Comments
 (0)