-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel.py
More file actions
24 lines (22 loc) · 1012 Bytes
/
model.py
File metadata and controls
24 lines (22 loc) · 1012 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn.functional as F
from torch import nn
from tcn import TemporalConvNet
import IPython as IP
class TCN(nn.Module):
def __init__(self, input_size, output_size, num_channels, kernel_size, dropout): # 1, 10, [25, 25 ... 25], 7.
super(TCN, self).__init__()
self.tcn = TemporalConvNet(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
# print('Ready to linearize')
# IP.embed()
self.relu = nn.ReLU()
self.linear = nn.Linear(num_channels[-1], output_size)
def forward(self, inputs):
"""Inputs have to have dimension (N, C_in, L_in)"""
# print('Forward pass of TCN')
# IP.embed()
y1 = self.tcn(inputs) # input should have dimension (batch, channels, seq_length)
# print('Final Step TCN forward')
# IP.embed()
o = self.linear(y1[:, :, -1])
return self.relu(o) # Maybe Relu helps to discard negative EMG values. Otherwise return o
# return F.Linear(o, dim=1)