|
129 | 129 | "from monai.networks.blocks import Warp\n", |
130 | 130 | "from monai.networks.nets import LocalNet\n", |
131 | 131 | "from monai.transforms import (\n", |
132 | | - " AddChanneld,\n", |
133 | 132 | " Compose,\n", |
134 | 133 | " LoadImaged,\n", |
135 | 134 | " RandAffined,\n", |
136 | 135 | " Resized,\n", |
137 | 136 | " ScaleIntensityRanged,\n", |
138 | | - " EnsureTyped,\n", |
139 | 137 | ")\n", |
140 | 138 | "from monai.utils import set_determinism, first\n", |
141 | 139 | "\n", |
|
302 | 300 | "source": [ |
303 | 301 | "## Setup transforms for training and validation\n", |
304 | 302 | "Here we use several transforms to augment the dataset:\n", |
305 | | - "1. LoadImaged loads the lung CT images and labels from NIfTI format files.\n", |
306 | | - "2. AddChanneld as the original data doesn't have channel dim, add 1 dim to construct \"channel first\" shape.\n", |
307 | | - "3. ScaleIntensityRanged extracts intensity range [-285, 3770] and scales to [0, 1].\n", |
308 | | - "4. RandAffined efficiently performs rotate, scale, shear, translate, etc. together based on PyTorch affine transform.\n", |
309 | | - "5. Resized resize images to the same size.\n", |
310 | | - "6. EnsureTyped converts the numpy array to PyTorch Tensor for further steps." |
| 303 | + "1. LoadImaged loads the lung CT images and labels from NIfTI format files. \"ensure_channel_first=True\" ensure that the first dim is channel.\n", |
| 304 | + "2. ScaleIntensityRanged extracts intensity range [-285, 3770] and scales to [0, 1].\n", |
| 305 | + "3. RandAffined efficiently performs rotate, scale, shear, translate, etc. together based on PyTorch affine transform.\n", |
| 306 | + "4. Resized resize images to the same size." |
311 | 307 | ] |
312 | 308 | }, |
313 | 309 | { |
|
324 | 320 | "train_transforms = Compose(\n", |
325 | 321 | " [\n", |
326 | 322 | " LoadImaged(\n", |
327 | | - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
328 | | - " ),\n", |
329 | | - " AddChanneld(\n", |
330 | | - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
| 323 | + " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"],\n", |
| 324 | + " ensure_channel_first=True\n", |
331 | 325 | " ),\n", |
332 | 326 | " ScaleIntensityRanged(\n", |
333 | 327 | " keys=[\"fixed_image\", \"moving_image\"],\n", |
|
345 | 339 | " align_corners=(True, True, None, None),\n", |
346 | 340 | " spatial_size=(96, 96, 104)\n", |
347 | 341 | " ),\n", |
348 | | - " EnsureTyped(\n", |
349 | | - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
350 | | - " ),\n", |
351 | 342 | " ]\n", |
352 | 343 | ")\n", |
353 | 344 | "val_transforms = Compose(\n", |
354 | 345 | " [\n", |
355 | 346 | " LoadImaged(\n", |
356 | | - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
357 | | - " ),\n", |
358 | | - " AddChanneld(\n", |
359 | | - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
| 347 | + " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"],\n", |
| 348 | + " ensure_channel_first=True\n", |
360 | 349 | " ),\n", |
361 | 350 | " ScaleIntensityRanged(\n", |
362 | 351 | " keys=[\"fixed_image\", \"moving_image\"],\n", |
|
369 | 358 | " align_corners=(True, True, None, None),\n", |
370 | 359 | " spatial_size=(96, 96, 104)\n", |
371 | 360 | " ),\n", |
372 | | - " EnsureTyped(\n", |
373 | | - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
374 | | - " ),\n", |
375 | 361 | " ]\n", |
376 | 362 | ")" |
377 | 363 | ] |
|
693 | 679 | "\n", |
694 | 680 | " val_ddf, val_pred_image, val_pred_label = forward(\n", |
695 | 681 | " val_data, model)\n", |
| 682 | + " val_pred_label[val_pred_label > 1] = 1\n", |
696 | 683 | "\n", |
697 | 684 | " val_fixed_image = val_data[\"fixed_image\"].to(device)\n", |
698 | 685 | " val_fixed_label = val_data[\"fixed_label\"].to(device)\n", |
|
723 | 710 | " optimizer.zero_grad()\n", |
724 | 711 | "\n", |
725 | 712 | " ddf, pred_image, pred_label = forward(batch_data, model)\n", |
| 713 | + " pred_label[pred_label > 1] = 1\n", |
726 | 714 | "\n", |
727 | 715 | " fixed_image = batch_data[\"fixed_image\"].to(device)\n", |
728 | 716 | " fixed_label = batch_data[\"fixed_label\"].to(device)\n", |
|
1281 | 1269 | "name": "python", |
1282 | 1270 | "nbconvert_exporter": "python", |
1283 | 1271 | "pygments_lexer": "ipython3", |
1284 | | - "version": "3.8.12" |
| 1272 | + "version": "3.8.13" |
1285 | 1273 | } |
1286 | 1274 | }, |
1287 | 1275 | "nbformat": 4, |
|
0 commit comments