|
17 | 17 | import torch |
18 | 18 | from ignite.engine import Events, _prepare_batch, create_supervised_evaluator, create_supervised_trainer |
19 | 19 | from ignite.handlers import EarlyStopping, ModelCheckpoint |
20 | | -from torch.utils.data import DataLoader |
21 | 20 |
|
22 | 21 | import monai |
23 | | -from monai.data import decollate_batch |
| 22 | +from monai.data import decollate_batch, DataLoader |
24 | 23 | from monai.handlers import ROCAUC, StatsHandler, TensorBoardStatsHandler, stopping_fn_from_metric |
25 | | -from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, LoadImaged, RandRotate90d, Resized, ScaleIntensityd, EnsureTyped, EnsureType |
| 24 | +from monai.transforms import Activations, AsDiscrete, Compose, LoadImaged, RandRotate90d, Resized, ScaleIntensityd |
26 | 25 |
|
27 | 26 |
|
28 | 27 | def main(): |
@@ -64,21 +63,17 @@ def main(): |
64 | 63 | # define transforms for image |
65 | 64 | train_transforms = Compose( |
66 | 65 | [ |
67 | | - LoadImaged(keys=["img"]), |
68 | | - AddChanneld(keys=["img"]), |
| 66 | + LoadImaged(keys=["img"], ensure_channel_first=True), |
69 | 67 | ScaleIntensityd(keys=["img"]), |
70 | 68 | Resized(keys=["img"], spatial_size=(96, 96, 96)), |
71 | 69 | RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]), |
72 | | - EnsureTyped(keys=["img"]), |
73 | 70 | ] |
74 | 71 | ) |
75 | 72 | val_transforms = Compose( |
76 | 73 | [ |
77 | | - LoadImaged(keys=["img"]), |
78 | | - AddChanneld(keys=["img"]), |
| 74 | + LoadImaged(keys=["img"], ensure_channel_first=True), |
79 | 75 | ScaleIntensityd(keys=["img"]), |
80 | 76 | Resized(keys=["img"], spatial_size=(96, 96, 96)), |
81 | | - EnsureTyped(keys=["img"]), |
82 | 77 | ] |
83 | 78 | ) |
84 | 79 |
|
@@ -126,8 +121,8 @@ def prepare_batch(batch, device=None, non_blocking=False): |
126 | 121 | # add evaluation metric to the evaluator engine |
127 | 122 | val_metrics = {metric_name: ROCAUC()} |
128 | 123 |
|
129 | | - post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) |
130 | | - post_pred = Compose([EnsureType(), Activations(softmax=True)]) |
| 124 | + post_label = Compose([AsDiscrete(to_onehot=2)]) |
| 125 | + post_pred = Compose([Activations(softmax=True)]) |
131 | 126 | # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, |
132 | 127 | # user can add output_transform to return other values |
133 | 128 | evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch, output_transform=lambda x, y, y_pred: ([post_pred(i) for i in decollate_batch(y_pred)], [post_label(i) for i in decollate_batch(y)])) |
|
0 commit comments