Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ ignore = [
"PLR2044", # empty comment symbol
"B905", # strict= in zip()
"UP007", # Union instead of | in python 3.9
"PLC0415",
"UP045",
]

# Allow fix for all enabled rules (when `--fix`) is provided.
Expand Down
3 changes: 1 addition & 2 deletions spineps/architectures/unet3D.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from functools import partial
from inspect import isfunction

import torch
from einops import rearrange
Expand Down Expand Up @@ -188,8 +189,6 @@ def forward(self, x, time_emb=None):


def default(val, d):
from inspect import isfunction

if val is not None:
return val
return d() if isfunction(d) else d
4 changes: 2 additions & 2 deletions spineps/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def entry_point():
"-model_semantic",
"-ms",
# type=str.lower,
default="auto",
default="t2w",
# choices=model_subreg_choices,
metavar="",
help="The model used for the subregion segmentation. You can also pass an absolute path the model folder",
Expand All @@ -148,7 +148,7 @@ def entry_point():
"-model_instance",
"-mv",
# type=str.lower,
default="auto",
default="instance",
# choices=model_vert_choices,
metavar="",
help="The model used for the vertebra segmentation. You can also pass an absolute path the model folder",
Expand Down
2 changes: 1 addition & 1 deletion spineps/phase_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def get_corpus_coms(
segvert = corpus_cc.extract_label(target_vert_id, inplace=False)
try:
logger.print("get_separating_components to split vertebra", verbose=verbose)
(spart, tpart, spart_dil, tpart_dil, stpart) = get_separating_components(segvert, connectivity=3)
(spart, tpart, spart_dil, tpart_dil, _) = get_separating_components(segvert, connectivity=3)

logger.print("Splitting by plane")
plane_split_nii = get_plane_split(segvert, corpus_nii, spart, tpart, spart_dil, tpart_dil)
Expand Down
28 changes: 21 additions & 7 deletions spineps/phase_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def perform_labeling_step(
vert_nii: NII,
subreg_nii: NII | None = None,
proc_lab_force_no_tl_anomaly: bool = False,
disable_c1: bool = True,
):
model.load()
if model.predictor is None:
model.load()

if 26 in vert_nii.unique():
has_sacrum = vert_nii.volumes()[26] > 500 # noqa: F841
Expand All @@ -44,12 +46,15 @@ def perform_labeling_step(
# crop for corpus instead of whole vertebra
corpus_nii = subreg_nii.extract_label((Location.Vertebra_Corpus, Location.Vertebra_Corpus_border))
vert_nii_c = vert_nii * corpus_nii
else:
vert_nii_c = vert_nii
# run model
labelmap = run_model_for_vert_labeling(
model,
img_nii,
vert_nii_c,
proc_lab_force_no_tl_anomaly=proc_lab_force_no_tl_anomaly,
disable_c1=disable_c1,
)[0]
# TODO make all vertebrae without visible corpus to visibility 0 but take into account for labeling
for i in vert_nii.unique():
Expand All @@ -66,14 +71,21 @@ def run_model_for_vert_labeling(
vert_nii: NII,
verbose: bool = False,
proc_lab_force_no_tl_anomaly: bool = False,
disable_c1: bool = True,
):
# reorient
img = img_nii.reorient(model.inference_config.model_expected_orientation, verbose=False)
vert = vert_nii.reorient(model.inference_config.model_expected_orientation, verbose=False)
# zms_pir = img.zoom
zms_pir = img.zoom

# crop
crop = vert.compute_crop(dist=128 / min(img.zoom))
img.apply_crop_(crop)
vert.apply_crop_(crop)

# rescale
# img.rescale_(model.calc_recommended_resampling_zoom(zms_pir), verbose=False)
# vert.rescale_(model.calc_recommended_resampling_zoom(zms_pir), verbose=False)
img.rescale_(model.calc_recommended_resampling_zoom(zms_pir), verbose=False)
vert.rescale_(model.calc_recommended_resampling_zoom(zms_pir), verbose=False)
#
img.assert_affine(other=vert)
# extract vertebrae
Expand All @@ -83,10 +95,11 @@ def run_model_for_vert_labeling(
# run model
predictions = model.run_all_seg_instances(img, vert)

fcost, fpath, fpath_post, costlist, min_costs_path, args = find_vert_path_from_predictions(
fcost, fpath, fpath_post, costlist, min_costs_path, _args = find_vert_path_from_predictions(
predictions=predictions,
proc_lab_force_no_tl_anomaly=proc_lab_force_no_tl_anomaly,
verbose=verbose,
disable_c1=disable_c1,
)
assert len(orig_label) == len(fpath_post), f"{len(orig_label)} != {len(fpath_post)}"
labelmap = {orig_label[idx]: fpath_post[idx] for idx in range(len(orig_label))}
Expand All @@ -100,7 +113,7 @@ def run_model_for_vert_labeling_cutouts(
disable_c1: bool = True,
boost_c2: float = 3.0,
allow_cervical_skip: bool = True,
verbose: bool = False,
verbose: bool = True,
):
# reorient
# img = img_nii.reorient(model.inference_config.model_expected_orientation, verbose=False)
Expand All @@ -117,7 +130,7 @@ def run_model_for_vert_labeling_cutouts(
orig_label = list(img_arrays.keys())
# run model
predictions = model.run_all_arrays(img_arrays)
fcost, fpath, fpath_post, costlist, min_costs_path, args = find_vert_path_from_predictions(
fcost, fpath, fpath_post, costlist, min_costs_path, _args = find_vert_path_from_predictions(
predictions=predictions,
verbose=verbose,
disable_c1=disable_c1,
Expand Down Expand Up @@ -257,6 +270,7 @@ def find_vert_path_from_predictions(
vertrel_column_norm: bool = True,
vertrel_gaussian_sigma: float = 0.6, # 0.6 # 0 means no gaussian
#
focus_tl_gap: bool = True, # focus on T11/T13 gap (if T11/t13 case is detected, predict again using crops and then check again)
argmax_combined_cost_matrix_instead_of_path_algorithm: bool = False,
proc_lab_force_no_tl_anomaly: bool = False,
#
Expand Down
2 changes: 1 addition & 1 deletion spineps/phase_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII, verbose=True):

# find which vert got how many ivd CCs
to_mapped_labels = list(mapping_cc_to_vert_label.values())
for i, l in enumerate(vert_labels):
for l in vert_labels:
if l not in to_mapped_labels and l != 1:
logger.print(f"Vertebra {v_idx2name[l]} got no IVD component assigned", Log_Type.STRANGE)
count = to_mapped_labels.count(l)
Expand Down
2 changes: 1 addition & 1 deletion spineps/phase_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def predict_semantic_mask(
def remove_nonsacrum_beyond_canal_height(seg_nii: NII):
seg_nii.assert_affine(orientation=("P", "I", "R"))
canal_nii = seg_nii.extract_label([Location.Spinal_Canal.value, Location.Spinal_Cord.value])
crop_i = canal_nii.compute_crop(dist=16)[1]
crop_i = canal_nii.compute_crop(dist=64 / seg_nii.zoom[1])[1]
seg_arr = seg_nii.get_seg_array()
sacrum_arr = seg_nii.extract_label(26).get_seg_array()
seg_arr[:, 0 : crop_i.start, :] = 0
Expand Down
2 changes: 1 addition & 1 deletion unit_tests/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_search_path_simple(self):
dtype=int,
)
rel_cost = -rel_cost
fcost, fpath, min_costs_path = find_most_probably_sequence(
fcost, fpath, _min_costs_path = find_most_probably_sequence(
cost,
region_rel_cost=rel_cost,
regions=[0, 3],
Expand Down
4 changes: 2 additions & 2 deletions unit_tests/test_postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class Test_Post_Processing(unittest.TestCase):
def test_phase_postprocess(self):
mri, subreg, vert, label = get_test_mri()
mri, subreg, vert, _label = get_test_mri()
print(vert.unique())
subreg_cleaned, vert_cleaned = phase_postprocess_combined(mri, subreg, vert, model_labeling=None, debug_data={})
self.assertTrue(subreg_cleaned.assert_affine(other=vert_cleaned))
Expand All @@ -29,7 +29,7 @@ def test_phase_postprocess(self):
self.assertEqual(len(vert_labels), 8)

def test_calc_centroids(self):
mri, subreg, vert, label = get_test_mri()
_mri, subreg, vert, _label = get_test_mri()

poi = predict_centroids_from_both(vert, subreg, models=[], parameter={"TEST": "TEST"})
self.assertTrue(poi is not None)
10 changes: 5 additions & 5 deletions unit_tests/test_proc_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@

class Test_proc_functions(unittest.TestCase):
def test_n4_bias(self):
mri, subreg, vert, label = get_test_mri()
mri, _subreg, _vert, _label = get_test_mri()
mri.normalize_to_range_()
mri_min = mri.min()
mri_max = mri.max()
self.assertEqual(mri_min, 0)
self.assertEqual(mri_max, 387)
mri_n4biased, mask = n4_bias(mri)
mri_n4biased, _mask = n4_bias(mri)
mri_min = mri_n4biased.min()
mri_max = mri_n4biased.max()
self.assertEqual(mri_min, 0)
self.assertEqual(mri_max, 252)

def test_clean_artifacts(self):
mri, subreg, vert, label = get_test_mri()
_mri, subreg, vert, label = get_test_mri()
l3 = vert.extract_label(label)
l3 = subreg.apply_mask(l3)
l3_volumes = l3.volumes()
Expand All @@ -45,7 +45,7 @@ def test_clean_artifacts(self):
self.assertEqual(a, b)

def test_clean_artifacts_no_zeros(self):
mri, subreg, vert, label = get_test_mri()
_mri, _subreg, vert, label = get_test_mri()
l3 = vert.extract_label(label)
l3[l3 == 0] = 1
l3_volumes = l3.volumes()
Expand All @@ -56,7 +56,7 @@ def test_clean_artifacts_no_zeros(self):
self.assertEqual(a, b)

def test_clean_artifacts_zeros(self):
mri, subreg, vert, label = get_test_mri()
_mri, subreg, vert, label = get_test_mri()
l3 = vert.extract_label(label)
l3 = subreg.apply_mask(l3) * 0
l3_volumes = l3.volumes()
Expand Down
10 changes: 5 additions & 5 deletions unit_tests/test_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_compatibility(self):

from spineps.seg_enums import Acquisition, InputType, Modality

mri, subreg, vert, label = get_test_mri()
mri, _subreg, _vert, _label = get_test_mri()
input_path = get_tests_dir().joinpath("sample_mri", "sub-mri_label-6_T2w.nii.gz")
model = Segmentation_Model_Dummy()

Expand All @@ -99,7 +99,7 @@ def test_compatibility(self):
self.assertFalse(compatible)

def test_phase_preprocess(self):
mri, subreg, vert, label = get_test_mri()
mri, _subreg, _vert, _label = get_test_mri()
for pad_size in range(7):
origin_diff = max([d * float(pad_size) for d in mri.zoom]) + 1e-4
# print(origin_diff)
Expand All @@ -113,11 +113,11 @@ def test_phase_preprocess(self):
self.assertEqual(s + (2 * pad_size), preprossed_input.shape[idx])

def test_segment_scan(self):
mri, subreg, vert, label = get_test_mri()
mri, subreg, _vert, _label = get_test_mri()
model = Segmentation_Model_Dummy()
model.run = MagicMock(return_value={OutputType.seg: subreg, OutputType.softmax_logits: None})
debug_data = {}
seg_nii, softmax_logits, errcode = predict_semantic_mask(
seg_nii, _softmax_logits, errcode = predict_semantic_mask(
mri,
model,
debug_data=debug_data,
Expand All @@ -134,7 +134,7 @@ def test_segment_scan(self):
self.assertEqual(errcode, ErrCode.OK)

def test_run_inference(self):
mri, subreg, vert, label = get_test_mri()
mri, subreg, _vert, _label = get_test_mri()
model = Segmentation_Model_Dummy().load()
s_arr = subreg.get_seg_array()
model.predictor.predict_single_npy_array = MagicMock(return_value=(s_arr, s_arr[np.newaxis, :]))
Expand Down
Loading