|
| 1 | +# Modified from: |
| 2 | +# https://github.com/nipreps/synthstrip/blob/main/nipreps/synthstrip/cli.py |
| 3 | +# Original copyright (c) 2024, NiPreps developers |
| 4 | +# Licensed under the Apache License, Version 2.0 |
| 5 | +# Changes made by the BrainLesion Preprocessing team (2025) |
| 6 | + |
| 7 | +from pathlib import Path |
| 8 | +from typing import Optional, Union, cast |
| 9 | + |
| 10 | +import nibabel as nib |
| 11 | +import numpy as np |
| 12 | +import scipy |
| 13 | +import torch |
| 14 | +from nibabel.nifti1 import Nifti1Image |
| 15 | +from nipreps.synthstrip.model import StripModel |
| 16 | +from nitransforms.linear import Affine |
| 17 | + |
| 18 | +from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor |
| 19 | +from brainles_preprocessing.utils.zenodo import fetch_synthstrip |
| 20 | + |
| 21 | + |
| 22 | +class SynthStripExtractor(BrainExtractor): |
| 23 | + |
| 24 | + def __init__(self, border: int = 1): |
| 25 | + """ |
| 26 | + Brain extraction using SynthStrip with preprocessing conforming to model requirements. |
| 27 | +
|
| 28 | + This is an optional dependency - to use this extractor, you need to install the `brainles_preprocessing` package with the `synthstrip` extra: `pip install brainles_preprocessing[synthstrip]` |
| 29 | +
|
| 30 | + Adapted from https://github.com/nipreps/synthstrip |
| 31 | +
|
| 32 | + Args: |
| 33 | + border (int): Mask border threshold in mm. Defaults to 1. |
| 34 | + """ |
| 35 | + |
| 36 | + super().__init__() |
| 37 | + self.border = border |
| 38 | + |
| 39 | + def _setup_model(self, device: torch.device) -> StripModel: |
| 40 | + """ |
| 41 | + Load SynthStrip model and prepare it for inference on the specified device. |
| 42 | +
|
| 43 | + Args: |
| 44 | + device: Device to load the model onto. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + A configured and ready-to-use StripModel. |
| 48 | + """ |
| 49 | + # necessary for speed gains (according to original nipreps authors) |
| 50 | + torch.backends.cudnn.benchmark = True |
| 51 | + torch.backends.cudnn.deterministic = True |
| 52 | + |
| 53 | + with torch.no_grad(): |
| 54 | + model = StripModel() |
| 55 | + model.to(device) |
| 56 | + model.eval() |
| 57 | + |
| 58 | + # Load the model weights |
| 59 | + weights_folder = fetch_synthstrip() |
| 60 | + weights = weights_folder / "synthstrip.1.pt" |
| 61 | + checkpoint = torch.load(weights, map_location=device) |
| 62 | + model.load_state_dict(checkpoint["model_state_dict"]) |
| 63 | + |
| 64 | + return model |
| 65 | + |
| 66 | + def _conform(self, input_nii: Nifti1Image) -> Nifti1Image: |
| 67 | + """ |
| 68 | + Resample the input image to match SynthStrip's expected input space. |
| 69 | +
|
| 70 | + Args: |
| 71 | + input_nii (Nifti1Image): Input NIfTI image to conform. |
| 72 | +
|
| 73 | + Raises: |
| 74 | + ValueError: If the input NIfTI image does not have a valid affine. |
| 75 | +
|
| 76 | + Returns: |
| 77 | + A new NIfTI image with conformed shape and affine. |
| 78 | + """ |
| 79 | + |
| 80 | + shape = np.array(input_nii.shape[:3]) |
| 81 | + affine = input_nii.affine |
| 82 | + |
| 83 | + if affine is None: |
| 84 | + raise ValueError("Input NIfTI image must have a valid affine.") |
| 85 | + |
| 86 | + # Get corner voxel centers in index coords |
| 87 | + corner_centers_ijk = ( |
| 88 | + np.array( |
| 89 | + [ |
| 90 | + (i, j, k) |
| 91 | + for k in (0, shape[2] - 1) |
| 92 | + for j in (0, shape[1] - 1) |
| 93 | + for i in (0, shape[0] - 1) |
| 94 | + ] |
| 95 | + ) |
| 96 | + + 0.5 |
| 97 | + ) |
| 98 | + |
| 99 | + # Get corner voxel centers in mm |
| 100 | + corners_xyz = ( |
| 101 | + affine |
| 102 | + @ np.hstack((corner_centers_ijk, np.ones((len(corner_centers_ijk), 1)))).T |
| 103 | + ) |
| 104 | + |
| 105 | + # Target affine is 1mm voxels in LIA orientation |
| 106 | + target_affine = np.diag([-1.0, 1.0, -1.0, 1.0])[:, (0, 2, 1, 3)] |
| 107 | + |
| 108 | + # Target shape |
| 109 | + extent = corners_xyz.min(1)[:3], corners_xyz.max(1)[:3] |
| 110 | + target_shape = ((extent[1] - extent[0]) / 1.0 + 0.999).astype(int) |
| 111 | + |
| 112 | + # SynthStrip likes dimensions be multiple of 64 (192, 256, or 320) |
| 113 | + target_shape = np.clip( |
| 114 | + np.ceil(np.array(target_shape) / 64).astype(int) * 64, 192, 320 |
| 115 | + ) |
| 116 | + |
| 117 | + # Ensure shape ordering is LIA too |
| 118 | + target_shape[2], target_shape[1] = target_shape[1:3] |
| 119 | + |
| 120 | + # Coordinates of center voxel do not change |
| 121 | + input_c = affine @ np.hstack((0.5 * (shape - 1), 1.0)) |
| 122 | + target_c = target_affine @ np.hstack((0.5 * (target_shape - 1), 1.0)) |
| 123 | + |
| 124 | + # Rebase the origin of the new, plumb affine |
| 125 | + target_affine[:3, 3] -= target_c[:3] - input_c[:3] |
| 126 | + |
| 127 | + nii = Affine( |
| 128 | + reference=Nifti1Image( |
| 129 | + np.zeros(target_shape), |
| 130 | + target_affine, |
| 131 | + None, |
| 132 | + ), |
| 133 | + ).apply(input_nii) |
| 134 | + return cast(Nifti1Image, nii) |
| 135 | + |
| 136 | + def _resample_like( |
| 137 | + self, |
| 138 | + image: Nifti1Image, |
| 139 | + target: Nifti1Image, |
| 140 | + output_dtype: Optional[np.dtype] = None, |
| 141 | + cval: Union[int, float] = 0, |
| 142 | + ) -> Nifti1Image: |
| 143 | + """ |
| 144 | + Resample the input image to match the target's grid using an identity transform. |
| 145 | +
|
| 146 | + Args: |
| 147 | + image: The image to be resampled. |
| 148 | + target: The reference image. |
| 149 | + output_dtype: Output data type. |
| 150 | + cval: Value to use for constant padding. |
| 151 | +
|
| 152 | + Returns: |
| 153 | + A resampled NIfTI image. |
| 154 | + """ |
| 155 | + result = Affine(reference=target).apply( |
| 156 | + image, |
| 157 | + output_dtype=output_dtype, |
| 158 | + cval=cval, |
| 159 | + ) |
| 160 | + return cast(Nifti1Image, result) |
| 161 | + |
| 162 | + def extract( |
| 163 | + self, |
| 164 | + input_image_path: Union[str, Path], |
| 165 | + masked_image_path: Union[str, Path], |
| 166 | + brain_mask_path: Union[str, Path], |
| 167 | + device: Union[torch.device, str] = "cuda", |
| 168 | + num_threads: int = 1, |
| 169 | + **kwargs, |
| 170 | + ) -> None: |
| 171 | + """ |
| 172 | + Extract the brain from an input image using SynthStrip. |
| 173 | +
|
| 174 | + Args: |
| 175 | + input_image_path (Union[str, Path]): Path to the input image. |
| 176 | + masked_image_path (Union[str, Path]): Path to the output masked image. |
| 177 | + brain_mask_path (Union[str, Path]): Path to the output brain mask. |
| 178 | + device (Union[torch.device, str], optional): Device to use for computation. Defaults to "cuda". |
| 179 | + num_threads (int, optional): Number of threads to use for computation in CPU mode. Defaults to 1. |
| 180 | +
|
| 181 | + Returns: |
| 182 | + None: The function saves the masked image and brain mask to the specified paths. |
| 183 | + """ |
| 184 | + |
| 185 | + device = torch.device(device) if isinstance(device, str) else device |
| 186 | + model = self._setup_model(device=device) |
| 187 | + |
| 188 | + if device.type == "cpu" and num_threads > 0: |
| 189 | + torch.set_num_threads(num_threads) |
| 190 | + |
| 191 | + # normalize intensities |
| 192 | + image = nib.load(input_image_path) |
| 193 | + image = cast(Nifti1Image, image) |
| 194 | + conformed = self._conform(image) |
| 195 | + in_data = conformed.get_fdata(dtype="float32") |
| 196 | + in_data -= in_data.min() |
| 197 | + in_data = np.clip(in_data / np.percentile(in_data, 99), 0, 1) |
| 198 | + in_data = in_data[np.newaxis, np.newaxis] |
| 199 | + |
| 200 | + # predict the surface distance transform |
| 201 | + input_tensor = torch.from_numpy(in_data).to(device) |
| 202 | + with torch.no_grad(): |
| 203 | + sdt = model(input_tensor).cpu().numpy().squeeze() |
| 204 | + |
| 205 | + # unconform the sdt and extract mask |
| 206 | + sdt_target = self._resample_like( |
| 207 | + Nifti1Image(sdt, conformed.affine, None), |
| 208 | + image, |
| 209 | + output_dtype=np.dtype("int16"), |
| 210 | + cval=100, |
| 211 | + ) |
| 212 | + sdt_data = np.asanyarray(sdt_target.dataobj).astype("int16") |
| 213 | + |
| 214 | + # find largest CC (just do this to be safe for now) |
| 215 | + components = scipy.ndimage.label(sdt_data.squeeze() < self.border)[0] |
| 216 | + bincount = np.bincount(components.flatten())[1:] |
| 217 | + mask = components == (np.argmax(bincount) + 1) |
| 218 | + mask = scipy.ndimage.morphology.binary_fill_holes(mask) |
| 219 | + |
| 220 | + # write the masked output |
| 221 | + img_data = image.get_fdata() |
| 222 | + bg = np.min([0, img_data.min()]) |
| 223 | + img_data[mask == 0] = bg |
| 224 | + Nifti1Image(img_data, image.affine, image.header).to_filename( |
| 225 | + masked_image_path, |
| 226 | + ) |
| 227 | + |
| 228 | + # write the brain mask |
| 229 | + hdr = image.header.copy() |
| 230 | + hdr.set_data_dtype("uint8") |
| 231 | + Nifti1Image(mask, image.affine, hdr).to_filename(brain_mask_path) |
0 commit comments