diff --git a/SwinUNETR/BTCV/README.md b/SwinUNETR/BTCV/README.md index 623f9b5e..73c5532f 100644 --- a/SwinUNETR/BTCV/README.md +++ b/SwinUNETR/BTCV/README.md @@ -15,6 +15,23 @@ Dependencies can be installed using: pip install -r requirements.txt ``` +# Huggingface inference API + +To install necessary dependencies, run the below in bash. +``` +git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc +pip install -r pmrc/requirements.txt +cd pmrc/SwinUNETR/BTCV +``` + +To load the model from the hub. +``` +from swinunetr import SwinUnetrModelForInference +model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny') +``` + +You can also use `predict.py` to run inference for sample dicom medical images. + # Models Please download the self-supervised pre-trained weights for Swin UNETR backbone (CVPR paper [1]) from this link. diff --git a/SwinUNETR/BTCV/dataset/imagesSampleTs/img0061.nii.gz b/SwinUNETR/BTCV/dataset/imagesSampleTs/img0061.nii.gz new file mode 100755 index 00000000..c722cdbe Binary files /dev/null and b/SwinUNETR/BTCV/dataset/imagesSampleTs/img0061.nii.gz differ diff --git a/SwinUNETR/BTCV/dataset/imagesSampleTs/img0062.nii.gz b/SwinUNETR/BTCV/dataset/imagesSampleTs/img0062.nii.gz new file mode 100755 index 00000000..042818d0 Binary files /dev/null and b/SwinUNETR/BTCV/dataset/imagesSampleTs/img0062.nii.gz differ diff --git a/SwinUNETR/BTCV/predict.py b/SwinUNETR/BTCV/predict.py new file mode 100755 index 00000000..a3488945 --- /dev/null +++ b/SwinUNETR/BTCV/predict.py @@ -0,0 +1,98 @@ +import os +import glob +import shutil +import torch +import argparse +import cv2 +import mediapy +import numpy as np +from skimage import color, img_as_ubyte +from monai import transforms, data +from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig + +parser = argparse.ArgumentParser(description='Swin UNETR segmentation pipeline') +parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='device for model - cpu/gpu') +parser.add_argument('--a_min', default=-175.0, type=float, help='a_min in ScaleIntensityRanged') +parser.add_argument('--a_max', default=250.0, type=float, help='a_max in ScaleIntensityRanged') +parser.add_argument('--b_min', default=0.0, type=float, help='b_min in ScaleIntensityRanged') +parser.add_argument('--b_max', default=1.0, type=float, help='b_max in ScaleIntensityRanged') +parser.add_argument('--infer_overlap', default=0.5, type=float, help='sliding window inference overlap') +parser.add_argument('--space_x', default=1.5, type=float, help='spacing in x direction') +parser.add_argument('--space_y', default=1.5, type=float, help='spacing in y direction') +parser.add_argument('--space_z', default=2.0, type=float, help='spacing in z direction') +parser.add_argument('--roi_x', default=96, type=int, help='roi size in x direction') +parser.add_argument('--roi_y', default=96, type=int, help='roi size in y direction') +parser.add_argument('--roi_z', default=96, type=int, help='roi size in z direction') +parser.add_argument('--last_n_frames', default=64, type=int, help='Limit the frames inference. -1 for all frames.') +args = parser.parse_args() + +ffmpeg_path = shutil.which('ffmpeg') +mediapy.set_ffmpeg(ffmpeg_path) + +model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny') +model.eval() +model.to(args.device) + +test_files = glob.glob('dataset/imagesSampleTs/*.nii.gz') +test_files = [{'image': f} for f in test_files] + +test_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image"]), + transforms.AddChanneld(keys=["image"]), + transforms.Spacingd(keys="image", + pixdim=(args.space_x, args.space_y, args.space_z), + mode="bilinear"), + transforms.ScaleIntensityRanged(keys=["image"], + a_min=args.a_min, + a_max=args.a_max, + b_min=args.b_min, + b_max=args.b_max, + clip=True), + #transforms.Resized(keys=["image"], spatial_size = (256,256,-1)), + transforms.ToTensord(keys=["image"]), + ]) + +test_ds = test_transform(test_files) +test_loader = data.DataLoader(test_ds, + batch_size=1, + shuffle=False) + +for i, batch in enumerate(test_loader): + + tst_inputs = batch["image"] + if args.last_n_frames>0: + tst_inputs = tst_inputs[:,:,:,:,-args.last_n_frames:] + + with torch.no_grad(): + outputs = model(tst_inputs, + (args.roi_x, + args.roi_y, + args.roi_z), + 8, + overlap=args.infer_overlap, + mode="gaussian") + + tst_outputs = torch.softmax(outputs.logits, 1) + tst_outputs = torch.argmax(tst_outputs, axis=1) + + fnames = batch['image_meta_dict']['filename_or_obj'] + + # Write frames to video + + for fname, inp, outp in zip(fnames, tst_inputs, tst_outputs): + + dicom_name = fname.split('/')[-1] + video_name = f'videos/{dicom_name}.mp4' + frames = [] + for idx in range(inp.shape[-1]): + # Segmentation + seg = outp[:,:,idx].numpy().astype(np.uint8) + # Input dicom frame + img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8) + img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) + frame = color.label2rgb(seg,img, bg_label = 0) + frame = img_as_ubyte(frame) + frame = np.concatenate((img, frame), 1) + frames.append(frame) + mediapy.write_video(video_name, frames, fps=4) diff --git a/SwinUNETR/BTCV/requirements.txt b/SwinUNETR/BTCV/requirements.txt index 2718cc44..49de077a 100644 --- a/SwinUNETR/BTCV/requirements.txt +++ b/SwinUNETR/BTCV/requirements.txt @@ -1,6 +1,12 @@ +transformers==4.20.1 +torch==1.10.0 + git+https://github.com/Project-MONAI/MONAI#egg.gitmonai@0.8.1+271.g07de215c nibabel==3.1.1 tqdm==4.59.0 einops==0.4.1 tensorboardX==2.1 -scipy==1.2.1 +scipy==1.5.0 +mediapy==1.0.3 +scikit-image==0.17.2 +opencv-python==4.6.0.66 diff --git a/SwinUNETR/BTCV/swinunetr/__init__.py b/SwinUNETR/BTCV/swinunetr/__init__.py new file mode 100755 index 00000000..fd14c007 --- /dev/null +++ b/SwinUNETR/BTCV/swinunetr/__init__.py @@ -0,0 +1,53 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 NAVER CLOVA Team. All rights reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.file_utils import ( + _LazyModule, + is_torch_available, +) + +_import_structure = { + "configuration_swinunetr": ["SWINUNETR_PRETRAINED_CONFIG_ARCHIVE_MAP", + "SwinUnetrConfig"], +} + + +if is_torch_available(): + _import_structure["modeling_swinunetr"] = [ + "SWINUNETR_PRETRAINED_MODEL_ARCHIVE_LIST", + "SwinUnetrModelForInference", + ] + +if TYPE_CHECKING: + from .configuration_swinunetr import SWINUNETR_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinUnetrConfig + + if is_torch_available(): + from .modeling_bros import ( + SWINUNETR_PRETRAINED_MODEL_ARCHIVE_LIST, + SwinUnetrModelForInference, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, globals()["__file__"], _import_structure + ) diff --git a/SwinUNETR/BTCV/swinunetr/configuration_swinunetr.py b/SwinUNETR/BTCV/swinunetr/configuration_swinunetr.py new file mode 100755 index 00000000..236023dc --- /dev/null +++ b/SwinUNETR/BTCV/swinunetr/configuration_swinunetr.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Swin Unnetr configuration """ + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +SWINUNETR_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "swinunetr-btcv-tiny": "https://huggingface.co/darragh/swinunetr-btcv-tiny/raw/main/config.json", + "swinunetr-btcv-small": "https://huggingface.co/darragh/swinunetr-btcv-small/raw/main/config.json", +} + + +class SwinUnetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.BertModel` or a + :class:`~transformers.TFBertModel`. It is used to instantiate a model according to the specified arguments, + defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration + to that of the BERT `bert-base-uncased `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + img_size: dimension of input image. + in_channels: dimension of input channels. + out_channels: dimension of output channels. + feature_size: dimension of network feature size. + depths: number of layers in each stage. + num_heads: number of attention heads. + norm_name: feature normalization type and arguments. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + dropout_path_rate: drop path rate. + normalize: normalize output intermediate features in each stage. + use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: number of spatial dims. + + Examples:: + + >>> TBD + """ + model_type = "swinunetr" + + def __init__( + self, + architecture= "SwinUNETR", + img_size= 96, + in_channels= 1, + out_channels= 14, + depths= (2, 2, 2, 2), + num_heads= (3, 6, 12, 24), + feature_size= 12, + norm_name= "instance", + drop_rate= 0.0, + attn_drop_rate= 0.0, + dropout_path_rate= 0.0, + normalize= True, + use_checkpoint= False, + spatial_dims= 3, + **kwargs + ): + super().__init__( + + architecture= architecture, + img_size= img_size, + in_channels= in_channels, + out_channels= out_channels, + depths= depths, + num_heads= num_heads, + feature_size= feature_size, + norm_name= norm_name, + drop_rate= drop_rate, + attn_drop_rate= attn_drop_rate, + dropout_path_rate= dropout_path_rate, + normalize= normalize, + use_checkpoint= use_checkpoint, + spatial_dims= spatial_dims, + **kwargs, + ) diff --git a/SwinUNETR/BTCV/swinunetr/modeling_swinunetr.py b/SwinUNETR/BTCV/swinunetr/modeling_swinunetr.py new file mode 100644 index 00000000..cd97b32c --- /dev/null +++ b/SwinUNETR/BTCV/swinunetr/modeling_swinunetr.py @@ -0,0 +1,132 @@ +from transformers.file_utils import ( + ModelOutput, +) + +from transformers.modeling_utils import ( + PreTrainedModel, +) +from transformers.utils import logging + +from .configuration_swinunetr import SwinUnetrConfig +import torch +from torch import nn +from monai.inferers import sliding_window_inference +from monai.networks.nets import SwinUNETR +from monai.utils import BlendMode + +import warnings +from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "darragh/swinunetr-btcv-tiny" +_CONFIG_FOR_DOC = "swinunetrConfig" + +SWINUNETR_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "swinunetr-btcv-tiny", + "swinunetr-btcv-small", +] + +class SwinUnetrPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwinUnetrConfig + base_model_prefix = "swinunetr" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + +class SwinUnetrModelForInference(SwinUnetrPreTrainedModel): + """ + Swin UNETR based on: "Hatamizadeh et al., + Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images + " + Source : https://docs.monai.io/en/stable/_modules/monai/networks/nets/swin_unetr.html + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + + self.config = config + + self.model = SwinUNETR( + img_size= config.img_size, + in_channels= config.in_channels, + out_channels=config. out_channels, + depths= config.depths, + num_heads= config.num_heads, + feature_size= config.feature_size, + norm_name= config.norm_name, + drop_rate= config.drop_rate, + attn_drop_rate= config.attn_drop_rate, + dropout_path_rate= config.dropout_path_rate, + normalize= config.normalize, + use_checkpoint= config.use_checkpoint, + spatial_dims= config.spatial_dims, + ) + + self.init_weights() + + def forward( + self, + inputs: torch.Tensor, + roi_size: Union[Sequence[int], int], + sw_batch_size: int, + overlap: float = 0.25, + mode: Union[BlendMode, str] = BlendMode.CONSTANT + ): + r""" + Sliding window inference on `inputs` with `predictor`. + + The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. + Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. + e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes + could be ([128,64,256], [64,32,128]). + In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still + an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters + so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). + + When roi_size is larger than the inputs' spatial size, the input image are padded during inference. + To maintain the same spatial sizes, the output image will be cropped to the original input size. + + Args: + inputs: input image to be processed (assuming NCHW[D]) + roi_size: the spatial window size for inferences. + When its components have None or non-positives, the corresponding inputs dimension will be used. + if the components of the `roi_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + sw_batch_size: the batch size to run window slices. + overlap: Amount of overlap between scans. + mode: {``"constant"``, ``"gaussian"``} + How to blend output of overlapping windows. Defaults to ``"constant"``. + + - ``"constant``": gives equal weight to all predictions. + - ``"gaussian``": gives less weight to predictions on edges of windows. + kwargs: optional keyword args to be passed to ``predictor``. + + Note: + - input must be channel-first and have a batch dim, supports N-D sliding window. + + """ + + logits = sliding_window_inference(inputs, + roi_size, + sw_batch_size, + self.model, + overlap, + mode) + + return ModelOutput(logits = logits) diff --git a/SwinUNETR/BTCV/videos/__init__.py b/SwinUNETR/BTCV/videos/__init__.py new file mode 100644 index 00000000..e69de29b