From aa9eae0bda034ebb166fd7b3a02441dd5385ec1f Mon Sep 17 00:00:00 2001 From: Amirali Date: Sat, 24 Feb 2024 21:01:19 +0400 Subject: [PATCH] edit lightning.py --- pytorch-lightning/lightning.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/pytorch-lightning/lightning.py b/pytorch-lightning/lightning.py index ed29397..7df362a 100644 --- a/pytorch-lightning/lightning.py +++ b/pytorch-lightning/lightning.py @@ -41,9 +41,8 @@ def training_step(self, batch, batch_idx): outputs = self(images) loss = F.cross_entropy(outputs, labels) - tensorboard_logs = {'train_loss': loss} - # use key 'log' - return {"loss": loss, 'log': tensorboard_logs} + self.log("train_loss", loss, prog_bar=True) + return loss # define what happens for testing here @@ -68,22 +67,20 @@ def val_dataloader(self): ) return test_loader - def validation_step(self, batch, batch_idx): + def validation_step(self, batch): images, labels = batch images = images.reshape(-1, 28 * 28) # Forward pass outputs = self(images) - loss = F.cross_entropy(outputs, labels) - return {"val_loss": loss} - - def validation_epoch_end(self, outputs): - # outputs = list of dictionaries - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() - tensorboard_logs = {'avg_val_loss': avg_loss} - # use key 'log' - return {'val_loss': avg_loss, 'log': tensorboard_logs} + self.validation_step_outputs.append(loss) + return loss + + def on_validation_epoch_end(self): + epoch_average = torch.stack(self.validation_step_outputs).mean() + self.log("val_loss", epoch_average, prog_bar=True, on_step=False, on_epoch=True) + self.validation_step_outputs.clear() def configure_optimizers(self):