|
123 | 123 | " Dataset,\n", |
124 | 124 | " pad_list_data_collate,\n", |
125 | 125 | " TestTimeAugmentation,\n", |
126 | | - " decollate_batch,\n", |
127 | 126 | ")\n", |
128 | 127 | "from monai.inferers import sliding_window_inference\n", |
129 | 128 | "from monai.losses import DiceLoss\n", |
|
228 | 227 | " def __call__(self, data):\n", |
229 | 228 | " d = dict(data)\n", |
230 | 229 | " im = d[self.label_key]\n", |
231 | | - " _im = im.detach().cpu().numpy()\n", |
232 | | - " q = np.sum((_im > 0).reshape(-1, _im.shape[-1]), axis=0)\n", |
| 230 | + " q = np.sum((im.array > 0).reshape(-1, im.array.shape[-1]), axis=0)\n", |
233 | 231 | " _slice = np.where(q == np.max(q))[0][0]\n", |
234 | 232 | " for key in self.keys:\n", |
235 | 233 | " d[key] = d[key][..., _slice]\n", |
|
247 | 245 | " fname = os.path.basename(\n", |
248 | 246 | " data[key + \"_meta_dict\"][\"filename_or_obj\"])\n", |
249 | 247 | " path = os.path.join(self.path, key, fname)\n", |
250 | | - " nib.save(nib.Nifti1Image(data[key].detach().cpu().numpy(), np.eye(4)), path)\n", |
| 248 | + " nib.save(nib.Nifti1Image(data[key].array, np.eye(4)), path)\n", |
251 | 249 | " d[key] = path\n", |
252 | 250 | " return d\n", |
253 | 251 | "\n", |
|
443 | 441 | "def infer_seg(images, model, roi_size=(96, 96), sw_batch_size=4):\n", |
444 | 442 | " val_outputs = sliding_window_inference(\n", |
445 | 443 | " images, roi_size, sw_batch_size, model)\n", |
446 | | - " return torch.stack([post_trans(i) for i in decollate_batch(val_outputs)])\n", |
| 444 | + " return pad_list_data_collate([post_trans(i) for i in val_outputs])\n", |
447 | 445 | "\n", |
448 | 446 | "\n", |
449 | 447 | "# Create network, loss fn., etc.\n", |
|
0 commit comments