|
104 | 104 | "\n", |
105 | 105 | "from monai.apps import download_and_extract\n", |
106 | 106 | "from monai.config import print_config\n", |
107 | | - "from monai.data import decollate_batch\n", |
| 107 | + "from monai.data import decollate_batch, DataLoader\n", |
108 | 108 | "from monai.metrics import ROCAUCMetric\n", |
109 | 109 | "from monai.networks.nets import DenseNet121\n", |
110 | 110 | "from monai.transforms import (\n", |
|
117 | 117 | " RandRotate,\n", |
118 | 118 | " RandZoom,\n", |
119 | 119 | " ScaleIntensity,\n", |
120 | | - " EnsureType,\n", |
121 | 120 | ")\n", |
122 | 121 | "from monai.utils import set_determinism\n", |
123 | 122 | "\n", |
|
361 | 360 | "source": [ |
362 | 361 | "train_transforms = Compose(\n", |
363 | 362 | " [\n", |
364 | | - " LoadImage(image_only=True),\n", |
| 363 | + " LoadImage(),\n", |
365 | 364 | " AddChannel(),\n", |
366 | 365 | " ScaleIntensity(),\n", |
367 | 366 | " RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),\n", |
368 | 367 | " RandFlip(spatial_axis=0, prob=0.5),\n", |
369 | 368 | " RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),\n", |
370 | | - " EnsureType(),\n", |
371 | 369 | " ]\n", |
372 | 370 | ")\n", |
373 | 371 | "\n", |
374 | 372 | "val_transforms = Compose(\n", |
375 | | - " [LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])\n", |
| 373 | + " [LoadImage(), AddChannel(), ScaleIntensity()])\n", |
376 | 374 | "\n", |
377 | | - "y_pred_trans = Compose([EnsureType(), Activations(softmax=True)])\n", |
378 | | - "y_trans = Compose([EnsureType(), AsDiscrete(to_onehot=num_class)])" |
| 375 | + "y_pred_trans = Compose([Activations(softmax=True)])\n", |
| 376 | + "y_trans = Compose([AsDiscrete(to_onehot=num_class)])" |
379 | 377 | ] |
380 | 378 | }, |
381 | 379 | { |
|
398 | 396 | "\n", |
399 | 397 | "\n", |
400 | 398 | "train_ds = MedNISTDataset(train_x, train_y, train_transforms)\n", |
401 | | - "train_loader = torch.utils.data.DataLoader(\n", |
| 399 | + "train_loader = DataLoader(\n", |
402 | 400 | " train_ds, batch_size=300, shuffle=True, num_workers=10)\n", |
403 | 401 | "\n", |
404 | 402 | "val_ds = MedNISTDataset(val_x, val_y, val_transforms)\n", |
405 | | - "val_loader = torch.utils.data.DataLoader(\n", |
| 403 | + "val_loader = DataLoader(\n", |
406 | 404 | " val_ds, batch_size=300, num_workers=10)\n", |
407 | 405 | "\n", |
408 | 406 | "test_ds = MedNISTDataset(test_x, test_y, val_transforms)\n", |
409 | | - "test_loader = torch.utils.data.DataLoader(\n", |
| 407 | + "test_loader = DataLoader(\n", |
410 | 408 | " test_ds, batch_size=300, num_workers=10)" |
411 | 409 | ] |
412 | 410 | }, |
|
1328 | 1326 | ], |
1329 | 1327 | "metadata": { |
1330 | 1328 | "kernelspec": { |
1331 | | - "display_name": "Python 3", |
| 1329 | + "display_name": "Python 3 (ipykernel)", |
1332 | 1330 | "language": "python", |
1333 | 1331 | "name": "python3" |
1334 | 1332 | }, |
|
1342 | 1340 | "name": "python", |
1343 | 1341 | "nbconvert_exporter": "python", |
1344 | 1342 | "pygments_lexer": "ipython3", |
1345 | | - "version": "3.7.10" |
| 1343 | + "version": "3.8.13" |
1346 | 1344 | } |
1347 | 1345 | }, |
1348 | 1346 | "nbformat": 4, |
|
0 commit comments