diff --git a/SwinUNETR/BTCV/README.md b/SwinUNETR/BTCV/README.md
index 623f9b5e..73c5532f 100644
--- a/SwinUNETR/BTCV/README.md
+++ b/SwinUNETR/BTCV/README.md
@@ -15,6 +15,23 @@ Dependencies can be installed using:
pip install -r requirements.txt
```
+# Huggingface inference API
+
+To install necessary dependencies, run the below in bash.
+```
+git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc
+pip install -r pmrc/requirements.txt
+cd pmrc/SwinUNETR/BTCV
+```
+
+To load the model from the hub.
+```
+from swinunetr import SwinUnetrModelForInference
+model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
+```
+
+You can also use `predict.py` to run inference for sample dicom medical images.
+
# Models
Please download the self-supervised pre-trained weights for Swin UNETR backbone (CVPR paper [1]) from this link.
diff --git a/SwinUNETR/BTCV/dataset/imagesSampleTs/img0061.nii.gz b/SwinUNETR/BTCV/dataset/imagesSampleTs/img0061.nii.gz
new file mode 100755
index 00000000..c722cdbe
Binary files /dev/null and b/SwinUNETR/BTCV/dataset/imagesSampleTs/img0061.nii.gz differ
diff --git a/SwinUNETR/BTCV/dataset/imagesSampleTs/img0062.nii.gz b/SwinUNETR/BTCV/dataset/imagesSampleTs/img0062.nii.gz
new file mode 100755
index 00000000..042818d0
Binary files /dev/null and b/SwinUNETR/BTCV/dataset/imagesSampleTs/img0062.nii.gz differ
diff --git a/SwinUNETR/BTCV/predict.py b/SwinUNETR/BTCV/predict.py
new file mode 100755
index 00000000..a3488945
--- /dev/null
+++ b/SwinUNETR/BTCV/predict.py
@@ -0,0 +1,98 @@
+import os
+import glob
+import shutil
+import torch
+import argparse
+import cv2
+import mediapy
+import numpy as np
+from skimage import color, img_as_ubyte
+from monai import transforms, data
+from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig
+
+parser = argparse.ArgumentParser(description='Swin UNETR segmentation pipeline')
+parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='device for model - cpu/gpu')
+parser.add_argument('--a_min', default=-175.0, type=float, help='a_min in ScaleIntensityRanged')
+parser.add_argument('--a_max', default=250.0, type=float, help='a_max in ScaleIntensityRanged')
+parser.add_argument('--b_min', default=0.0, type=float, help='b_min in ScaleIntensityRanged')
+parser.add_argument('--b_max', default=1.0, type=float, help='b_max in ScaleIntensityRanged')
+parser.add_argument('--infer_overlap', default=0.5, type=float, help='sliding window inference overlap')
+parser.add_argument('--space_x', default=1.5, type=float, help='spacing in x direction')
+parser.add_argument('--space_y', default=1.5, type=float, help='spacing in y direction')
+parser.add_argument('--space_z', default=2.0, type=float, help='spacing in z direction')
+parser.add_argument('--roi_x', default=96, type=int, help='roi size in x direction')
+parser.add_argument('--roi_y', default=96, type=int, help='roi size in y direction')
+parser.add_argument('--roi_z', default=96, type=int, help='roi size in z direction')
+parser.add_argument('--last_n_frames', default=64, type=int, help='Limit the frames inference. -1 for all frames.')
+args = parser.parse_args()
+
+ffmpeg_path = shutil.which('ffmpeg')
+mediapy.set_ffmpeg(ffmpeg_path)
+
+model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
+model.eval()
+model.to(args.device)
+
+test_files = glob.glob('dataset/imagesSampleTs/*.nii.gz')
+test_files = [{'image': f} for f in test_files]
+
+test_transform = transforms.Compose(
+ [
+ transforms.LoadImaged(keys=["image"]),
+ transforms.AddChanneld(keys=["image"]),
+ transforms.Spacingd(keys="image",
+ pixdim=(args.space_x, args.space_y, args.space_z),
+ mode="bilinear"),
+ transforms.ScaleIntensityRanged(keys=["image"],
+ a_min=args.a_min,
+ a_max=args.a_max,
+ b_min=args.b_min,
+ b_max=args.b_max,
+ clip=True),
+ #transforms.Resized(keys=["image"], spatial_size = (256,256,-1)),
+ transforms.ToTensord(keys=["image"]),
+ ])
+
+test_ds = test_transform(test_files)
+test_loader = data.DataLoader(test_ds,
+ batch_size=1,
+ shuffle=False)
+
+for i, batch in enumerate(test_loader):
+
+ tst_inputs = batch["image"]
+ if args.last_n_frames>0:
+ tst_inputs = tst_inputs[:,:,:,:,-args.last_n_frames:]
+
+ with torch.no_grad():
+ outputs = model(tst_inputs,
+ (args.roi_x,
+ args.roi_y,
+ args.roi_z),
+ 8,
+ overlap=args.infer_overlap,
+ mode="gaussian")
+
+ tst_outputs = torch.softmax(outputs.logits, 1)
+ tst_outputs = torch.argmax(tst_outputs, axis=1)
+
+ fnames = batch['image_meta_dict']['filename_or_obj']
+
+ # Write frames to video
+
+ for fname, inp, outp in zip(fnames, tst_inputs, tst_outputs):
+
+ dicom_name = fname.split('/')[-1]
+ video_name = f'videos/{dicom_name}.mp4'
+ frames = []
+ for idx in range(inp.shape[-1]):
+ # Segmentation
+ seg = outp[:,:,idx].numpy().astype(np.uint8)
+ # Input dicom frame
+ img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8)
+ img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
+ frame = color.label2rgb(seg,img, bg_label = 0)
+ frame = img_as_ubyte(frame)
+ frame = np.concatenate((img, frame), 1)
+ frames.append(frame)
+ mediapy.write_video(video_name, frames, fps=4)
diff --git a/SwinUNETR/BTCV/requirements.txt b/SwinUNETR/BTCV/requirements.txt
index 2718cc44..49de077a 100644
--- a/SwinUNETR/BTCV/requirements.txt
+++ b/SwinUNETR/BTCV/requirements.txt
@@ -1,6 +1,12 @@
+transformers==4.20.1
+torch==1.10.0
+
git+https://github.com/Project-MONAI/MONAI#egg.gitmonai@0.8.1+271.g07de215c
nibabel==3.1.1
tqdm==4.59.0
einops==0.4.1
tensorboardX==2.1
-scipy==1.2.1
+scipy==1.5.0
+mediapy==1.0.3
+scikit-image==0.17.2
+opencv-python==4.6.0.66
diff --git a/SwinUNETR/BTCV/swinunetr/__init__.py b/SwinUNETR/BTCV/swinunetr/__init__.py
new file mode 100755
index 00000000..fd14c007
--- /dev/null
+++ b/SwinUNETR/BTCV/swinunetr/__init__.py
@@ -0,0 +1,53 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2021 NAVER CLOVA Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from transformers.file_utils import (
+ _LazyModule,
+ is_torch_available,
+)
+
+_import_structure = {
+ "configuration_swinunetr": ["SWINUNETR_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "SwinUnetrConfig"],
+}
+
+
+if is_torch_available():
+ _import_structure["modeling_swinunetr"] = [
+ "SWINUNETR_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "SwinUnetrModelForInference",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_swinunetr import SWINUNETR_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinUnetrConfig
+
+ if is_torch_available():
+ from .modeling_bros import (
+ SWINUNETR_PRETRAINED_MODEL_ARCHIVE_LIST,
+ SwinUnetrModelForInference,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__, globals()["__file__"], _import_structure
+ )
diff --git a/SwinUNETR/BTCV/swinunetr/configuration_swinunetr.py b/SwinUNETR/BTCV/swinunetr/configuration_swinunetr.py
new file mode 100755
index 00000000..236023dc
--- /dev/null
+++ b/SwinUNETR/BTCV/swinunetr/configuration_swinunetr.py
@@ -0,0 +1,94 @@
+# coding=utf-8
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Swin Unnetr configuration """
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+SWINUNETR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "swinunetr-btcv-tiny": "https://huggingface.co/darragh/swinunetr-btcv-tiny/raw/main/config.json",
+ "swinunetr-btcv-small": "https://huggingface.co/darragh/swinunetr-btcv-small/raw/main/config.json",
+}
+
+
+class SwinUnetrConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a :class:`~transformers.BertModel` or a
+ :class:`~transformers.TFBertModel`. It is used to instantiate a model according to the specified arguments,
+ defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
+ to that of the BERT `bert-base-uncased `__ architecture.
+
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
+
+
+ Args:
+ img_size: dimension of input image.
+ in_channels: dimension of input channels.
+ out_channels: dimension of output channels.
+ feature_size: dimension of network feature size.
+ depths: number of layers in each stage.
+ num_heads: number of attention heads.
+ norm_name: feature normalization type and arguments.
+ drop_rate: dropout rate.
+ attn_drop_rate: attention dropout rate.
+ dropout_path_rate: drop path rate.
+ normalize: normalize output intermediate features in each stage.
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
+ spatial_dims: number of spatial dims.
+
+ Examples::
+
+ >>> TBD
+ """
+ model_type = "swinunetr"
+
+ def __init__(
+ self,
+ architecture= "SwinUNETR",
+ img_size= 96,
+ in_channels= 1,
+ out_channels= 14,
+ depths= (2, 2, 2, 2),
+ num_heads= (3, 6, 12, 24),
+ feature_size= 12,
+ norm_name= "instance",
+ drop_rate= 0.0,
+ attn_drop_rate= 0.0,
+ dropout_path_rate= 0.0,
+ normalize= True,
+ use_checkpoint= False,
+ spatial_dims= 3,
+ **kwargs
+ ):
+ super().__init__(
+
+ architecture= architecture,
+ img_size= img_size,
+ in_channels= in_channels,
+ out_channels= out_channels,
+ depths= depths,
+ num_heads= num_heads,
+ feature_size= feature_size,
+ norm_name= norm_name,
+ drop_rate= drop_rate,
+ attn_drop_rate= attn_drop_rate,
+ dropout_path_rate= dropout_path_rate,
+ normalize= normalize,
+ use_checkpoint= use_checkpoint,
+ spatial_dims= spatial_dims,
+ **kwargs,
+ )
diff --git a/SwinUNETR/BTCV/swinunetr/modeling_swinunetr.py b/SwinUNETR/BTCV/swinunetr/modeling_swinunetr.py
new file mode 100644
index 00000000..cd97b32c
--- /dev/null
+++ b/SwinUNETR/BTCV/swinunetr/modeling_swinunetr.py
@@ -0,0 +1,132 @@
+from transformers.file_utils import (
+ ModelOutput,
+)
+
+from transformers.modeling_utils import (
+ PreTrainedModel,
+)
+from transformers.utils import logging
+
+from .configuration_swinunetr import SwinUnetrConfig
+import torch
+from torch import nn
+from monai.inferers import sliding_window_inference
+from monai.networks.nets import SwinUNETR
+from monai.utils import BlendMode
+
+import warnings
+from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "darragh/swinunetr-btcv-tiny"
+_CONFIG_FOR_DOC = "swinunetrConfig"
+
+SWINUNETR_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "swinunetr-btcv-tiny",
+ "swinunetr-btcv-small",
+]
+
+class SwinUnetrPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = SwinUnetrConfig
+ base_model_prefix = "swinunetr"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+class SwinUnetrModelForInference(SwinUnetrPreTrainedModel):
+ """
+ Swin UNETR based on: "Hatamizadeh et al.,
+ Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
+ "
+ Source : https://docs.monai.io/en/stable/_modules/monai/networks/nets/swin_unetr.html
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+
+ self.config = config
+
+ self.model = SwinUNETR(
+ img_size= config.img_size,
+ in_channels= config.in_channels,
+ out_channels=config. out_channels,
+ depths= config.depths,
+ num_heads= config.num_heads,
+ feature_size= config.feature_size,
+ norm_name= config.norm_name,
+ drop_rate= config.drop_rate,
+ attn_drop_rate= config.attn_drop_rate,
+ dropout_path_rate= config.dropout_path_rate,
+ normalize= config.normalize,
+ use_checkpoint= config.use_checkpoint,
+ spatial_dims= config.spatial_dims,
+ )
+
+ self.init_weights()
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ roi_size: Union[Sequence[int], int],
+ sw_batch_size: int,
+ overlap: float = 0.25,
+ mode: Union[BlendMode, str] = BlendMode.CONSTANT
+ ):
+ r"""
+ Sliding window inference on `inputs` with `predictor`.
+
+ The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
+ Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
+ e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
+ could be ([128,64,256], [64,32,128]).
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
+ an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
+ so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
+
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
+
+ Args:
+ inputs: input image to be processed (assuming NCHW[D])
+ roi_size: the spatial window size for inferences.
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
+ if the components of the `roi_size` are non-positive values, the transform will use the
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
+ sw_batch_size: the batch size to run window slices.
+ overlap: Amount of overlap between scans.
+ mode: {``"constant"``, ``"gaussian"``}
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
+
+ - ``"constant``": gives equal weight to all predictions.
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
+ kwargs: optional keyword args to be passed to ``predictor``.
+
+ Note:
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
+
+ """
+
+ logits = sliding_window_inference(inputs,
+ roi_size,
+ sw_batch_size,
+ self.model,
+ overlap,
+ mode)
+
+ return ModelOutput(logits = logits)
diff --git a/SwinUNETR/BTCV/videos/__init__.py b/SwinUNETR/BTCV/videos/__init__.py
new file mode 100644
index 00000000..e69de29b