Skip to content

Commit 4ee9d6b

Browse files
authored
update 3d_classification (#804)
Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent 5399480 commit 4ee9d6b

9 files changed

+42
-59
lines changed

3d_classification/densenet_training_array.ipynb

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@
107107
" RandRotate90,\n",
108108
" Resize,\n",
109109
" ScaleIntensity,\n",
110-
" EnsureType\n",
111110
")\n",
112111
"\n",
113112
"pin_memory = torch.cuda.is_available()\n",
@@ -210,24 +209,24 @@
210209
},
211210
{
212211
"cell_type": "code",
213-
"execution_count": 6,
212+
"execution_count": 5,
214213
"metadata": {},
215214
"outputs": [
216215
{
217216
"name": "stdout",
218217
"output_type": "stream",
219218
"text": [
220-
"<class 'torch.Tensor'> torch.Size([3, 1, 96, 96, 96]) tensor([[1., 0.],\n",
219+
"<class 'monai.data.meta_tensor.MetaTensor'> (3, 1, 96, 96, 96) tensor([[1., 0.],\n",
221220
" [1., 0.],\n",
222221
" [1., 0.]]) torch.Size([3, 2])\n"
223222
]
224223
}
225224
],
226225
"source": [
227226
"# Define transforms\n",
228-
"train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), EnsureType()])\n",
227+
"train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90()])\n",
229228
"\n",
230-
"val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()])\n",
229+
"val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96))])\n",
231230
"\n",
232231
"# Define nifti dataset, data loader\n",
233232
"check_ds = ImageDataset(image_files=images, labels=labels, transform=train_transforms)\n",
@@ -500,9 +499,9 @@
500499
],
501500
"metadata": {
502501
"kernelspec": {
503-
"display_name": "Python [conda env:monai]",
502+
"display_name": "Python 3 (ipykernel)",
504503
"language": "python",
505-
"name": "conda-env-monai-py"
504+
"name": "python3"
506505
},
507506
"language_info": {
508507
"codemirror_mode": {
@@ -514,7 +513,7 @@
514513
"name": "python",
515514
"nbconvert_exporter": "python",
516515
"pygments_lexer": "ipython3",
517-
"version": "3.9.7"
516+
"version": "3.8.13"
518517
}
519518
},
520519
"nbformat": 4,

3d_classification/ignite/densenet_evaluation_array.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
import torch
1818
from ignite.engine import _prepare_batch, create_supervised_evaluator
1919
from ignite.metrics import Accuracy
20-
from torch.utils.data import DataLoader
2120

2221
import monai
23-
from monai.data import ImageDataset
22+
from monai.data import ImageDataset, DataLoader
2423
from monai.handlers import CheckpointLoader, ClassificationSaver, StatsHandler
25-
from monai.transforms import AddChannel, Compose, Resize, ScaleIntensity, EnsureType
24+
from monai.transforms import EnsureChannelFirst, Compose, Resize, ScaleIntensity
2625

2726

2827
def main():
@@ -50,9 +49,9 @@ def main():
5049
labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
5150

5251
# define transforms for image
53-
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()])
52+
val_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96))])
5453
# define image dataset
55-
val_ds = ImageDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False)
54+
val_ds = ImageDataset(image_files=images, labels=labels, transform=val_transforms, image_only=True)
5655
# create DenseNet121
5756
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5857
net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
@@ -78,7 +77,7 @@ def prepare_batch(batch, device=None, non_blocking=False):
7877
# for the array data format, assume the 3rd item of batch data is the meta_data
7978
prediction_saver = ClassificationSaver(
8079
output_dir="tempdir",
81-
batch_transform=lambda batch: batch[2],
80+
batch_transform=lambda batch: batch[0].meta,
8281
output_transform=lambda output: output[0].argmax(1),
8382
)
8483
prediction_saver.attach(evaluator)

3d_classification/ignite/densenet_evaluation_dict.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import torch
1818
from ignite.engine import _prepare_batch, create_supervised_evaluator
1919
from ignite.metrics import Accuracy
20-
from torch.utils.data import DataLoader
2120

2221
import monai
22+
from monai.data import DataLoader
2323
from monai.handlers import CheckpointLoader, ClassificationSaver, StatsHandler
24-
from monai.transforms import AddChanneld, Compose, LoadImaged, Resized, ScaleIntensityd, EnsureTyped
24+
from monai.transforms import Compose, LoadImaged, Resized, ScaleIntensityd
2525

2626

2727
def main():
@@ -52,11 +52,9 @@ def main():
5252
# define transforms for image
5353
val_transforms = Compose(
5454
[
55-
LoadImaged(keys=["img"]),
56-
AddChanneld(keys=["img"]),
55+
LoadImaged(keys=["img"], ensure_channel_first=True),
5756
ScaleIntensityd(keys=["img"]),
5857
Resized(keys=["img"], spatial_size=(96, 96, 96)),
59-
EnsureTyped(keys=["img"]),
6058
]
6159
)
6260

@@ -85,7 +83,7 @@ def prepare_batch(batch, device=None, non_blocking=False):
8583
prediction_saver = ClassificationSaver(
8684
output_dir="tempdir",
8785
name="evaluator",
88-
batch_transform=lambda batch: batch["img_meta_dict"],
86+
batch_transform=lambda batch: batch["img"].meta,
8987
output_transform=lambda output: output[0].argmax(1),
9088
)
9189
prediction_saver.attach(evaluator)

3d_classification/ignite/densenet_training_array.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
1919
from ignite.handlers import EarlyStopping, ModelCheckpoint
2020
from ignite.metrics import Accuracy
21-
from torch.utils.data import DataLoader
2221

2322
import monai
24-
from monai.data import ImageDataset, decollate_batch
23+
from monai.data import ImageDataset, decollate_batch, DataLoader
2524
from monai.handlers import StatsHandler, TensorBoardStatsHandler, stopping_fn_from_metric
26-
from monai.transforms import AddChannel, Compose, RandRotate90, Resize, ScaleIntensity, EnsureType
25+
from monai.transforms import EnsureChannelFirst, Compose, RandRotate90, Resize, ScaleIntensity
2726

2827

2928
def main():
@@ -61,8 +60,8 @@ def main():
6160
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
6261

6362
# define transforms
64-
train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), EnsureType()])
65-
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()])
63+
train_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96)), RandRotate90()])
64+
val_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96))])
6665

6766
# define image dataset, data loader
6867
check_ds = ImageDataset(image_files=images, labels=labels, transform=train_transforms)

3d_classification/ignite/densenet_training_dict.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
import torch
1818
from ignite.engine import Events, _prepare_batch, create_supervised_evaluator, create_supervised_trainer
1919
from ignite.handlers import EarlyStopping, ModelCheckpoint
20-
from torch.utils.data import DataLoader
2120

2221
import monai
23-
from monai.data import decollate_batch
22+
from monai.data import decollate_batch, DataLoader
2423
from monai.handlers import ROCAUC, StatsHandler, TensorBoardStatsHandler, stopping_fn_from_metric
25-
from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, LoadImaged, RandRotate90d, Resized, ScaleIntensityd, EnsureTyped, EnsureType
24+
from monai.transforms import Activations, AsDiscrete, Compose, LoadImaged, RandRotate90d, Resized, ScaleIntensityd
2625

2726

2827
def main():
@@ -64,21 +63,17 @@ def main():
6463
# define transforms for image
6564
train_transforms = Compose(
6665
[
67-
LoadImaged(keys=["img"]),
68-
AddChanneld(keys=["img"]),
66+
LoadImaged(keys=["img"], ensure_channel_first=True),
6967
ScaleIntensityd(keys=["img"]),
7068
Resized(keys=["img"], spatial_size=(96, 96, 96)),
7169
RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
72-
EnsureTyped(keys=["img"]),
7370
]
7471
)
7572
val_transforms = Compose(
7673
[
77-
LoadImaged(keys=["img"]),
78-
AddChanneld(keys=["img"]),
74+
LoadImaged(keys=["img"], ensure_channel_first=True),
7975
ScaleIntensityd(keys=["img"]),
8076
Resized(keys=["img"], spatial_size=(96, 96, 96)),
81-
EnsureTyped(keys=["img"]),
8277
]
8378
)
8479

@@ -126,8 +121,8 @@ def prepare_batch(batch, device=None, non_blocking=False):
126121
# add evaluation metric to the evaluator engine
127122
val_metrics = {metric_name: ROCAUC()}
128123

129-
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])
130-
post_pred = Compose([EnsureType(), Activations(softmax=True)])
124+
post_label = Compose([AsDiscrete(to_onehot=2)])
125+
post_pred = Compose([Activations(softmax=True)])
131126
# Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
132127
# user can add output_transform to return other values
133128
evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch, output_transform=lambda x, y, y_pred: ([post_pred(i) for i in decollate_batch(y_pred)], [post_label(i) for i in decollate_batch(y)]))

3d_classification/torch/densenet_evaluation_array.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515

1616
import numpy as np
1717
import torch
18-
from torch.utils.data import DataLoader
1918

2019
import monai
21-
from monai.data import CSVSaver, ImageDataset
22-
from monai.transforms import AddChannel, Compose, Resize, ScaleIntensity, EnsureType
20+
from monai.data import CSVSaver, ImageDataset, DataLoader
21+
from monai.transforms import AddChannel, Compose, Resize, ScaleIntensity
2322

