-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathmodule.py
More file actions
105 lines (84 loc) · 3.1 KB
/
module.py
File metadata and controls
105 lines (84 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
class Linear(nn.Linear):
"""
Linear
Args:
x: (N, T, C_in)
Returns:
y: (N, T, C_out)
"""
def __init__(self, in_features, out_features, bias=True, activation_fn=None, ln=None, drop_rate=0.):
super(Linear, self).__init__(in_features, out_features, bias=bias)
self.activation_fn = activation_fn(inplace=True) if activation_fn is not None else None
self.layer_norm = nn.LayerNorm(out_features) if ln is not None else None
self.activation_fn = activation_fn(inplace=True) if activation_fn is not None else None
self.drop_out = nn.Dropout(drop_rate) if drop_rate > 0 else None
def forward(self, x):
y = super(Linear, self).forward(x)
y = self.layer_norm(y) if self.layer_norm is not None else y
y = self.activation_fn(y) if self.activation_fn is not None else y
y = self.drop_out(y) if self.drop_out is not None else y
return y
class Conv1d(nn.Conv1d):
"""
Convolution 1d
Args:
x: (N, T, C_in)
Returns:
y: (N, T, C_out)
"""
def __init__(self, in_channels, out_channels, kernel_size, activation_fn=None, drop_rate=0.,
stride=1, padding='same', dilation=1, groups=1, bias=True, ln=False):
if padding == 'same':
padding = kernel_size // 2 * dilation
self.even_kernel = not bool(kernel_size % 2)
super(Conv1d, self).__init__(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.activation_fn = activation_fn(inplace=True) if activation_fn is not None else None
self.drop_out = nn.Dropout(drop_rate) if drop_rate > 0 else None
self.layer_norm = nn.LayerNorm(out_channels) if ln else None
def forward(self, x):
y = x.transpose(1, 2)
y = super(Conv1d, self).forward(y)
y = y.transpose(1, 2)
y = self.layer_norm(y) if self.layer_norm is not None else y
y = self.activation_fn(y) if self.activation_fn is not None else y
y = self.drop_out(y) if self.drop_out is not None else y
y = y[:, :-1, :] if self.even_kernel else y
return y
class Conv1dResBlock(Conv1d):
"""
Convolution 1d with Residual connection
Args:
x: (N, T, C_in)
Returns:
y: (N, T, C_out)
"""
def __init__(self, in_channels, out_channels, kernel_size, activation_fn=None, drop_rate=0.,
stride=1, padding='same', dilation=1, groups=1, bias=True, ln=False):
super(Conv1dResBlock, self).__init__(in_channels, out_channels, kernel_size, activation_fn,
drop_rate, stride, padding, dilation, groups=groups, bias=bias,
ln=ln)
def forward(self, x):
residual = x
x = super(Conv1dResBlock, self).forward(x)
x = x + residual
return x
class Upsample(nn.Upsample):
"""
Upsampling via interporlation
Args:
x: (N, T, C)
Returns:
y: (N, S * T, C)
(S: scale_factor)
"""
def __init__(self, scale_factor=2, mode='nearest'):
super(Upsample, self).__init__(scale_factor=scale_factor, mode=mode)
def forward(self, x):
x = x.transpose(1, 2)
x = super(Upsample, self).forward(x)
x = x.transpose(1, 2)
return x