How to predict on the test dataset using trainer.predict()? #13568
-
|
Hi I have trained the model using trainer and was trying to use trainer.predict() method to predict on the datamodule. But it throws the following error: I have the following dataloader: I have following model defined: Please, help to predict on the testdata. How I can leverage trainer to predict on the test data and get classification report for the predicted output? Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
|
@karndeepsingh To use If you'd like to run inference on your test set, you just need to define def predict_dataloader(self):
return torch.utils.data.DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=4,
shuffle=False) |
Beta Was this translation helpful? Give feedback.
@karndeepsingh To use
Trainer.predict(), You must havepredict_dataloader()defined in your LightningModule or LightningDataModule as the error message states:If you'd like to run inference on your test set, you just need to define
predict_dataloader()with your test set: