diff --git a/README.md b/README.md index 4522a25..28eb0b6 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) # SPINEPS – Automatic Whole Spine Segmentation of T2w MR images using a Two-Phase Approach to Multi-class Semantic and Instance Segmentation. +# and +# VERIDAH: Solving Enumeration Anomaly Aware Vertebra Labeling across Imaging Sequences This is a segmentation pipeline to automatically, and robustly, segment the whole spine in T2w sagittal images. @@ -268,6 +270,23 @@ In the subregion segmentation: In the vertebra instance segmentation mask, each label X in [1, 25] are the unique vertebrae, while 100+X are their corresponding IVD and 200+X their endplates. +## VERIDAH: + +To run the vertebra labeling after segmentation, specify a -model_labeling model (similar to -model_semantic and -model_instance). + +If you use VERIDAH (labeling model) in addition to the segmentation models from SPINEPS, then a labeling model will run and give each vertebrae detected by SPINEPS a vertebra label. These are + +| Label | Structure | +| :---: | --------- | +| 1 | C1 | +| 2 - 7 | C2 - C7 | +| 8 - 19 | T1 - T12 | +| 28 | T13 | +| 20 | L1 | +| 21 - 25 | L2 - L6 | +| 26 | Sacrum | + +The labels 100+X still correspond to the vertebra's IVD and 200+X the respective endplate. For example, the label 119 is the IVD below the T12 vertebra. ## Using the Code diff --git a/spineps/architectures/pl_densenet.py b/spineps/architectures/pl_densenet.py index f6ab532..c219e36 100644 --- a/spineps/architectures/pl_densenet.py +++ b/spineps/architectures/pl_densenet.py @@ -3,17 +3,82 @@ import os import sys from dataclasses import dataclass +from enum import Enum from pathlib import Path import pytorch_lightning as pl import torch -from monai.networks.nets import DenseNet169 +from monai.networks.nets import DenseNet121, DenseNet169 +from monai.networks.nets.resnet import ( + ResNet, + ResNetBlock, + _resnet, + get_inplanes, + resnet10, + resnet18, + resnet34, + resnet50, + resnet101, + resnet152, +) from torch import nn from TypeSaveArgParse import Class_to_ArgParse +def resnet2( + layers: list[int] | None = None, + **kwargs, +): + if layers is None: + layers = [1, 1] + return _resnet("resnet2", ResNetBlock, layers, get_inplanes(), False, False, **kwargs) + + +class MODEL(Enum): + DENSENET169 = DenseNet169 + DENSENET121 = DenseNet121 + RESNET10 = 10 # resnet10 + RESNET18 = 18 # resnet18 + RESNET34 = 34 # resnet34 + RESNET50 = 50 # resnet50 + RESNET101 = 101 # resnet101 + RESNET152 = 152 # resnet152 + RESNET2 = 2 # resnet2 + + def __call__( + self, + opt: ARGS_MODEL, + remove_classification_head: bool = True, + ): + if "DENSENET" in self.name: + return get_densenet_architecture( + self.value, + in_channel=opt.in_channel, + out_channel=opt.num_classes, + pretrained=not opt.not_pretrained, + remove_classification_head=remove_classification_head, + ) + elif "RESNET" in self.name: + d = { + 10: resnet10, + 18: resnet18, + 34: resnet34, + 50: resnet50, + 101: resnet101, + 152: resnet152, + 2: resnet2, + } + return get_resnet_architecture( + d[self.value], + remove_classification_head=remove_classification_head, + ) + else: + raise ValueError(f"Model {self.name} not supported.") + + @dataclass class ARGS_MODEL(Class_to_ArgParse): + backbone: MODEL = MODEL.DENSENET169.name classification_conv: bool = False classification_linear: bool = True # @@ -43,9 +108,8 @@ def __init__(self, opt: ARGS_MODEL, group_2_n_channel: dict[str, int]): # save hyperparameter, everything below not visible self.save_hyperparameters() - self.net, linear_in = get_architecture( - DenseNet169, opt.in_channel, opt.num_classes, pretrained=False, remove_classification_head=True - ) + self.backbone = MODEL[opt.backbone] + self.net, linear_in = self.backbone(opt, remove_classification_head=True) self.classification_heads = self.build_classification_heads(linear_in, opt.classification_conv, opt.classification_linear) self.classification_keys = list(self.classification_heads.keys()) self.mse_weighting = opt.mse_weighting @@ -89,7 +153,7 @@ def __str__(self) -> str: return "VertebraLabelingModel" -def get_architecture( +def get_densenet_architecture( model, in_channel: int = 1, out_channel: int = 1, @@ -102,8 +166,21 @@ def get_architecture( out_channels=out_channel, pretrained=pretrained, ) - linear_infeatures = 0 linear_infeatures = model.class_layers[-1].in_features if remove_classification_head: model.class_layers = model.class_layers[:-1] return model, linear_infeatures + + +def get_resnet_architecture( + model, + remove_classification_head: bool = True, +): + model = model( + spatial_dims=3, + n_input_channels=1, + ) + linear_infeatures = model.fc.in_features + if remove_classification_head: + model.fc = None + return model, linear_infeatures diff --git a/spineps/architectures/read_labels.py b/spineps/architectures/read_labels.py index 433259e..96bc4c6 100644 --- a/spineps/architectures/read_labels.py +++ b/spineps/architectures/read_labels.py @@ -383,6 +383,12 @@ class SubjectInfo: first_lwk: int = 20 double_entries: list[int] = field(default_factory=list) + @property + def has_tea(self) -> bool: + if not self.has_anomaly_entry: + return None + return self.anomaly_entry["T11"] or self.anomaly_entry["T13"] + @property def block(self) -> int: return int(str(self.subject_name)[:3]) @@ -393,15 +399,17 @@ def get_subject_info( subject_name: str | int, anomaly_dict: dict, vert_subfolders_int: list[int], - anomaly_factor_condition: int = 0, + subject_name_int: bool = True, ): + if subject_name_int: + subject_name = int(subject_name) double_entries = [] labelmap = {} has_anomaly_entry = False anomaly_entry = {} deleted_label = [] is_remove = False - if int(subject_name) in anomaly_dict: + if subject_name in anomaly_dict: anomaly_entry = anomaly_dict[subject_name] has_anomaly_entry = True if anomaly_entry["DeleteLabel"] is not None: @@ -411,22 +419,43 @@ def get_subject_info( if bool(anomaly_entry["T11"]): labelmap = {i: i + 1 for i in range(19, 26)} - double_entries = [17, 18, 20, 21] elif bool(anomaly_entry["T13"]): labelmap = {20: 28, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24} - double_entries = [19, 28, 20, 21] - elif anomaly_factor_condition == 0: - double_entries = [18, 19, 20, 21] + + if "LabelOverride" in anomaly_entry and anomaly_entry["LabelOverride"] is not None: + assert len(anomaly_entry["LabelOverride"]) == len(vert_subfolders_int), ( + f"len({anomaly_entry['LabelOverride']}) != len({vert_subfolders_int})" + ) + vert_subfolders_sorted = sorted(vert_subfolders_int, key=lambda x: x if x != 28 else 19.5) + labelmap = {i: k for i, k in zip(vert_subfolders_sorted, anomaly_entry["LabelOverride"], strict=False)} # noqa: C416 actual_labels = [labelmap.get(v, v) for v in vert_subfolders_int] + + if 28 in actual_labels and 19 not in actual_labels: + print(f"{subject_name}: 28 in {actual_labels} but no 19") + is_remove = True + + # T11 + if 18 in actual_labels and 19 not in actual_labels and 20 in actual_labels: + double_entries = [17, 18, 20, 21] + elif 28 in actual_labels: + double_entries = [19, 28, 20, 21] + else: + double_entries = [18, 19, 20, 21] + + if len(anomaly_dict) == 0: + double_entries = [] + # # last_hwk = 7 # first_bwk = 8 - last_bwk = max([v for v in actual_labels if 7 < v <= 19 or v == 28]) if max(actual_labels) >= 18 else None + bwks = [v for v in actual_labels if 7 < v <= 19 or v == 28] + last_bwk = max(bwks) if max(actual_labels) >= 18 and len(bwks) > 0 else None # first_lwk = 20 - last_lwk = max([v for v in actual_labels if 22 < v < 26]) if max(actual_labels) >= 23 else None + lwks = [v for v in actual_labels if 22 < v < 26] + last_lwk = max(lwks) if max(actual_labels) >= 23 and len(lwks) > 0 else None return SubjectInfo( - subject_name=int(subject_name), + subject_name=subject_name, has_anomaly_entry=has_anomaly_entry, anomaly_entry=anomaly_entry, actual_labels=actual_labels, diff --git a/spineps/lab_model.py b/spineps/lab_model.py index a7dbc87..64e81f3 100755 --- a/spineps/lab_model.py +++ b/spineps/lab_model.py @@ -1,11 +1,13 @@ from __future__ import annotations +import math import os from pathlib import Path import numpy as np import torch from monai.transforms import CenterSpatialCropd, Compose, NormalizeIntensityd, ToTensor +from scipy.ndimage.interpolation import rotate from TPTBox import NII, Log_Type, No_Logger, np_utils from typing_extensions import Self @@ -17,6 +19,57 @@ logger = No_Logger(prefix="VertLabelingClassifier") +def unit_vector(vector): + """Returns the unit vector of the vector.""" + return vector / np.linalg.norm(vector) + + +def angle_between(v1, v2, signed=True): + """Returns the angle in radians between vectors 'v1' and 'v2':: + + >>> angle_between((1, 0, 0), (0, 1, 0)) + 1.5707963267948966 + >>> angle_between((1, 0, 0), (1, 0, 0)) + 0.0 + >>> angle_between((1, 0, 0), (-1, 0, 0)) + 3.141592653589793 + """ + v1_u = unit_vector(v1) + v2_u = unit_vector(v2) + angle = np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + if signed: + sign = np.array(np.sign(np.cross(v1, v2).dot((1, 1, 1)))) + # 0 means collinear: 0 or 180. Let's call that clockwise. + sign[sign == 0] = 1 + angle = sign * angle + return angle + + +def rotate_patch_sagitally(patch: np.ndarray, angle: float, msk: bool = False, cval: int = 0) -> np.ndarray: + """ + Rotates a patch sagitally given an angle (Assuming the patch is in (I,P,L) orientation) + + Parameters: + ---------- + patch: np.ndarray + a numpy array with (I,P,L) orientation + angle: float + angle of rotation in degrees + msk: bool, optional + flag to determine interpolation type. Interpolation order os 0 if input is a mask. + Output: + np.ndarray: rotated patch + """ + if msk: + cval = 0 + order = 0 + else: + order = 3 + rotated_patch = rotate(patch, angle=-angle, reshape=False, order=order, mode="constant", cval=cval) # type: ignore + return rotated_patch + + class VertLabelingClassifier(Segmentation_Model): def __init__( self, @@ -104,12 +157,25 @@ def run_all_seg_instances(self, img: NII, seg: NII) -> dict[int, dict[str, np.nd seg = seg.reorient() # TODO assert order of seg labels are order from top to bottom predictions = {} + + coms = seg.reorient(("I", "P", "L")).center_of_masses() + sorted_ctds = sorted([[a, *b] for a, b in coms.items()], key=lambda x: x[1]) + for v in seg.unique(): - logits_soft, pred_cls = self.run_given_seg_pos(img, seg, vert_label=v) + # Find the index of the given vertebra in the sorted list + idx = next(i for i, ct in enumerate(sorted_ctds) if ct[0] == v) + + # Get the centroids above and below + ctd1 = sorted_ctds[idx - 1][1:] if idx > 0 else sorted_ctds[idx][1:] + ctd2 = sorted_ctds[idx + 1][1:] if idx < len(sorted_ctds) - 1 else sorted_ctds[idx][1:] + myradians = angle_between(np.asarray(ctd2) - np.asarray(ctd1), (1, 0, 0)) # type: ignore + mydegrees = math.degrees(myradians) + + logits_soft, pred_cls = self.run_given_seg_pos(img, seg, vert_label=v, angle=mydegrees) predictions[v] = {"soft": logits_soft, "pred": pred_cls} return predictions - def run_given_seg_pos(self, img: NII | np.ndarray, seg: NII, vert_label: int | None = None): + def run_given_seg_pos(self, img: NII, seg: NII, vert_label: int | None = None, angle: float | None = None): if vert_label is not None: seg = seg.extract_label(vert_label) elif len(seg.unique()) > 1: @@ -120,18 +186,48 @@ def run_given_seg_pos(self, img: NII | np.ndarray, seg: NII, vert_label: int | N for i in range(len(crop)): size_t = crop[i].stop - crop[i].start center_of_crop.append(crop[i].start + (size_t // 2)) - return self.run_given_center_pos(img, center_of_crop) + return self.run_given_center_pos(img, seg, center_of_crop, angle=angle) # type: ignore - def run_given_center_pos(self, img: NII | np.ndarray, center_pos: tuple[int, int, int]): + def run_given_center_pos(self, img: NII, seg: NII, center_pos: tuple[int, int, int], angle: float | None = None): + extra_rotation_padding = 64 + extra_rotation_padding_halfed = extra_rotation_padding // 2 + # # cut array then runs prediction arr = img.get_array() if isinstance(img, NII) else img arr_cut, cutout_coords_slices, padding = np_utils.np_calc_crop_around_centerpoint( center_pos, arr, - self.cutout_size, + (self.cutout_size[0] + extra_rotation_padding, self.cutout_size[1] + extra_rotation_padding, self.cutout_size[2]), ) - # sem_cut = np.pad(vert_v.get_seg_array()[cutout_coords_slices], padding) - return self._run_array(arr_cut) # sem_cut + sem_cut = np.pad(seg[cutout_coords_slices], padding) + # final cutout size (200, 160, 32) + + ori = img.orientation + img_v = img.set_array(arr_cut).reorient_(("I", "P", "L")) + seg_v = seg.set_array(sem_cut).reorient_(("I", "P", "L")) + + # angle = 0 + if angle is not None and angle != 0: + arr_cut = rotate_patch_sagitally(img_v.get_array(), -angle, msk=False) + sem_cut = rotate_patch_sagitally(seg_v.get_seg_array(), -angle, msk=True) + + # crop down to final cutout size (200, 160, 32) + arr_cut = arr_cut[ + extra_rotation_padding_halfed:-extra_rotation_padding_halfed, + extra_rotation_padding_halfed:-extra_rotation_padding_halfed, + :, + ] + sem_cut = sem_cut[ + extra_rotation_padding_halfed:-extra_rotation_padding_halfed, + extra_rotation_padding_halfed:-extra_rotation_padding_halfed, + :, + ] + + img_v.set_array_(arr_cut).reorient_(ori) + seg_v.set_array_(sem_cut).reorient_(ori) + # img_v.save("/DATA/NAS/ongoing_projects/hendrik/img_v.nii.gz") + # seg_v.save("/DATA/NAS/ongoing_projects/hendrik/seg_v.nii.gz") + return self._run_array(img_v.get_array(), seg_v.get_seg_array()) # sem_cut def _run_nii(self, img_nii: NII): # TODO check resolution @@ -146,16 +242,27 @@ def run_all_arrays(self, img_arrays: dict[int, np.ndarray]) -> dict[int, dict[st predictions[v] = {"soft": logits_soft, "pred": pred_cls} return predictions - def _run_array(self, img_arr: np.ndarray | torch.Tensor): # , seg_arr: np.ndarray): + def _run_array(self, img_arr: np.ndarray, seg_arr: np.ndarray | None | torch.Tensor = None): # , seg_arr: np.ndarray): assert img_arr.ndim == 3, f"Dimension mismatch, {img_arr.shape}, expected 3 dimensions" # - img_arr = self.totensor(img_arr).unsqueeze_(0) - d = self.transform({"img": img_arr, "seg": img_arr}) + img_arr = self.totensor(img_arr) + # add channel + img_arr.unsqueeze_(0) + + if seg_arr is not None: + seg_arr = self.totensor(seg_arr) + seg_arr.unsqueeze_(0) + else: + seg_arr = img_arr.clone() + + d = self.transform({"img": img_arr, "seg": seg_arr}) # TODO seg channelwise and stuff model_input = d["img"] + # print(model_input.shape) model_input.unsqueeze_(0) + # print(model_input.shape) model_input = model_input.to(torch.float32) model_input = model_input.to(self.device) diff --git a/spineps/phase_instance.py b/spineps/phase_instance.py index 3cdfc23..ba7e319 100755 --- a/spineps/phase_instance.py +++ b/spineps/phase_instance.py @@ -288,10 +288,7 @@ def get_corpus_coms( segvert = corpus_cc.extract_label(target_vert_id, inplace=False) try: logger.print("get_separating_components to split vertebra", verbose=verbose) - (spart, tpart, spart_dil, tpart_dil, stpart) = get_separating_components( - segvert, - connectivity=3, - ) + (spart, tpart, spart_dil, tpart_dil, stpart) = get_separating_components(segvert, connectivity=3) logger.print("Splitting by plane") plane_split_nii = get_plane_split(segvert, corpus_nii, spart, tpart, spart_dil, tpart_dil) @@ -308,6 +305,55 @@ def get_corpus_coms( def get_separating_components(segvert: np.ndarray, max_iter: int = 10, connectivity: int = 3): + """ + Attempts to split a binary volumetric segmentation into two spatially separate components (S and T) + by iterative erosion and connected component analysis. + + This function is designed for cases where an initial segmentation is a single connected component, + but the goal is to identify two meaningful subregions. It uses morphological erosion to find a + splitting point and then recovers the two regions through dilation. + + Parameters + ---------- + segvert : np.ndarray + A 3D binary (or labeled) numpy array representing the segmented volume to split. + max_iter : int, optional + Maximum number of erosion iterations allowed to find separable components. Default is 10. + connectivity : int, optional + Connectivity used for morphological operations (e.g., 1=6-connectivity, 2=18, 3=26). Default is 3. + + Returns + ------- + spart : np.ndarray + Binary mask of the first separated component (S). + tpart : np.ndarray + Binary mask of the second separated component (T). + spart_dil : np.ndarray + Dilated version of spart until contact with tpart. + tpart_dil : np.ndarray + Dilated version of tpart until contact with spart. + stpart : np.ndarray + Combined map of dilated S and T, with values: + - 0: background + - 1: spart_dil only + - 2: tpart_dil only + - 3: overlapping region between spart_dil and tpart_dil + + Raises + ------ + Exception + If the input volume cannot be split into two parts within the allowed number of iterations, + or if resulting parts are empty. + IndentationError + If the maximum number of iterations is reached without successful separation. + + Notes + ----- + - The function assumes that `np_erode_msk`, `np_dilate_msk`, `np_connected_components`, + `np_volume`, `np_filter_connected_components` are available in the environment. + - This method is particularly useful for anatomical structures that are initially connected + (e.g., left and right organs) but should be separated for downstream analysis. + """ check_connectivity = 3 vol = segvert.copy() vol_old = vol.copy() @@ -387,6 +433,52 @@ def get_plane_split( spart_dil: np.ndarray, tpart_dil: np.ndarray, ): + """ + Computes an approximate separating plane between two regions (spart and tpart) + based on their dilated overlap and returns it as a NIfTI image. + + This function determines the collision region between the dilated versions + of two separated segmentation components. It then estimates a plane orthogonal + to the vector between their centers of mass and passing through the point + of contact. The resulting binary plane is filled in the axial direction + and returned in NIfTI format for visualization or further processing. + + Parameters + ---------- + segvert : np.ndarray + Original 3D binary or labeled segmentation volume. + compare_nii : NII + NIfTI image object used as a reference for orientation and spatial metadata. + spart : np.ndarray + Binary mask of the first component (S). + tpart : np.ndarray + Binary mask of the second component (T). + spart_dil : np.ndarray + Dilated mask of spart. + tpart_dil : np.ndarray + Dilated mask of tpart. + + Returns + ------- + plane_filled_nii : NII + A NIfTI image containing a filled binary plane separating spart and tpart, + reoriented to match the input NIfTI image. If no collision is detected, + returns an empty image. + + Notes + ----- + - The function uses the collision area between the dilated masks to find + a center of mass and constructs a plane orthogonal to the vector between + the COMs of spart and tpart. + - Filling the plane ensures better visualization and compatibility with downstream tasks. + - If the dilated masks do not overlap, an empty volume is returned and a warning is logged. + + TODO + ---- + - Improve accuracy by using the line connecting both COMs and projecting the collision + point onto this vector, rather than relying on the COM of the overlap region. + """ + s_dilint = spart_dil.astype(np.uint8) t_dilint = tpart_dil.astype(np.uint8) collision_arr = s_dilint + t_dilint @@ -506,6 +598,7 @@ def collect_vertebra_predictions( Location.Vertebra_Disc.value: 0, Location.Endplate.value: 0, 26: 0, + 51: 0, }, verbose=False, ) diff --git a/spineps/phase_labeling.py b/spineps/phase_labeling.py index 0c0e9e8..b90677d 100644 --- a/spineps/phase_labeling.py +++ b/spineps/phase_labeling.py @@ -36,6 +36,10 @@ def perform_labeling_step( ): model.load() + if 26 in vert_nii.unique(): + has_sacrum = vert_nii.volumes()[26] > 500 # noqa: F841 + # TODO remove sacrum for labeling and make a separate step for sacrum labeling + if subreg_nii is not None: # crop for corpus instead of whole vertebra corpus_nii = subreg_nii.extract_label((Location.Vertebra_Corpus, Location.Vertebra_Corpus_border)) @@ -204,7 +208,11 @@ def prepare_vertrel_columns(vertrel_matrix: np.ndarray, gaussian_sigma: float = if gaussian_sigma > 0.0 and np.sum(vertrel_matrix) > 0.0: vertrel_matrix[:, i] = gaussian_filter1d(vertrel_matrix[:, i], sigma=gaussian_sigma, mode="nearest", radius=gaussian_radius) # normalize per column / label in this case - vertrel_matrix[:, i] = vertrel_matrix[:, i] / (np.sum(vertrel_matrix[:, i]) + DIVIDE_BY_ZERO_OFFSET) + vertrel_sum = np.sum(vertrel_matrix[:, i]) + DIVIDE_BY_ZERO_OFFSET + if vertrel_sum > 1.0: + vertrel_matrix[:, i] = vertrel_matrix[:, i] / vertrel_sum + elif vertrel_sum < 1.0: + vertrel_matrix[:, i] = vertrel_matrix[:, i] / (1.0 + vertrel_sum) return vertrel_matrix @@ -225,18 +233,21 @@ def prepare_vertrel(vertrel_softmax_values: np.ndarray, gaussian_sigma: float = def find_vert_path_from_predictions( predictions, - visible_w: float = 1.0, + visible_w: float = 0.5, vert_w: float = 0.9, # 0.9 vertgrp_w: float = 0.8, region_w: float = 1.1, # 1.1 - vertrel_w: float = 0.3, # 0.3 + vertrel_w: float = 0.6, # 0.3 vertt13_w: float = 0.4, disable_c1: bool = True, boost_c2: float = 1.0, # 3.0 allow_cervical_skip: bool = False, + allow_thoracic_skip: bool = False, + allow_lumbar_skip: bool = False, # punish_multiple_sequence: float = 0.0, punish_skip_sequence: float = 0.0, + punish_skip_at_region_sequence: float = 0.0, # region_gaussian_sigma: float = 0.0, # 0 means no gaussian vert_gaussian_sigma: float = 0.8, # 0.8 0 means no gaussian @@ -253,6 +264,7 @@ def find_vert_path_from_predictions( ): args = locals() assert 0 <= visible_w, visible_w # noqa: SIM300 + assert visible_w <= 1.0, f"visible_w must be <= 1.0, got {visible_w}" assert 0 <= vert_w, vert_w # noqa: SIM300 assert 0 <= region_w, region_w # noqa: SIM300 assert 0 <= vertrel_w, vertrel_w # noqa: SIM300 @@ -263,6 +275,7 @@ def find_vert_path_from_predictions( cost_matrix = np.zeros((n_vert, 24)) # TODO 24 fix? relative_cost_matrix = np.zeros((n_vert, 6)) # TODO 6 fix? visible_chain = prepare_visible(predictions, visible_w) + # print(visible_chain) predict_keys = list(predictions[list(predictions.keys())[0]]["soft"].keys()) # noqa: RUF015 assert "VERT" in predict_keys or "VERTEXACT" in predict_keys or "VERTGRP" in predict_keys, ( @@ -365,6 +378,13 @@ def find_vert_path_from_predictions( else: allow_multiple_at_class = [18, 23] if not proc_lab_force_no_tl_anomaly else [23] # T12 and L5 allow_skip_at_class = [17] if not proc_lab_force_no_tl_anomaly else [] # T11 + allow_skip_at_region = [] + if allow_cervical_skip: + allow_skip_at_region.append(0) + if allow_thoracic_skip: + allow_skip_at_region.append(1) + if allow_lumbar_skip: + allow_skip_at_region.append(2) fcost, fpath, min_costs_path = find_most_probably_sequence( # input cost_matrix, @@ -380,8 +400,9 @@ def find_vert_path_from_predictions( allow_multiple_at_class=allow_multiple_at_class, # T12 and L5 allow_skip_at_class=allow_skip_at_class, # T11 # - allow_skip_at_region=[0] if allow_cervical_skip else [], - punish_skip_at_region_sequence=0.2 if allow_cervical_skip else 0.0, + allow_skip_at_region=allow_skip_at_region, + punish_skip_at_region_sequence=punish_skip_at_region_sequence, + verbose=False, ) # post processing fpath_post = fpath_post_processing(fpath) diff --git a/spineps/phase_post.py b/spineps/phase_post.py index 0516cf7..98f3380 100644 --- a/spineps/phase_post.py +++ b/spineps/phase_post.py @@ -241,7 +241,7 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII, verbose=True): orientation = whole_vert_nii.orientation vert_t = whole_vert_nii.reorient() seg_t = seg_nii.reorient() - vert_labels = vert_t.unique() # without zero + vert_labels = [t for t in vert_t.unique() if t <= 26 or t == 28] # without zero vert_arr = vert_t.get_seg_array() subreg_arr = seg_t.get_seg_array() @@ -286,8 +286,8 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII, verbose=True): # find which vert got how many ivd CCs to_mapped_labels = list(mapping_cc_to_vert_label.values()) - for l in vert_labels: - if l not in to_mapped_labels: + for i, l in enumerate(vert_labels): + if l not in to_mapped_labels and l != 1: logger.print(f"Vertebra {v_idx2name[l]} got no IVD component assigned", Log_Type.STRANGE) count = to_mapped_labels.count(l) if count > 1: @@ -455,7 +455,8 @@ def detect_and_solve_merged_vertebra(seg_nii: NII, vert_nii: NII): volumes = subreg_cc.volumes() stats = {i: (g[1], True, volumes[i]) for i, g in coms.items()} - vert_coms = vert_nii.center_of_masses() + corpus_nii = seg_sem.extract_label([Location.Vertebra_Corpus_border.value, Location.Arcus_Vertebrae.value]) * vert_nii + vert_coms = corpus_nii.center_of_masses() vert_volumes = vert_nii.volumes() for i, g in vert_coms.items(): diff --git a/spineps/phase_pre.py b/spineps/phase_pre.py index 619f1e2..540ea5b 100644 --- a/spineps/phase_pre.py +++ b/spineps/phase_pre.py @@ -26,7 +26,7 @@ def preprocess_input( try: # Enforce to range [0, 1500] if proc_normalize_input: - mri_nii.normalize_to_range_(min_value=0, max_value=9000, verbose=False) + mri_nii.normalize_to_range_(min_value=0, max_value=1500, verbose=False) crop = mri_nii.compute_crop(dist=0) if proc_crop_input else (slice(None, None), slice(None, None), slice(None, None)) else: crop = ( diff --git a/spineps/utils/find_min_cost_path.py b/spineps/utils/find_min_cost_path.py index 24fa77f..d56666f 100644 --- a/spineps/utils/find_min_cost_path.py +++ b/spineps/utils/find_min_cost_path.py @@ -5,6 +5,7 @@ from warnings import warn import numpy as np +from TPTBox import Log_Type, No_Logger def argmin(lst): @@ -30,6 +31,7 @@ def internal_to_real_path(p): return pat +# TODO: make clear recursion calls with extra cost for the path? def find_most_probably_sequence( # noqa: C901 cost: np.ndarray | list[int], # @@ -51,7 +53,11 @@ def find_most_probably_sequence( # noqa: C901 # allow_skip_at_region: list[int] | None = None, punish_skip_at_region_sequence: float = 0.2, + # + verbose: bool = False, ) -> tuple[float, list[int], list]: + logger = No_Logger() + logger.default_verbose = verbose # default mutable arguments if allow_skip_at_region is None: allow_skip_at_region = [0] @@ -109,6 +115,7 @@ def add_option_path(options, r, c, extracost): # main recursive loop def minCostAlgo(r, c): + logger.print(f"Called vert {r}, label {c}") # get current region region_cur = c_to_region_idx(c, regions) # start point @@ -116,22 +123,26 @@ def minCostAlgo(r, c): # go over each possible start column options = [] for cc in range(min_start_class, n_classes): - add_option_path(options, 0, cc, 0) + with logger: + # logger.default_verbose = cc in [7, 8, 9] + add_option_path(options, 0, cc, 0) # options.append(minCostAlgo(r=0, c=cc)) minidx, minval = argmin([o[0] for o in options]) return minval, options[minidx][1] # stepped over the line elif c < 0 or r < 0 or c >= shape[1] or r >= shape[0]: + logger.print(f"Out of bounds vert {r}, label {c}") return sys.maxsize, [(r, c)] # last row, path end elif r == shape[0] - 1: + # logger.print(f"End of path vert {r}, label {c}") # path_tothis.append((r, c)) cost_value = costlist[r][c] p = [(r, c)] # transition cost of vertrel cost_value += rel_cost(r, c, p, region_cur) - # if cost_value < 0: - # print(f"Endpoint {r}, {c} to {cost_value}, {p}") + if cost_value < 0: + logger.print(f"End of path vert {r}, label {c} to {cost_value}, {internal_to_real_path(p)}") return (cost_value, p) # check min of move directions else: @@ -141,21 +152,25 @@ def minCostAlgo(r, c): # rel_costadd = rel_cost(r, c, [(r, c)], region_cur) options = [] # normal diagonal edge - add_option_path(options, r + 1, c + 1, 0) + with logger: + add_option_path(options, r + 1, c + 1, 0) # allow two subsequent of same class if c in allow_multiple_at_class: cost_add = punish_multiple_sequence if c == 18: cost_add += t13_cost_single(r + 1, c) - add_option_path(options, r + 1, c, cost_add) + with logger: + add_option_path(options, r + 1, c, cost_add) # Allow skips at certain classes if c in allow_skip_at_class: cost_add = punish_skip_sequence - add_option_path(options, r + 1, c + 2, cost_add) + with logger: + add_option_path(options, r + 1, c + 2, cost_add) # Allow skips in certain regions if region_cur in allow_skip_at_region and c != regions_ranges[region_cur][1] - 1: cost_add = punish_skip_at_region_sequence - add_option_path(options, r + 1, c + 2, punish_skip_at_region_sequence) + with logger: + add_option_path(options, r + 1, c + 2, punish_skip_at_region_sequence) # find min minidx, minval = argmin([o[0] for o in options]) pnext = options[minidx][1] @@ -172,8 +187,8 @@ def minCostAlgo(r, c): break # setting to memory min_costs_path[r][c] = (cost_value, p) - # if cost_value < 0: - # print(f"Setting {r}, {c} to {cost_value}, {p}") + if cost_value < 0: + logger.print(f"Setting vert {r}, label {c} to {cost_value}, {internal_to_real_path(p)}") return min_costs_path[r][c] # def t13_cost(r, c, pnext, p, region_cur): @@ -211,11 +226,11 @@ def rel_cost(r, c, pnext, region_cur): if rel_cost == 0: continue if last == 0 and c == regions_ranges[region_cur][0]: - # print(f"Added F {rel_cost} to {r}, {c}, {internal_to_real_path(pnext)}") + logger.print(f"Added F {rel_cost} to vert {r}, label {c}, {internal_to_real_path(pnext)}") cost_add += rel_cost # break - elif last == 1 and c_to_region_idx(pnext[-1][1], regions) >= region_cur + 1: - # print(f"Added L {rel_cost} to {r}, {c}, {internal_to_real_path(pnext)}") + elif last == 1 and (c_to_region_idx(pnext[-1][1], regions) >= region_cur + 1): # or pnext[-1][1] == c): + logger.print(f"Added L {rel_cost} to vert {r}, label {c}, {internal_to_real_path(pnext)}") cost_add += rel_cost return cost_add diff --git a/unit_tests/test_architectures.py b/unit_tests/test_architectures.py index d134d36..4245c6d 100644 --- a/unit_tests/test_architectures.py +++ b/unit_tests/test_architectures.py @@ -149,7 +149,6 @@ def test_simple_testing_case(self): subject_name=1337, anomaly_dict={}, vert_subfolders_int=vert_subfolders_int, - anomaly_factor_condition=1, ) for v in vert_subfolders_int: