Skip to content

Commit 5399480

Browse files
authored
update 2d_segmentation (#803)
1 parent 53ae072 commit 5399480

File tree

4 files changed

+14
-26
lines changed

4 files changed

+14
-26
lines changed

2d_segmentation/torch/unet_evaluation_array.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717

1818
import torch
1919
from PIL import Image
20-
from torch.utils.data import DataLoader
2120

2221
from monai import config
23-
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch
22+
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch, DataLoader
2423
from monai.inferers import sliding_window_inference
2524
from monai.metrics import DiceMetric
2625
from monai.networks.nets import UNet
27-
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, LoadImage, SaveImage, ScaleIntensity, EnsureType
26+
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, LoadImage, SaveImage, ScaleIntensity
2827

2928

3029
def main(tempdir):
@@ -41,13 +40,13 @@ def main(tempdir):
4140
segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
4241

4342
# define transforms for image and segmentation
44-
imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
45-
segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
43+
imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity()])
44+
segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity()])
4645
val_ds = ArrayDataset(images, imtrans, segs, segtrans)
4746
# sliding window inference for one image at every iteration
4847
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
4948
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
50-
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
49+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
5150
saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg")
5251
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5352
model = UNet(

2d_segmentation/torch/unet_evaluation_dict.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717

1818
import torch
1919
from PIL import Image
20-
from torch.utils.data import DataLoader
2120

2221
import monai
23-
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
22+
from monai.data import create_test_image_2d, list_data_collate, decollate_batch, DataLoader
2423
from monai.inferers import sliding_window_inference
2524
from monai.metrics import DiceMetric
2625
from monai.networks.nets import UNet
27-
from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd, EnsureTyped, EnsureType
26+
from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd
2827

2928

3029
def main(tempdir):
@@ -47,14 +46,13 @@ def main(tempdir):
4746
LoadImaged(keys=["img", "seg"]),
4847
AddChanneld(keys=["img", "seg"]),
4948
ScaleIntensityd(keys=["img", "seg"]),
50-
EnsureTyped(keys=["img", "seg"]),
5149
]
5250
)
5351
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
5452
# sliding window inference need to input 1 image in every iteration
5553
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
5654
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
57-
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
55+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
5856
saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg")
5957
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6058
model = UNet(

2d_segmentation/torch/unet_training_array.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717

1818
import torch
1919
from PIL import Image
20-
from torch.utils.data import DataLoader
2120
from torch.utils.tensorboard import SummaryWriter
2221

2322
import monai
24-
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch
23+
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch, DataLoader
2524
from monai.inferers import sliding_window_inference
2625
from monai.metrics import DiceMetric
2726
from monai.transforms import (
@@ -33,7 +32,6 @@
3332
RandRotate90,
3433
RandSpatialCrop,
3534
ScaleIntensity,
36-
EnsureType,
3735
)
3836
from monai.visualize import plot_2d_or_3d_image
3937

@@ -60,7 +58,6 @@ def main(tempdir):
6058
ScaleIntensity(),
6159
RandSpatialCrop((96, 96), random_size=False),
6260
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
63-
EnsureType(),
6461
]
6562
)
6663
train_segtrans = Compose(
@@ -70,11 +67,10 @@ def main(tempdir):
7067
ScaleIntensity(),
7168
RandSpatialCrop((96, 96), random_size=False),
7269
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
73-
EnsureType(),
7470
]
7571
)
76-
val_imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
77-
val_segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
72+
val_imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity()])
73+
val_segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity()])
7874

7975
# define array dataset, data loader
8076
check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
@@ -89,7 +85,7 @@ def main(tempdir):
8985
val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
9086
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
9187
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
92-
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
88+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
9389
# create UNet, DiceLoss and Adam optimizer
9490
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9591
model = monai.networks.nets.UNet(

2d_segmentation/torch/unet_training_dict.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717

1818
import torch
1919
from PIL import Image
20-
from torch.utils.data import DataLoader
2120
from torch.utils.tensorboard import SummaryWriter
2221

2322
import monai
24-
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
23+
from monai.data import create_test_image_2d, list_data_collate, decollate_batch, DataLoader
2524
from monai.inferers import sliding_window_inference
2625
from monai.metrics import DiceMetric
2726
from monai.transforms import (
@@ -33,8 +32,6 @@
3332
RandCropByPosNegLabeld,
3433
RandRotate90d,
3534
ScaleIntensityd,
36-
EnsureTyped,
37-
EnsureType,
3835
)
3936
from monai.visualize import plot_2d_or_3d_image
4037

@@ -65,15 +62,13 @@ def main(tempdir):
6562
keys=["img", "seg"], label_key="seg", spatial_size=[96, 96], pos=1, neg=1, num_samples=4
6663
),
6764
RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 1]),
68-
EnsureTyped(keys=["img", "seg"]),
6965
]
7066
)
7167
val_transforms = Compose(
7268
[
7369
LoadImaged(keys=["img", "seg"]),
7470
AddChanneld(keys=["img", "seg"]),
7571
ScaleIntensityd(keys=["img", "seg"]),
76-
EnsureTyped(keys=["img", "seg"]),
7772
]
7873
)
7974

@@ -99,7 +94,7 @@ def main(tempdir):
9994
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
10095
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
10196
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
102-
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
97+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
10398
# create UNet, DiceLoss and Adam optimizer
10499
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
105100
model = monai.networks.nets.UNet(

0 commit comments

Comments
 (0)