-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
18 lines (17 loc) · 963 Bytes
/
model.py
File metadata and controls
18 lines (17 loc) · 963 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn
# Denoising Convolutional Neural Network from DnCNN paper
class DnCNN(nn.Module):
def __init__(self, channels, num_of_layers=17, kernel_size=3, padding=1, features=64):
super(DnCNN, self).__init__()
layers = []
layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.ReLU(inplace=True))
for _ in range(num_of_layers - 2):
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(features))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
def forward(self, x):
out = self.dncnn(x)
return x - out