-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathneural_network.py
More file actions
34 lines (24 loc) · 1.01 KB
/
neural_network.py
File metadata and controls
34 lines (24 loc) · 1.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""
Neural Network class for MNIST classification
"""
import torch
class NeuralNetwork(torch.nn.Module):
def __init__(self, input_size=784, hidden_layers=[64, 32], output_size=10):
super(NeuralNetwork, self).__init__()
# Build layers dynamically from config
self.layers = torch.nn.ModuleList()
# Input to first hidden layer
prev_size = input_size
for hidden_size in hidden_layers:
self.layers.append(torch.nn.Linear(prev_size, hidden_size))
prev_size = hidden_size
# Last hidden to output layer
self.layers.append(torch.nn.Linear(prev_size, output_size))
def train_forward(self, data):
data = data.view(data.size(0), -1) # Flatten
# Apply all layers except the last with ReLU
for layer in self.layers[:-1]:
data = torch.nn.functional.relu(layer(data))
# Last layer without activation
data = self.layers[-1](data)
return data