-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup_net.py
More file actions
89 lines (70 loc) · 2.69 KB
/
setup_net.py
File metadata and controls
89 lines (70 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch.nn as nn
def setup_net(pars):
net = nn.Sequential()
classifier = nn.Sequential()
head = nn.Sequential()
if pars.clf_dataset == 'Cifar100':
NUM_CLASS = 100
elif pars.clf_dataset == 'Cifar10':
NUM_CLASS = 10
else:
NUM_CLASS = 1000
HW = 32
NUM_CHANNEL = 32
pars.NUM_LAYER = 5
if pars.nonlinear == 'hardtanh':
nonlinear = nn.Hardtanh()
else:
nonlinear = nn.ReLU()
for i in range(pars.NUM_LAYER):
layer = nn.Sequential()
if i==0:
layer.add_module('conv', nn.Conv2d(3,int(NUM_CHANNEL),3,padding=1))
layer.add_module('activation', nonlinear)
elif (i == 1) or (i == 3):
layer.add_module('conv', nn.Conv2d(int(NUM_CHANNEL),int(NUM_CHANNEL),3,padding=1))
layer.add_module('maxpool', nn.MaxPool2d(2))
HW /= 2
elif i == 2:
layer.add_module('conv', nn.Conv2d(int(NUM_CHANNEL),int(NUM_CHANNEL*2),3,padding=1))
layer.add_module('activation', nonlinear)
NUM_CHANNEL *= 2
else:
layer.add_module('conv', nn.Conv2d(int(NUM_CHANNEL),int(NUM_CHANNEL*8),3,padding=1))
layer.add_module('maxpool', nn.MaxPool2d(2))
NUM_CHANNEL *= 8
HW /= 2
net.add_module('layer%d'%i, layer)
aux = nn.Sequential(
nn.Flatten(),
)
aux.add_module('fc', nn.Linear(int(NUM_CHANNEL*HW*HW), NUM_CLASS))
auxhead = nn.Sequential(
nn.Flatten(),
)
auxhead.add_module('fc', nn.Linear(
int(NUM_CHANNEL*HW*HW), pars.headsize))
classifier.add_module('aux', aux)
if pars.clfnonlinear == 'softmax':
classifier.add_module('softmax', nn.Softmax())
head.add_module('auxhead', auxhead)
if pars.headnonlinear == 'tanh':
head.add_module('activation', nn.Tanh())
return net, classifier, head
def setup_decoder(pars):
decoder = nn.Sequential()
auxdecoder = nn.Sequential()
auxdecoder.add_module('fc', nn.Linear(
pars.headsize, pars.decoder_channel*32*32))
auxdecoder.add_module('relu', nn.ReLU())
auxdecoder.add_module('unflatten', nn.Unflatten(1, (pars.decoder_channel, 32, 32)))
if pars.decoder_layer>1:
for i in range(pars.decoder_layer-1):
auxdecoder.add_module('deconv'+str(i), nn.ConvTranspose2d(
pars.decoder_channel, pars.decoder_channel, 3, padding=1))
auxdecoder.add_module('relu'+str(i), nn.ReLU())
auxdecoder.add_module('deconv', nn.ConvTranspose2d(
pars.decoder_channel, 3, 3, padding=1))
auxdecoder.add_module('sigmoid', nn.Sigmoid())
decoder.add_module('auxdecoder', auxdecoder)
return decoder