|
94 | 94 | "from monai.networks.nets import UNet\n", |
95 | 95 | "from monai.transforms import (\n", |
96 | 96 | " Activationsd,\n", |
97 | | - " AsChannelFirstd,\n", |
| 97 | + " EnsureChannelFirstd,\n", |
98 | 98 | " AsDiscreted,\n", |
99 | 99 | " Compose,\n", |
100 | 100 | " KeepLargestConnectedComponentd,\n", |
|
119 | 119 | "First of all, let's take a look at the possible data shape in `engine.state.batch` and `engine.state.output`.\n", |
120 | 120 | "\n", |
121 | 121 | "### engine.state.batch\n", |
122 | | - "(1) For a common ignite program, `batch` is usually the iterable output of PyTorch DataLoader, for example: `{\"image\": Tensor, \"label\" Tensor, \"image_meta_dict\": Dict}` where `image` and `label` are batch-first arrays, `image_meta_dict` is a dictionary of meta information for the input images, every item is a batch:\n", |
| 122 | + "(1) For a common ignite program, `batch` is usually the iterable output of PyTorch DataLoader, for example: `{\"image\": MetaTensor, \"label\" MetaTensor, \"image_meta_dict\": Dict}` where `image` and `label` are batch-first arrays, `image_meta_dict` is a dictionary of meta information for the input images, every item is a batch:\n", |
123 | 123 | "```\n", |
124 | 124 | "image.shape = [2, 4, 64, 64, 64] # here 2 is batch size, 4 is channels\n", |
125 | 125 | "label.shape = [2, 3, 64, 64, 64]\n", |
|
129 | 129 | "(2) For MONAI engines, it will automatically `decollate` the batch data into a list of `channel-first` data after every iteration. For more details about `decollate`, please refer to: https://github.com/Project-MONAI/tutorials/blob/main/modules/decollate_batch.ipynb.\n", |
130 | 130 | "\n", |
131 | 131 | "The `engine.state.batch` example in (1) will be decollated into a list of dictionaries:\n", |
132 | | - "`[{\"image\": Tensor, \"label\" Tensor, \"image_meta_dict\": Dict}, {\"image\": Tensor, \"label\" Tensor, \"image_meta_dict\": Dict}]`.\n", |
| 132 | + "`[{\"image\": MetaTensor, \"label\" MetaTensor, \"image_meta_dict\": Dict}, {\"image\": MetaTensor, \"label\" MetaTensor, \"image_meta_dict\": Dict}]`.\n", |
133 | 133 | "\n", |
134 | 134 | "each item of the list can be:\n", |
135 | 135 | "```\n", |
|
139 | 139 | "```\n", |
140 | 140 | "\n", |
141 | 141 | "### engine.state.output\n", |
142 | | - "(1) For a common ignite program, `output` is usually the output data of current iteration, for example: `{\"pred\": Tensor, \"label\": Tensor, \"loss\": scalar}` where `pred` and `label` are batch-first arrays, `loss` is a scalar value of current iteration:\n", |
| 142 | + "(1) For a common ignite program, `output` is usually the output data of current iteration, for example: `{\"pred\": MetaTensor, \"label\": MetaTensor, \"loss\": scalar}` where `pred` and `label` are batch-first arrays, `loss` is a scalar value of current iteration:\n", |
143 | 143 | "```\n", |
144 | 144 | "pred.shape = [2, 3, 64, 64, 64] # here 2 is batch size, 3 is channels\n", |
145 | 145 | "label.shape = [2, 3, 64, 64, 64]\n", |
|
148 | 148 | "\n", |
149 | 149 | "(2) For MONAI engines, it will also automatically `decollate` the output data into a list of `channel-first` data after every iteration.\n", |
150 | 150 | "The `engine.state.output` example in (1) will be decollated into a list of dictionaries:\n", |
151 | | - "`[{\"pred\": Tensor, \"label\": Tensor, \"loss\" 0.4534}, {\"pred\": Tensor, \"label\": Tensor, \"loss\" 0.4534}]`. Please note that it replicated the scalar value of `loss` to every item of the decollated list." |
| 151 | + "`[{\"pred\": MetaTensor, \"label\": MetaTensor, \"loss\" 0.4534}, {\"pred\": MetaTensor, \"label\": MetaTensor, \"loss\" 0.4534}]`. Please note that it replicated the scalar value of `loss` to every item of the decollated list." |
152 | 152 | ] |
153 | 153 | }, |
154 | 154 | { |
|
159 | 159 | "\n", |
160 | 160 | "Now let's analyze the cases of extracting data from `engine.state.batch` or `engine.state.output`. To simplify the operation, we developed a utility function `monai.handlers.from_engine` to automatically handle all the common cases.\n", |
161 | 161 | "\n", |
162 | | - "(1) To get the meta data from dictionary format `engine.state.batch`, set arg `batch_transform=lambda x: x[\"image_meta_dict\"]`.\n", |
| 162 | + "(1) To get the meta data from dictionary format `engine.state.batch`, set arg `batch_transform=lambda x: x.meta`.\n", |
163 | 163 | "\n", |
164 | | - "(2) To get the meta data from decollated list of dictionaries `engine.state.batch`, set arg `lambda x: [i[\"image_meta_dict\"] for i in x]` or `from_engine(\"image_meta_dict\")`.\n", |
| 164 | + "(2) To get the meta data from decollated list of dictionaries `engine.state.batch`, set arg `lambda x: [i.meta for i in x]` or `from_engine(\"image_meta_dict\")`.\n", |
165 | 165 | "\n", |
166 | 166 | "(3) Metrics usually expect a `Tuple(pred, label)` input, if `engine.state.output` is a dictionary, set arg `output_transform=lambda x: (x[\"pred\"], x[\"label\"])`. If decollated list, set arg `lambda x: ([i[\"pred\"] for i in x], [i[\"label\"] for i in x])` or `from_engine([\"pred\", \"label\"])`.\n", |
167 | 167 | "\n", |
|
244 | 244 | "train_transforms = Compose(\n", |
245 | 245 | " [\n", |
246 | 246 | " LoadImaged(keys=[\"image\", \"label\"]),\n", |
247 | | - " AsChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n", |
| 247 | + " EnsureChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n", |
248 | 248 | " ScaleIntensityd(keys=\"image\"),\n", |
249 | 249 | " RandCropByPosNegLabeld(\n", |
250 | 250 | " keys=[\"image\", \"label\"], label_key=\"label\", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4\n", |
|
255 | 255 | "val_transforms = Compose(\n", |
256 | 256 | " [\n", |
257 | 257 | " LoadImaged(keys=[\"image\", \"label\"]),\n", |
258 | | - " AsChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n", |
| 258 | + " EnsureChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n", |
259 | 259 | " ScaleIntensityd(keys=\"image\"),\n", |
260 | 260 | " EnsureTyped(keys=[\"image\", \"label\"]),\n", |
261 | 261 | " ]\n", |
|
425 | 425 | "name": "python", |
426 | 426 | "nbconvert_exporter": "python", |
427 | 427 | "pygments_lexer": "ipython3", |
428 | | - "version": "3.8.12" |
| 428 | + "version": "3.8.13" |
429 | 429 | } |
430 | 430 | }, |
431 | 431 | "nbformat": 4, |
|
0 commit comments