Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

# SPINEPS – Automatic Whole Spine Segmentation of T2w MR images using a Two-Phase Approach to Multi-class Semantic and Instance Segmentation.
# and
# VERIDAH: Solving Enumeration Anomaly Aware Vertebra Labeling across Imaging Sequences

This is a segmentation pipeline to automatically, and robustly, segment the whole spine in T2w sagittal images.

Expand Down Expand Up @@ -268,6 +270,23 @@ In the subregion segmentation:

In the vertebra instance segmentation mask, each label X in [1, 25] are the unique vertebrae, while 100+X are their corresponding IVD and 200+X their endplates.

## VERIDAH:

To run the vertebra labeling after segmentation, specify a -model_labeling model (similar to -model_semantic and -model_instance).

If you use VERIDAH (labeling model) in addition to the segmentation models from SPINEPS, then a labeling model will run and give each vertebrae detected by SPINEPS a vertebra label. These are

| Label | Structure |
| :---: | --------- |
| 1 | C1 |
| 2 - 7 | C2 - C7 |
| 8 - 19 | T1 - T12 |
| 28 | T13 |
| 20 | L1 |
| 21 - 25 | L2 - L6 |
| 26 | Sacrum |

The labels 100+X still correspond to the vertebra's IVD and 200+X the respective endplate. For example, the label 119 is the IVD below the T12 vertebra.

## Using the Code

Expand Down
89 changes: 83 additions & 6 deletions spineps/architectures/pl_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,82 @@
import os
import sys
from dataclasses import dataclass
from enum import Enum
from pathlib import Path

import pytorch_lightning as pl
import torch
from monai.networks.nets import DenseNet169
from monai.networks.nets import DenseNet121, DenseNet169
from monai.networks.nets.resnet import (
ResNet,
ResNetBlock,
_resnet,
get_inplanes,
resnet10,
resnet18,
resnet34,
resnet50,
resnet101,
resnet152,
)
from torch import nn
from TypeSaveArgParse import Class_to_ArgParse


def resnet2(
layers: list[int] | None = None,
**kwargs,
):
if layers is None:
layers = [1, 1]
return _resnet("resnet2", ResNetBlock, layers, get_inplanes(), False, False, **kwargs)


class MODEL(Enum):
DENSENET169 = DenseNet169
DENSENET121 = DenseNet121
RESNET10 = 10 # resnet10
RESNET18 = 18 # resnet18
RESNET34 = 34 # resnet34
RESNET50 = 50 # resnet50
RESNET101 = 101 # resnet101
RESNET152 = 152 # resnet152
RESNET2 = 2 # resnet2

def __call__(
self,
opt: ARGS_MODEL,
remove_classification_head: bool = True,
):
if "DENSENET" in self.name:
return get_densenet_architecture(
self.value,
in_channel=opt.in_channel,
out_channel=opt.num_classes,
pretrained=not opt.not_pretrained,
remove_classification_head=remove_classification_head,
)
elif "RESNET" in self.name:
d = {
10: resnet10,
18: resnet18,
34: resnet34,
50: resnet50,
101: resnet101,
152: resnet152,
2: resnet2,
}
return get_resnet_architecture(
d[self.value],
remove_classification_head=remove_classification_head,
)
else:
raise ValueError(f"Model {self.name} not supported.")


@dataclass
class ARGS_MODEL(Class_to_ArgParse):
backbone: MODEL = MODEL.DENSENET169.name
classification_conv: bool = False
classification_linear: bool = True
#
Expand Down Expand Up @@ -43,9 +108,8 @@ def __init__(self, opt: ARGS_MODEL, group_2_n_channel: dict[str, int]):
# save hyperparameter, everything below not visible
self.save_hyperparameters()

self.net, linear_in = get_architecture(
DenseNet169, opt.in_channel, opt.num_classes, pretrained=False, remove_classification_head=True
)
self.backbone = MODEL[opt.backbone]
self.net, linear_in = self.backbone(opt, remove_classification_head=True)
self.classification_heads = self.build_classification_heads(linear_in, opt.classification_conv, opt.classification_linear)
self.classification_keys = list(self.classification_heads.keys())
self.mse_weighting = opt.mse_weighting
Expand Down Expand Up @@ -89,7 +153,7 @@ def __str__(self) -> str:
return "VertebraLabelingModel"


