Skip to content

Commit c87be1b

Browse files
authored
Merge pull request #76 from Hendrik-code/new_features_robert
New features robert
2 parents 5f20f41 + 413cfb5 commit c87be1b

8 files changed

Lines changed: 46 additions & 28 deletions

File tree

TPTBox/core/nii_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
)
6363

6464
if TYPE_CHECKING:
65-
pass
65+
from torch import device
6666
MODES = Literal["constant", "nearest", "reflect", "wrap"]
6767
_unpacked_nii = tuple[np.ndarray, AFFINE, nib.nifti1.Nifti1Header]
6868
_formatwarning = warnings.formatwarning
@@ -1042,7 +1042,7 @@ def to_simpleITK(self):
10421042
from TPTBox.core.sitk_utils import nii_to_sitk
10431043
return nii_to_sitk(self)
10441044

1045-
def to_deepali(self,align_corners: bool = True,dtype=None,device = "cuda"):
1045+
def to_deepali(self,align_corners: bool = True,dtype=None,device:device|str = "cpu"):
10461046
import torch
10471047
try:
10481048
from deepali.data import Image as deepaliImage # type: ignore

TPTBox/core/poi_fun/poi_abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
POI_ID = Union[
2020
tuple[int, int],
2121
slice,
22-
tuple[Abstract_lvl | int, Abstract_lvl | int],
22+
tuple[Union[Abstract_lvl, int], Union[Abstract_lvl, int]],
2323
tuple[Abstract_lvl, Abstract_lvl],
2424
tuple[Abstract_lvl, int],
2525
tuple[int, Abstract_lvl],

TPTBox/core/vert_constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def _get_id(cls, s: str | int, no_raise=True) -> int:
107107
try:
108108
return cls[s].value
109109
except KeyError:
110+
for c in cls:
111+
if c.name.lower() == s.lower():
112+
return c.value
110113
if not no_raise:
111114
raise
112115
return int(s)

TPTBox/registration/deformable/deformable_reg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,9 @@ def save(self, path: str | Path):
231231
pickle.dump(self.get_dump(), w)
232232

233233
@classmethod
234-
def load(cls, path):
234+
def load(cls, path, gpu=0, ddevice: DEVICES = "cuda"):
235235
with open(path, "rb") as w:
236-
return cls.load_(pickle.load(w))
236+
return cls.load_(pickle.load(w), gpu, ddevice)
237237

238238
@classmethod
239239
def load_(cls, w, gpu=0, ddevice: DEVICES = "cuda"):

TPTBox/segmentation/TotalVibeSeg/auto_download.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@ def user_guard(func: Any) -> Any:
6767

6868

6969
@user_guard
70-
def _download_weights(idx=85) -> None:
70+
def _download_weights(idx=85, addendum="") -> None:
7171
weights_dir = get_weights_dir(idx)
72-
weights_url = WEIGHTS_URL_ + f"{idx:03}.zip"
72+
weights_url = WEIGHTS_URL_ + f"{idx:03}{addendum}.zip"
7373
_download(weights_url, weights_dir, text="pretrained weights")
74+
addendum_download(idx)
7475

7576

7677
def _download(weights_url, weights_dir, text="") -> None:
@@ -100,10 +101,21 @@ def update_progress(block_num: int, block_size: int, total_size: int) -> None:
100101
os.remove(zip_path) # noqa: PTH107
101102

102103

104+
def addendum_download(idx):
105+
weights_dir = get_weights_dir(idx)
106+
next_zip = weights_dir / "other_downloads.json"
107+
if next_zip.exists():
108+
with open(next_zip) as f:
109+
add = json.load(f)
110+
[_download_weights(idx, addendum=a) for a in add]
111+
next_zip.unlink()
112+
113+
103114
def download_weights(idx, model_path: Path | None = None) -> Path:
104115
weights_dir = get_weights_dir(idx, model_path)
105116

106117
# Check if weights are downloaded
107118
if not weights_dir.exists():
108119
_download_weights(idx)
120+
addendum_download(idx)
109121
return weights_dir

TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ def get_ds_info(idx) -> dict:
2222
try:
2323
nnunet_path = next(next(iter(model_path.glob(f"*{idx}*"))).glob("*__nnUNetPlans*"))
2424
except StopIteration:
25-
Print_Logger().print(f"Please add Dataset {idx} to {model_path}", Log_Type.FAIL)
26-
model_path.mkdir(exist_ok=True, parents=True)
27-
sys.exit()
25+
try:
26+
nnunet_path = next(next(iter(model_path.glob(f"*{idx}*"))).glob("*__nnUNet*ResEnc*"))
27+
except StopIteration:
28+
Print_Logger().print(f"Please add Dataset {idx} to {model_path}", Log_Type.FAIL)
29+
model_path.mkdir(exist_ok=True, parents=True)
30+
sys.exit()
2831
with open(Path(nnunet_path, "dataset.json")) as f:
2932
ds_info = json.load(f)
3033
return ds_info

TPTBox/segmentation/nnUnet_utils/predictor.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -166,22 +166,22 @@ def mapp(d: dict):
166166
if "ignore" in dataset_json["labels"]:
167167
num_output_channels -= 1
168168

169-
network = get_network_from_plans(
170-
plans_manager, # type: ignore
171-
dataset_json,
172-
configuration_manager, # type: ignore
173-
num_input_channels,
174-
num_output_channels=num_output_channels,
175-
deep_supervision=False,
176-
)
177-
self.network = network
178-
179-
self.plans_manager = plans_manager
180-
self.configuration_manager = configuration_manager
181-
self.list_of_parameters = parameters # Lists of model folds
182-
self.dataset_json = dataset_json
183-
self.trainer_name = trainer_name
184-
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
169+
network = get_network_from_plans(
170+
plans_manager, # type: ignore
171+
dataset_json,
172+
configuration_manager, # type: ignore
173+
num_input_channels,
174+
num_output_channels=num_output_channels,
175+
deep_supervision=False,
176+
)
177+
self.network = network
178+
179+
self.plans_manager = plans_manager
180+
self.configuration_manager = configuration_manager
181+
self.list_of_parameters = parameters # Lists of model folds
182+
self.dataset_json = dataset_json
183+
self.trainer_name = trainer_name
184+
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
185185
self.label_manager = plans_manager.get_label_manager(dataset_json)
186186
if (
187187
("nnUNet_compile" in os.environ.keys())

unit_tests/test_testsamples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,10 @@ def _test_deformable(self, dim=0, save=True): # type: ignore
214214
poi[123, 44] = (random.randint(0, mri.shape[0] - 1), random.randint(0, mri.shape[1] - 1), random.randint(0, mri.shape[2] - 1))
215215
poi[123, 45] = (random.randint(0, mri.shape[0] - 1), random.randint(0, mri.shape[1] - 1), random.randint(0, mri.shape[2] - 1))
216216

217-
deform = Deformable_Registration(mov, mov, reference_image=mov)
217+
deform = Deformable_Registration(mov, mov, reference_image=mov, ddevice="cpu")
218218
if save:
219219
deform.save(test_save)
220-
deform = Deformable_Registration.load(test_save)
220+
deform = Deformable_Registration.load(test_save, ddevice="cpu")
221221
mov2 = mov.copy()
222222
mov2.seg = True
223223
mov2[mov > -10000] = 0

0 commit comments

Comments
 (0)