Skip to content

Commit 95dc0f4

Browse files
authored
Merge pull request #16 from MathGaron/feature/modules
add common modules
2 parents 6c9c301 + 2ff7308 commit 95dc0f4

File tree

6 files changed

+143
-26
lines changed

6 files changed

+143
-26
lines changed

examples/classification/cat_dog_net.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,33 @@
11
import torch.nn.functional as F
22
import torch.nn as nn
3-
from pytorch_toolbox.network.network_base import NetworkBase
3+
from pytorch_toolbox.network_base import NetworkBase
4+
from pytorch_toolbox.modules.conv2d_module import ConvBlock
5+
from pytorch_toolbox.modules.fc_module import FCBlock
46

57

68
class CatDogNet(NetworkBase):
79
def __init__(self):
810
super(CatDogNet, self).__init__()
9-
self.conv1 = nn.Conv2d(3, 24, 5)
10-
self.conv1_bn = nn.BatchNorm2d(24)
11-
self.dropout1 = nn.Dropout2d(0.25)
12-
self.conv2 = nn.Conv2d(24, 48, 3)
13-
self.conv2_bn = nn.BatchNorm2d(48)
14-
self.dropout2 = nn.Dropout2d(0.25)
15-
self.conv3 = nn.Conv2d(48, 48, 3)
16-
self.conv3_bn = nn.BatchNorm2d(48)
17-
self.dropout3 = nn.Dropout2d(0.25)
18-
self.conv4 = nn.Conv2d(48, 96, 3)
19-
self.conv4_bn = nn.BatchNorm2d(96)
11+
self.conv1 = ConvBlock(3, 24, 5, dropout=True, batchnorm=True, maxpool=True, activation=F.elu)
12+
self.conv2 = ConvBlock(24, 48, 3, dropout=True, batchnorm=True, maxpool=True, activation=F.elu)
13+
self.conv3 = ConvBlock(48, 48, 3, dropout=True, batchnorm=True, maxpool=True, activation=F.elu)
14+
self.conv4 = ConvBlock(48, 48, 3, dropout=True, batchnorm=True, maxpool=True, activation=F.elu)
2015

2116
self.view_size = 96 * 6 * 6
22-
self.fc1 = nn.Linear(self.view_size, 250)
23-
self.fc_bn1 = nn.BatchNorm1d(250)
24-
self.fc2 = nn.Linear(250, 2)
2517

26-
self.dropout1 = nn.Dropout()
27-
self.dropout2 = nn.Dropout()
18+
self.fc1 = FCBlock(self.view_size, 250, dropout=True, batchnorm=True, activation=F.elu)
19+
self.fc2 = FCBlock(250, 2, dropout=False, batchnorm=False, activation=None)
2820

2921
self.criterion = nn.NLLLoss()
3022

3123
def forward(self, x):
32-
x = self.dropout1(F.max_pool2d(F.elu(self.conv1(x)), 2))
33-
x = self.dropout2(F.max_pool2d(F.elu(self.conv2_bn(self.conv2(x))), 2))
34-
x = self.dropout3(F.max_pool2d(F.elu(self.conv3_bn(self.conv3(x))), 2))
35-
x = F.max_pool2d(F.elu(self.conv4_bn(self.conv4(x))), 2)
36-
x = x.view(-1, self.view_size)
37-
x = self.dropout1(x)
38-
x = F.elu(self.fc_bn1(self.fc1(x)))
39-
x = self.dropout2(x)
40-
x = F.log_softmax(self.fc2(x))
24+
x = self.conv1(x)
25+
x = self.conv2(x)
26+
x = self.conv3(x)
27+
x = self.conv4(x)
28+
x = self.fc1(x)
29+
x = self.fc2(x)
30+
x = F.log_softmax(x)
4131
return x
4232

4333
def loss(self, predictions, targets):

pytorch_toolbox/modules/__init__.py

