diff --git a/rnn-lstm-gru/main.py b/rnn-lstm-gru/main.py index 299b035..96b5c7d 100644 --- a/rnn-lstm-gru/main.py +++ b/rnn-lstm-gru/main.py @@ -57,22 +57,22 @@ def forward(self, x): h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) #c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) - # x: (n, 28, 28), h0: (2, n, 128) + # x: (batch_size, sequence_length=28, input_size=28), h0: (num_layers=2, batch_size, hidden_size=128) # Forward propagate RNN out, _ = self.rnn(x, h0) # or: #out, _ = self.lstm(x, (h0,c0)) - # out: tensor of shape (batch_size, seq_length, hidden_size) - # out: (n, 28, 128) + # out: tensor of shape (batch_size, sequence_length, hidden_size) + # out: (batch_size, 28, 128) # Decode the hidden state of the last time step out = out[:, -1, :] - # out: (n, 128) + # out: (batch_size, hidden_size=128) out = self.fc(out) - # out: (n, 10) + # out: (batch_size, num_classes=10) return out model = RNN(input_size, hidden_size, num_layers, num_classes).to(device) @@ -85,8 +85,8 @@ def forward(self, x): n_total_steps = len(train_loader) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): - # origin shape: [N, 1, 28, 28] - # resized: [N, 28, 28] + # origin shape: [batch_size, 1, 28, 28] + # resized: [batch_size, sequence_length=28, input_size=28] images = images.reshape(-1, sequence_length, input_size).to(device) labels = labels.to(device)