def get_architecture(
def get_densenet_architecture(
model,
in_channel: int = 1,
out_channel: int = 1,
Expand All @@ -102,8 +166,21 @@ def get_architecture(
out_channels=out_channel,
pretrained=pretrained,
)
linear_infeatures = 0
linear_infeatures = model.class_layers[-1].in_features
if remove_classification_head:
model.class_layers = model.class_layers[:-1]
return model, linear_infeatures


def get_resnet_architecture(
model,
remove_classification_head: bool = True,
):
model = model(
spatial_dims=3,
n_input_channels=1,
)
linear_infeatures = model.fc.in_features
if remove_classification_head:
model.fc = None
return model, linear_infeatures
47 changes: 38 additions & 9 deletions spineps/architectures/read_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,12 @@ class SubjectInfo:
first_lwk: int = 20
double_entries: list[int] = field(default_factory=list)

@property
def has_tea(self) -> bool:
if not self.has_anomaly_entry:
return None
Copy link

Copilot AI Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The has_tea property returns None when has_anomaly_entry is False, but the return type annotation indicates bool. This should return False or the annotation should be bool | None.

Suggested change
return None
return False

Copilot uses AI. Check for mistakes.
return self.anomaly_entry["T11"] or self.anomaly_entry["T13"]

@property
def block(self) -> int:
return int(str(self.subject_name)[:3])
Expand All @@ -393,15 +399,17 @@ def get_subject_info(
subject_name: str | int,
anomaly_dict: dict,
vert_subfolders_int: list[int],
anomaly_factor_condition: int = 0,
subject_name_int: bool = True,
):
if subject_name_int:
subject_name = int(subject_name)
double_entries = []
labelmap = {}
has_anomaly_entry = False
anomaly_entry = {}
deleted_label = []
is_remove = False
if int(subject_name) in anomaly_dict:
if subject_name in anomaly_dict:
anomaly_entry = anomaly_dict[subject_name]
has_anomaly_entry = True
if anomaly_entry["DeleteLabel"] is not None:
Expand All @@ -411,22 +419,43 @@ def get_subject_info(

if bool(anomaly_entry["T11"]):
labelmap = {i: i + 1 for i in range(19, 26)}
double_entries = [17, 18, 20, 21]
elif bool(anomaly_entry["T13"]):
labelmap = {20: 28, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24}
double_entries = [19, 28, 20, 21]
elif anomaly_factor_condition == 0:
double_entries = [18, 19, 20, 21]

if "LabelOverride" in anomaly_entry and anomaly_entry["LabelOverride"] is not None:
assert len(anomaly_entry["LabelOverride"]) == len(vert_subfolders_int), (
f"len({anomaly_entry['LabelOverride']}) != len({vert_subfolders_int})"
)
vert_subfolders_sorted = sorted(vert_subfolders_int, key=lambda x: x if x != 28 else 19.5)
labelmap = {i: k for i, k in zip(vert_subfolders_sorted, anomaly_entry["LabelOverride"], strict=False)} # noqa: C416

actual_labels = [labelmap.get(v, v) for v in vert_subfolders_int]

if 28 in actual_labels and 19 not in actual_labels:
print(f"{subject_name}: 28 in {actual_labels} but no 19")
Copy link

Copilot AI Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use logger.print() instead of print() for consistency with the logging pattern used throughout the codebase.

Suggested change
print(f"{subject_name}: 28 in {actual_labels} but no 19")
logger.print(f"{subject_name}: 28 in {actual_labels} but no 19")

Copilot uses AI. Check for mistakes.
is_remove = True

# T11
if 18 in actual_labels and 19 not in actual_labels and 20 in actual_labels:
double_entries = [17, 18, 20, 21]
elif 28 in actual_labels:
double_entries = [19, 28, 20, 21]
else:
double_entries = [18, 19, 20, 21]

if len(anomaly_dict) == 0:
double_entries = []

#
# last_hwk = 7
# first_bwk = 8
last_bwk = max([v for v in actual_labels if 7 < v <= 19 or v == 28]) if max(actual_labels) >= 18 else None
bwks = [v for v in actual_labels if 7 < v <= 19 or v == 28]
last_bwk = max(bwks) if max(actual_labels) >= 18 and len(bwks) > 0 else None
# first_lwk = 20
last_lwk = max([v for v in actual_labels if 22 < v < 26]) if max(actual_labels) >= 23 else None
lwks = [v for v in actual_labels if 22 < v < 26]
last_lwk = max(lwks) if max(actual_labels) >= 23 and len(lwks) > 0 else None
return SubjectInfo(
subject_name=int(subject_name),
subject_name=subject_name,
has_anomaly_entry=has_anomaly_entry,
anomaly_entry=anomaly_entry,
actual_labels=actual_labels,
Expand Down
Loading
Loading