Whitespace-only changes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class ConvBlock(nn.Module):
6+
"""
7+
Generic convolution block
8+
9+
"""
10+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
11+
padding=0, dropout=False, dropout_rate=0.25,
12+
batchnorm=True, maxpool=True, maxpool_size=2, activation=None):
13+
"""
14+
15+
:param in_channels:
16+
:param out_channels:
17+
:param kernel_size:
18+
:param stride:
19+
:param padding:
20+
:param dropout:
21+
:param dropout_rate:
22+
:param batchnorm:
23+
:param maxpool:
24+
:param maxpool_size:
25+
:param activation: Function from function api
26+
"""
27+
super(ConvBlock, self).__init__()
28+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
29+
self.batch_norm = nn.BatchNorm2d(out_channels) if batchnorm else None
30+
self.dropout = nn.Dropout2d(dropout_rate) if dropout else None
31+
self.maxpool = nn.MaxPool2d(maxpool_size) if maxpool else None
32+
self.activation = activation
33+
34+
def forward(self, x):
35+
36+
x = self.conv(x)
37+
if self.batch_norm:
38+
x = self.batch_norm(x)
39+
if self.activation:
40+
x = self.activation(x)
41+
if self.maxpool:
42+
x = self.maxpool(x)
43+
if self.dropout:
44+
x = self.dropout(x)
45+
return x
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class FCBlock(nn.Module):
6+
"""
7+
generic Fully connected block
8+
"""
9+
def __init__(self, in_features, out_features, dropout=True, dropout_rate=0.5,
10+
batchnorm=True, activation=None):
11+
"""
12+
13+
:param in_features:
14+
:param out_features:
15+
:param dropout:
16+
:param dropout_rate:
17+
:param batchnorm:
18+
:param activation: Function from function api
19+
"""
20+
super(FCBlock, self).__init__()
21+
self.fc = nn.Linear(in_features, out_features)
22+
self.batch_norm = nn.BatchNorm1d(out_features) if batchnorm else None
23+
self.dropout = nn.Dropout(dropout_rate) if dropout else None
24+
self.activation = activation
25+
26+
def forward(self, x):
27+
28+
x = self.fc(x)
29+
if self.batch_norm:
30+
x = self.batch_norm(x)
31+
if self.activation:
32+
x = self.activation(x)
33+
if self.dropout:
34+
x = self.dropout(x)
35+
return x
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class Fire(nn.Module):
6+
"""
7+
From SqueezeNet : https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py
8+
9+
"""
10+
def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes):
11+
super(Fire, self).__init__()
12+
self.inplanes = inplanes
13+
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
14+
self.squeeze_activation = nn.ELU(inplace=True)
15+
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
16+
self.expand1x1_activation = nn.ELU(inplace=True)
17+
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
18+
self.expand3x3_activation = nn.ELU(inplace=True)
19+
20+
def forward(self, x):
21+
x = self.squeeze_activation(self.squeeze(x))
22+
return torch.cat([
23+
self.expand1x1_activation(self.expand1x1(x)),
24+
self.expand3x3_activation(self.expand3x3(x))
25+
], 1)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from torch import nn
2+
3+
4+
class SELayer(nn.Module):
5+
"""
6+
From Squeeze-and-Excitation Networks : https://github.com/moskomule/senet.pytorch/blob/master/se_module.py
7+
"""
8+
def __init__(self, channel, reduction=16):
9+
super(SELayer, self).__init__()
10+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
11+
self.fc = nn.Sequential(
12+
nn.Linear(channel, reduction),
13+
nn.ReLU(inplace=True),
14+
nn.Linear(reduction, channel),
15+
nn.Sigmoid()
16+
)
17+
18+
def forward(self, x):
19+
b, c, _, _ = x.size()
20+
y = self.avg_pool(x).view(b, c)
21+
y = self.fc(y).view(b, c, 1, 1)
22+
return x * y

0 commit comments

Comments
 (0)