|
32 | 32 | }, |
33 | 33 | { |
34 | 34 | "cell_type": "code", |
35 | | - "execution_count": 1, |
| 35 | + "execution_count": 4, |
36 | 36 | "id": "f2e1b91f", |
37 | 37 | "metadata": {}, |
38 | 38 | "outputs": [], |
39 | 39 | "source": [ |
40 | 40 | "# Install monai\n", |
41 | | - "!python -c \"import monai\" || pip install -q \"monai-weekly\"" |
| 41 | + "!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\"" |
42 | 42 | ] |
43 | 43 | }, |
44 | 44 | { |
45 | 45 | "cell_type": "code", |
46 | | - "execution_count": 2, |
| 46 | + "execution_count": 5, |
47 | 47 | "id": "e9cd1b08", |
48 | 48 | "metadata": {}, |
49 | 49 | "outputs": [], |
50 | 50 | "source": [ |
51 | 51 | "# Import libs\n", |
52 | | - "from monai.inferers import SlidingWindowInferer\n", |
| 52 | + "from monai.inferers import SliceInferer\n", |
53 | 53 | "import torch\n", |
54 | | - "from typing import Callable, Any\n", |
55 | 54 | "from monai.networks.nets import UNet" |
56 | 55 | ] |
57 | 56 | }, |
|
60 | 59 | "id": "85f00a47", |
61 | 60 | "metadata": {}, |
62 | 61 | "source": [ |
63 | | - "## Overiding SlidingWindowInferer\n", |
64 | | - "The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`" |
65 | | - ] |
66 | | - }, |
67 | | - { |
68 | | - "cell_type": "code", |
69 | | - "execution_count": 3, |
70 | | - "id": "01f8bfa3", |
71 | | - "metadata": {}, |
72 | | - "outputs": [], |
73 | | - "source": [ |
74 | | - "class YourSlidingWindowInferer(SlidingWindowInferer):\n", |
75 | | - " def __init__(self, spatial_dim: int = 0, *args, **kwargs):\n", |
76 | | - " # Set dim to slice the volume across, for example, `0` could slide over axial slices,\n", |
77 | | - " # `1` over coronal slices\n", |
78 | | - " # and `2` over sagittal slices.\n", |
79 | | - " self.spatial_dim = spatial_dim\n", |
80 | | - "\n", |
81 | | - " super().__init__(*args, **kwargs)\n", |
82 | | - "\n", |
83 | | - " def __call__(\n", |
84 | | - " self,\n", |
85 | | - " inputs: torch.Tensor,\n", |
86 | | - " network: Callable[..., torch.Tensor],\n", |
87 | | - " slice_axis: int = 0,\n", |
88 | | - " *args: Any,\n", |
89 | | - " **kwargs: Any,\n", |
90 | | - " ) -> torch.Tensor:\n", |
91 | | - "\n", |
92 | | - " assert (\n", |
93 | | - " self.spatial_dim < 3\n", |
94 | | - " ), \"`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively\"\n", |
95 | | - "\n", |
96 | | - " # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch\n", |
97 | | - " if len(self.roi_size) != len(inputs.shape[2:]):\n", |
98 | | - "\n", |
99 | | - " # If they mismatch and roi_size is 2D add another dimension to roi size\n", |
100 | | - " if len(self.roi_size) == 2:\n", |
101 | | - " self.roi_size = list(self.roi_size)\n", |
102 | | - " self.roi_size.insert(self.spatial_dim, 1)\n", |
103 | | - " else:\n", |
104 | | - " raise RuntimeError(\n", |
105 | | - " \"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \"\n", |
106 | | - " )\n", |
107 | | - "\n", |
108 | | - " return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n", |
109 | | - "\n", |
110 | | - " def network_wrapper(self, network, x, *args, **kwargs):\n", |
111 | | - " \"\"\"\n", |
112 | | - " Wrapper handles cases where inference needs to be done using\n", |
113 | | - " 2D models over 3D volume inputs.\n", |
114 | | - " \"\"\"\n", |
115 | | - " # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n", |
116 | | - " # be handled accordingly\n", |
117 | | - "\n", |
118 | | - " if self.roi_size[self.spatial_dim] == 1:\n", |
119 | | - " # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n", |
120 | | - " x = x.squeeze(dim=self.spatial_dim + 2)\n", |
121 | | - " out = network(x, *args, **kwargs)\n", |
122 | | - " # Unsqueeze the network output so it is [N, C, D, H, W] as expected by\n", |
123 | | - " # the default SlidingWindowInferer class\n", |
124 | | - " return out.unsqueeze(dim=self.spatial_dim + 2)\n", |
125 | | - "\n", |
126 | | - " else:\n", |
127 | | - " return network(x, *args, **kwargs)" |
| 62 | + "## SliceInferer\n", |
| 63 | + "The simplest way to achieve this functionality is to extend the `SlidingWindowInferer` in `monai.inferers`. This is made available as `SliceInferer` in MONAI (https://docs.monai.io/en/latest/inferers.html#sliceinferer)." |
128 | 64 | ] |
129 | 65 | }, |
130 | 66 | { |
131 | 67 | "cell_type": "markdown", |
132 | 68 | "id": "bb0a63dd", |
133 | 69 | "metadata": {}, |
134 | 70 | "source": [ |
135 | | - "## Testing added functionality\n", |
136 | | - "Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above." |
| 71 | + "## Usage" |
137 | 72 | ] |
138 | 73 | }, |
139 | 74 | { |
140 | 75 | "cell_type": "code", |
141 | | - "execution_count": 4, |
| 76 | + "execution_count": 6, |
142 | 77 | "id": "85b15305", |
143 | 78 | "metadata": {}, |
144 | 79 | "outputs": [ |
| 80 | + { |
| 81 | + "name": "stderr", |
| 82 | + "output_type": "stream", |
| 83 | + "text": [ |
| 84 | + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:00<00:00, 107.33it/s]\n" |
| 85 | + ] |
| 86 | + }, |
| 87 | + { |
| 88 | + "name": "stdout", |
| 89 | + "output_type": "stream", |
| 90 | + "text": [ |
| 91 | + "Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n" |
| 92 | + ] |
| 93 | + }, |
| 94 | + { |
| 95 | + "name": "stderr", |
| 96 | + "output_type": "stream", |
| 97 | + "text": [ |
| 98 | + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:01<00:00, 177.69it/s]\n" |
| 99 | + ] |
| 100 | + }, |
145 | 101 | { |
146 | 102 | "name": "stdout", |
147 | 103 | "output_type": "stream", |
148 | 104 | "text": [ |
149 | | - "Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n", |
150 | 105 | "Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n" |
151 | 106 | ] |
152 | 107 | } |
|
167 | 122 | "# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n", |
168 | 123 | "input_volume = torch.ones(1, 1, 64, 256, 256)\n", |
169 | 124 | "\n", |
170 | | - "# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n", |
171 | | - "axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1)\n", |
| 125 | + "# Create an instance of SliceInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n", |
| 126 | + "axial_inferer = SliceInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1, progress=True)\n", |
172 | 127 | "\n", |
173 | 128 | "output = axial_inferer(input_volume, net)\n", |
174 | 129 | "\n", |
175 | 130 | "# Output is a 3D volume with 2D slices aggregated\n", |
176 | 131 | "print(\"Axial Inferer Output Shape: \", output.shape)\n", |
177 | | - "# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n", |
178 | | - "coronal_inferer = YourSlidingWindowInferer(\n", |
| 132 | + "# Create an instance of SliceInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n", |
| 133 | + "coronal_inferer = SliceInferer(\n", |
179 | 134 | " roi_size=(64, 256),\n", |
180 | 135 | " sw_batch_size=1,\n", |
181 | 136 | " spatial_dim=1, # Spatial dim to slice along is added here\n", |
182 | 137 | " cval=-1,\n", |
| 138 | + " progress=True,\n", |
183 | 139 | ")\n", |
184 | 140 | "\n", |
185 | 141 | "output = coronal_inferer(input_volume, net)\n", |
186 | 142 | "\n", |
187 | 143 | "# Output is a 3D volume with 2D slices aggregated\n", |
188 | 144 | "print(\"Coronal Inferer Output Shape: \", output.shape)" |
189 | 145 | ] |
| 146 | + }, |
| 147 | + { |
| 148 | + "cell_type": "markdown", |
| 149 | + "id": "f2596d86", |
| 150 | + "metadata": {}, |
| 151 | + "source": [ |
| 152 | + "Note that with `axial_inferer` and `coronal_inferer`, the number of inference iterations is 64 and 256 repectively." |
| 153 | + ] |
190 | 154 | } |
191 | 155 | ], |
192 | 156 | "metadata": { |
|
205 | 169 | "name": "python", |
206 | 170 | "nbconvert_exporter": "python", |
207 | 171 | "pygments_lexer": "ipython3", |
208 | | - "version": "3.7.11" |
| 172 | + "version": "3.8.12" |
209 | 173 | } |
210 | 174 | }, |
211 | 175 | "nbformat": 4, |
|
0 commit comments