|
53 | 53 | }, |
54 | 54 | { |
55 | 55 | "cell_type": "code", |
56 | | - "execution_count": null, |
| 56 | + "execution_count": 1, |
57 | 57 | "metadata": {}, |
58 | 58 | "outputs": [], |
59 | 59 | "source": [ |
|
69 | 69 | }, |
70 | 70 | { |
71 | 71 | "cell_type": "code", |
72 | | - "execution_count": null, |
| 72 | + "execution_count": 2, |
73 | 73 | "metadata": {}, |
74 | 74 | "outputs": [], |
75 | 75 | "source": [ |
|
78 | 78 | }, |
79 | 79 | { |
80 | 80 | "cell_type": "code", |
81 | | - "execution_count": null, |
| 81 | + "execution_count": 3, |
82 | 82 | "metadata": { |
83 | 83 | "id": "KvbbZuhmquRR" |
84 | 84 | }, |
|
92 | 92 | }, |
93 | 93 | { |
94 | 94 | "cell_type": "code", |
95 | | - "execution_count": null, |
| 95 | + "execution_count": 4, |
96 | 96 | "metadata": { |
97 | 97 | "id": "gduPdIturUIB" |
98 | 98 | }, |
99 | 99 | "outputs": [], |
100 | 100 | "source": [ |
101 | | - "from pathlib import Path\n", |
102 | 101 | "from datetime import datetime\n", |
| 102 | + "import os\n", |
| 103 | + "import tempfile\n", |
| 104 | + "from glob import glob\n", |
103 | 105 | "\n", |
104 | 106 | "import torch\n", |
105 | 107 | "from torch.utils.data import random_split, DataLoader\n", |
|
117 | 119 | "%load_ext tensorboard" |
118 | 120 | ] |
119 | 121 | }, |
| 122 | + { |
| 123 | + "cell_type": "markdown", |
| 124 | + "metadata": {}, |
| 125 | + "source": [ |
| 126 | + "## Setup data directory\n", |
| 127 | + "\n", |
| 128 | + "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable. \n", |
| 129 | + "This allows you to save results and reuse downloads. \n", |
| 130 | + "If not specified a temporary directory will be used." |
| 131 | + ] |
| 132 | + }, |
| 133 | + { |
| 134 | + "cell_type": "code", |
| 135 | + "execution_count": 5, |
| 136 | + "metadata": {}, |
| 137 | + "outputs": [ |
| 138 | + { |
| 139 | + "name": "stdout", |
| 140 | + "output_type": "stream", |
| 141 | + "text": [ |
| 142 | + "/mnt/data/rbrown/Documents/Data/MONAI\n" |
| 143 | + ] |
| 144 | + } |
| 145 | + ], |
| 146 | + "source": [ |
| 147 | + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", |
| 148 | + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", |
| 149 | + "print(root_dir)" |
| 150 | + ] |
| 151 | + }, |
120 | 152 | { |
121 | 153 | "cell_type": "markdown", |
122 | 154 | "metadata": { |
|
145 | 177 | }, |
146 | 178 | { |
147 | 179 | "cell_type": "code", |
148 | | - "execution_count": null, |
| 180 | + "execution_count": 6, |
149 | 181 | "metadata": { |
150 | 182 | "id": "KuhTaRl3vf37" |
151 | 183 | }, |
152 | 184 | "outputs": [], |
153 | 185 | "source": [ |
154 | | - "\n", |
155 | | - "\n", |
156 | 186 | "class MedicalDecathlonDataModule(pl.LightningDataModule):\n", |
157 | 187 | " def __init__(self, task, batch_size, train_val_ratio):\n", |
158 | 188 | " super().__init__()\n", |
159 | 189 | " self.task = task\n", |
160 | 190 | " self.batch_size = batch_size\n", |
161 | | - " self.dataset_dir = Path(task)\n", |
| 191 | + " self.base_dir = root_dir\n", |
| 192 | + " self.dataset_dir = os.path.join(root_dir, task)\n", |
162 | 193 | " self.train_val_ratio = train_val_ratio\n", |
163 | 194 | " self.subjects = None\n", |
164 | 195 | " self.test_subjects = None\n", |
|
175 | 206 | " return shapes.max(axis=0)\n", |
176 | 207 | "\n", |
177 | 208 | " def download_data(self):\n", |
178 | | - " if not self.dataset_dir.is_dir():\n", |
179 | | - " url = 'https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar'\n", |
180 | | - " monai.apps.download_and_extract(url=url, output_dir=\".\")\n", |
| 209 | + " if not os.path.isdir(self.dataset_dir):\n", |
| 210 | + " url = f'https://msd-for-monai.s3-us-west-2.amazonaws.com/{self.task}.tar'\n", |
| 211 | + " monai.apps.download_and_extract(url=url, output_dir=self.base_dir)\n", |
181 | 212 | "\n", |
182 | | - " def get_niis(d):\n", |
183 | | - " return sorted(p for p in d.glob('*.nii*') if not p.name.startswith('.'))\n", |
184 | | - "\n", |
185 | | - " image_training_paths = get_niis(self.dataset_dir / 'imagesTr')\n", |
186 | | - " label_training_paths = get_niis(self.dataset_dir / 'labelsTr')\n", |
187 | | - " image_test_paths = get_niis(self.dataset_dir / 'imagesTs')\n", |
| 213 | + " image_training_paths = sorted(glob(os.path.join(self.dataset_dir, 'imagesTr', \"*.nii*\")))\n", |
| 214 | + " label_training_paths = sorted(glob(os.path.join(self.dataset_dir, 'labelsTr', \"*.nii*\")))\n", |
| 215 | + " image_test_paths = sorted(glob(os.path.join(self.dataset_dir, 'imagesTs', \"*.nii*\")))\n", |
188 | 216 | " return image_training_paths, label_training_paths, image_test_paths\n", |
189 | 217 | "\n", |
190 | 218 | " def prepare_data(self):\n", |
|
260 | 288 | }, |
261 | 289 | { |
262 | 290 | "cell_type": "code", |
263 | | - "execution_count": null, |
| 291 | + "execution_count": 7, |
264 | 292 | "metadata": { |
265 | 293 | "id": "hcHf9w2nLfyC" |
266 | 294 | }, |
|
284 | 312 | }, |
285 | 313 | { |
286 | 314 | "cell_type": "code", |
287 | | - "execution_count": null, |
| 315 | + "execution_count": 8, |
288 | 316 | "metadata": { |
289 | 317 | "colab": { |
290 | 318 | "base_uri": "https://localhost:8080/" |
|
293 | 321 | "outputId": "7cb39051-4c26-4811-b838-8a5e938e53a3" |
294 | 322 | }, |
295 | 323 | "outputs": [ |
296 | | - { |
297 | | - "name": "stderr", |
298 | | - "output_type": "stream", |
299 | | - "text": [ |
300 | | - "Downloading...\n", |
301 | | - "From: https://drive.google.com/uc?id=1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C\n", |
302 | | - "To: /content/Task04_Hippocampus.tar\n", |
303 | | - "28.4MB [00:00, 82.8MB/s]\n" |
304 | | - ] |
305 | | - }, |
306 | 324 | { |
307 | 325 | "name": "stdout", |
308 | 326 | "output_type": "stream", |
|
341 | 359 | }, |
342 | 360 | { |
343 | 361 | "cell_type": "code", |
344 | | - "execution_count": null, |
| 362 | + "execution_count": 9, |
345 | 363 | "metadata": { |
346 | 364 | "id": "1Ov3H12p6Qx1" |
347 | 365 | }, |
|
395 | 413 | }, |
396 | 414 | { |
397 | 415 | "cell_type": "code", |
398 | | - "execution_count": null, |
| 416 | + "execution_count": 10, |
399 | 417 | "metadata": { |
400 | 418 | "colab": { |
401 | 419 | "base_uri": "https://localhost:8080/" |
|
0 commit comments