From 21bb1666a66adbe5d20e052c66ac3f90e94a90ca Mon Sep 17 00:00:00 2001 From: agentmarketbot Date: Sat, 11 Jan 2025 09:15:21 +0000 Subject: [PATCH] improve RNN model comments with dimension details Update tensor dimension comments in RNN model implementation to be more descriptive and clearer. Replace generic variable names (n) with specific parameter names (batch_size) and add explicit dimension sizes for sequence_length, input_size, hidden_size, and num_classes to improve code readability and understanding of the network architecture. --- rnn-lstm-gru/main.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/rnn-lstm-gru/main.py b/rnn-lstm-gru/main.py index 299b035..96b5c7d 100644 --- a/rnn-lstm-gru/main.py +++ b/rnn-lstm-gru/main.py @@ -57,22 +57,22 @@ def forward(self, x): h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) #c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) - # x: (n, 28, 28), h0: (2, n, 128) + # x: (batch_size, sequence_length=28, input_size=28), h0: (num_layers=2, batch_size, hidden_size=128) # Forward propagate RNN out, _ = self.rnn(x, h0) # or: #out, _ = self.lstm(x, (h0,c0)) - # out: tensor of shape (batch_size, seq_length, hidden_size) - # out: (n, 28, 128) + # out: tensor of shape (batch_size, sequence_length, hidden_size) + # out: (batch_size, 28, 128) # Decode the hidden state of the last time step out = out[:, -1, :] - # out: (n, 128) + # out: (batch_size, hidden_size=128) out = self.fc(out) - # out: (n, 10) + # out: (batch_size, num_classes=10) return out model = RNN(input_size, hidden_size, num_layers, num_classes).to(device) @@ -85,8 +85,8 @@ def forward(self, x): n_total_steps = len(train_loader) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): - # origin shape: [N, 1, 28, 28] - # resized: [N, 28, 28] + # origin shape: [batch_size, 1, 28, 28] + # resized: [batch_size, sequence_length=28, input_size=28] images = images.reshape(-1, sequence_length, input_size).to(device) labels = labels.to(device)