From 7ccd1c5fb5638a1edce6a8b1a9983af811b513fe Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Wed, 5 Nov 2025 13:17:14 -0500 Subject: [PATCH] Check if device is mps --- src/training/train_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/training/train_loop.py b/src/training/train_loop.py index e79a5d8..7c6911a 100644 --- a/src/training/train_loop.py +++ b/src/training/train_loop.py @@ -46,7 +46,8 @@ def train_one_epoch(epoch, _model, _optimizer, _loader, _device, _loss_func, _cl final_image_loss = float(nn.functional.l1_loss(pred, gt).detach().cpu()) total_l1_loss += final_image_loss del loss, pred, final_image_loss - torch.mps.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() if (batch_idx + 1) % log_interval == 0: pbar.set_postfix({"loss": f"{total_loss/n_images:.4f}"})