2423

2524
def main():
@@ -47,10 +46,10 @@ def main():
4746
labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
4847

4948
# Define transforms for image
50-
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()])
49+
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96))])
5150

5251
# Define image dataset
53-
val_ds = ImageDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False)
52+
val_ds = ImageDataset(image_files=images, labels=labels, transform=val_transforms, image_only=True)
5453
# create a validation data loader
5554
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())
5655

@@ -70,7 +69,7 @@ def main():
7069
value = torch.eq(val_outputs, val_labels)
7170
metric_count += len(value)
7271
num_correct += value.sum().item()
73-
saver.save_batch(val_outputs, val_data[2])
72+
saver.save_batch(val_outputs, val_images.meta)
7473
metric = num_correct / metric_count
7574
print("evaluation metric:", metric)
7675
saver.finalize()

3d_classification/torch/densenet_evaluation_dict.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515

1616
import numpy as np
1717
import torch
18-
from torch.utils.data import DataLoader
1918

2019
import monai
21-
from monai.data import CSVSaver
22-
from monai.transforms import AddChanneld, Compose, LoadImaged, Resized, ScaleIntensityd, EnsureTyped
20+
from monai.data import CSVSaver, DataLoader
21+
from monai.transforms import AddChanneld, Compose, LoadImaged, Resized, ScaleIntensityd
2322

2423

2524
def main():
@@ -54,7 +53,6 @@ def main():
5453
AddChanneld(keys=["img"]),
5554
ScaleIntensityd(keys=["img"]),
5655
Resized(keys=["img"], spatial_size=(96, 96, 96)),
57-
EnsureTyped(keys=["img"]),
5856
]
5957
)
6058

@@ -78,7 +76,7 @@ def main():
7876
value = torch.eq(val_outputs, val_labels)
7977
metric_count += len(value)
8078
num_correct += value.sum().item()
81-
saver.save_batch(val_outputs, val_data["img_meta_dict"])
79+
saver.save_batch(val_outputs, val_data["img"].meta)
8280
metric = num_correct / metric_count
8381
print("evaluation metric:", metric)
8482
saver.finalize()

3d_classification/torch/densenet_training_array.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515

1616
import numpy as np
1717
import torch
18-
from torch.utils.data import DataLoader
1918
from torch.utils.tensorboard import SummaryWriter
2019

2120
import monai
22-
from monai.data import ImageDataset
23-
from monai.transforms import AddChannel, Compose, RandRotate90, Resize, ScaleIntensity, EnsureType
21+
from monai.data import ImageDataset, DataLoader
22+
from monai.transforms import AddChannel, Compose, RandRotate90, Resize, ScaleIntensity
2423

2524

2625
def main():
@@ -58,8 +57,8 @@ def main():
5857
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
5958

6059
# Define transforms
61-
train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), EnsureType()])
62-
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()])
60+
train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90()])
61+
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96))])
6362

6463
# Define image dataset, data loader
6564
check_ds = ImageDataset(image_files=images, labels=labels, transform=train_transforms)

3d_classification/torch/densenet_training_dict.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515

1616
import numpy as np
1717
import torch
18-
from torch.utils.data import DataLoader
1918
from torch.utils.tensorboard import SummaryWriter
2019

2120
import monai
22-
from monai.data import decollate_batch
21+
from monai.data import decollate_batch, DataLoader
2322
from monai.metrics import ROCAUCMetric
24-
from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, LoadImaged, RandRotate90d, Resized, ScaleIntensityd, EnsureTyped, EnsureType
23+
from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, LoadImaged, RandRotate90d, Resized, ScaleIntensityd
2524

2625

2726
def main():
@@ -68,7 +67,6 @@ def main():
6867
ScaleIntensityd(keys=["img"]),
6968
Resized(keys=["img"], spatial_size=(96, 96, 96)),
7069
RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
71-
EnsureTyped(keys=["img"]),
7270
]
7371
)
7472
val_transforms = Compose(
@@ -77,11 +75,10 @@ def main():
7775
AddChanneld(keys=["img"]),
7876
ScaleIntensityd(keys=["img"]),
7977
Resized(keys=["img"], spatial_size=(96, 96, 96)),
80-
EnsureTyped(keys=["img"]),
8178
]
8279
)
83-
post_pred = Compose([EnsureType(), Activations(softmax=True)])
84-
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])
80+
post_pred = Compose([Activations(softmax=True)])
81+
post_label = Compose([AsDiscrete(to_onehot=2)])
8582

8683
# Define dataset, data loader
8784
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

0 commit comments

Comments
 (0)