From fab2b812fe1bc6ca28786c0f4bfe65da2194aa4f Mon Sep 17 00:00:00 2001 From: ga84mun Date: Sat, 29 Nov 2025 00:40:44 +0000 Subject: [PATCH] fix loading bug for new torch versions --- .gitignore | 3 ++- pyproject.toml | 2 +- spineps/lab_model.py | 2 +- spineps/seg_model.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 8281c7d..f890c15 100755 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,5 @@ lightning_logs/ test.txt derivatives_seg robert_test.py -poetry.lock \ No newline at end of file +poetry.lock +*.DS_Store \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a57e752..88eea95 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ SciPy = "^1.11.2" torchmetrics = "^1.1.2" tqdm = "^4.66.1" einops= "^0.6.1" -nnunetv2 = "2.4.2" +nnunetv2 = "^2.4.2" TPTBox = "^0.4.0" antspyx = "0.4.2" rich = "^13.6.0" diff --git a/spineps/lab_model.py b/spineps/lab_model.py index 64e81f3..625c45d 100755 --- a/spineps/lab_model.py +++ b/spineps/lab_model.py @@ -97,7 +97,7 @@ def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002 chktpath = search_path(self.model_folder, "**/*val_f1=*valf1-weights.ckpt") assert len(chktpath) >= 1, chktpath - model = PLClassifier.load_from_checkpoint(checkpoint_path=chktpath[-1]) + model = PLClassifier.load_from_checkpoint(checkpoint_path=chktpath[-1], weights_only=False) if hasattr(model.opt, "final_size"): self.final_size = model.opt.final_size self.transform = Compose( diff --git a/spineps/seg_model.py b/spineps/seg_model.py index 3a8c86f..2da3544 100755 --- a/spineps/seg_model.py +++ b/spineps/seg_model.py @@ -313,7 +313,7 @@ def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002 chktpath = search_path(self.model_folder, "**/*weights*.ckpt") assert len(chktpath) == 1 - model = PLNet.load_from_checkpoint(checkpoint_path=chktpath[0]) + model = PLNet.load_from_checkpoint(checkpoint_path=chktpath[0], weights_only=False) model.eval() self.device = torch.device("cuda:0" if torch.cuda.is_available() and not self.use_cpu else "cpu") model.to(self.device)