-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprune.py
More file actions
65 lines (48 loc) · 2.19 KB
/
prune.py
File metadata and controls
65 lines (48 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import logging
import torch
import numpy as np
import process
def prune_by_ratio(weight, bias, ratio, device):
num_weight = torch.numel(weight)
# Weight mask
threshold = np.sort(weight.abs().flatten())[int(num_weight*ratio)]
weight_mask = torch.ge(weight.abs(), threshold).type('torch.FloatTensor').to(device)
# Bias mask
bias_mask = torch.ones(bias.size()).to(device)
for i in range(bias_mask.size(0)):
if len(torch.nonzero(weight_mask[i]).size()) == 0:
bias_mask[i] = 0
prune_ratio = (num_weight - torch.nonzero(weight_mask).size(0))/num_weight
return weight_mask, bias_mask, prune_ratio
def prune_by_nueron(weight, bias, number, device, ascending=False):
num_weight = torch.numel(weight)
sign = 1 if ascending else -1
# Weight mask
threshold = sign*np.sort(sign*weight.flatten())[number]
weight_mask = torch.lt(weight, threshold).type('torch.FloatTensor').to(device)
# Bias mask
bias_mask = torch.ones(bias.size()).to(device)
for i in range(bias_mask.size(0)):
if len(torch.nonzero(weight_mask[i]).size()) == 0:
bias_mask[i] = 0
prune_ratio = (num_weight - torch.nonzero(weight_mask).size(0))/num_weight
return weight_mask, bias_mask, prune_ratio
def prune(model, criterion, loader, prune_layers, **kwargs):
device = kwargs.get('device')
acc = []
for name, m in model.named_modules():
if name in prune_layers:
logging.info("Testing layer {}".format(name))
prune_number = list(range(0, m.out_features, m.out_features//1000))
for number in prune_number:
logging.info('Pruning top %d neurons' % number)
weight = m.weight.data.cpu()
bias = m.bias.data.cpu()
weight_mask, bias_mask, _ = prune_by_nueron(weight, bias, number, device)
m.weight.data *= weight_mask
m.bias.data *= bias_mask
top1, _, _ = process.validate(loader, model, criterion, device=device)
acc.append([number, top1])
m.weight.data = weight.to(device)
m.bias.data = bias.to(device)
return acc