Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions pytorch-lightning/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def training_step(self, batch, batch_idx):
outputs = self(images)
loss = F.cross_entropy(outputs, labels)

tensorboard_logs = {'train_loss': loss}
# use key 'log'
return {"loss": loss, 'log': tensorboard_logs}
self.log("train_loss", loss, prog_bar=True)
return loss

# define what happens for testing here

Expand All @@ -68,22 +67,20 @@ def val_dataloader(self):
)
return test_loader

def validation_step(self, batch, batch_idx):
def validation_step(self, batch):
images, labels = batch
images = images.reshape(-1, 28 * 28)

# Forward pass
outputs = self(images)

loss = F.cross_entropy(outputs, labels)
return {"val_loss": loss}

def validation_epoch_end(self, outputs):
# outputs = list of dictionaries
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'avg_val_loss': avg_loss}
# use key 'log'
return {'val_loss': avg_loss, 'log': tensorboard_logs}
self.validation_step_outputs.append(loss)
return loss

def on_validation_epoch_end(self):
epoch_average = torch.stack(self.validation_step_outputs).mean()
self.log("val_loss", epoch_average, prog_bar=True, on_step=False, on_epoch=True)
self.validation_step_outputs.clear()


def configure_optimizers(self):
Expand Down