diff --git a/src/training/train_loop.py b/src/training/train_loop.py index e79a5d8..7c6911a 100644 --- a/src/training/train_loop.py +++ b/src/training/train_loop.py @@ -46,7 +46,8 @@ def train_one_epoch(epoch, _model, _optimizer, _loader, _device, _loss_func, _cl final_image_loss = float(nn.functional.l1_loss(pred, gt).detach().cpu()) total_l1_loss += final_image_loss del loss, pred, final_image_loss - torch.mps.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() if (batch_idx + 1) % log_interval == 0: pbar.set_postfix({"loss": f"{total_loss/n_images:.4f}"})