Skip to content

Commit de12b79

Browse files
mersad95zdpre-commit-ci[bot]wyli
authored
VarNet demo (#904)
* initial unet recon demo commit Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed formatting errors Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * fixed formatting errors Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * fixing formatting errors Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * more experimental details added Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * removed checkpoint from this PR; minor fix to checkpoint directory in inference.ipynb Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * clarified common practice for fastMRI inference Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * fixed model checkpoint name in all files Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * added model checkpoint link to readme Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * init varnet commit Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * back to init commit Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated readmes Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * removed unet files Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * updated inference files based on the data split Signed-off-by: mersad95zd <m.zalbagi@gmail.com> * minor update to unet inference Signed-off-by: mersad95zd <m.zalbagi@gmail.com> Signed-off-by: mersad95zd <m.zalbagi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com>
1 parent 4b348ef commit de12b79

File tree

11 files changed

+895
-41
lines changed

11 files changed

+895
-41
lines changed

reconstruction/MRI_reconstruction/README.md renamed to reconstruction/MRI_reconstruction/unet_demo/README.md

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,17 @@ This folder contains code to train and validate a U-Net for accelerated MRI reco
3030
# Dataset
3131

3232
The experiments are performed on the [fastMRI](https://fastmri.org/dataset) dataset. Users should request access to the dataset
33-
from the [owner's website](https://fastmri.org/dataset).
33+
from the [owner's website](https://fastmri.org/dataset). Remember to use the `$PATH` where you downloaded the data in `train.py`
34+
or `inference.ipynb` accordingly.
3435

35-
**Note.** Since the ground truth is not released with the test set of the fastMRI dataset, it is a common practice in the literature
36-
to perform inference on the validation set of the fastMRI dataset. This could be in the form of testing on the whole validation
37-
set (for example this work [https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8767765/](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8767765/)).
38-
<br>
39-
Another approach is to split the validation set into validation and test sets and keep the test portion for inference (for exmple this work [https://arxiv.org/pdf/2111.02549.pdf](https://arxiv.org/pdf/2111.02549.pdf)). Note that both approaches are conceptually similar
40-
in that splitting the validation set does not change the fact that the splits belong to the same distribution.
41-
<br>
42-
Other workarounds to this problem include (1) skipping validation during training and saving the model checkpoint of the last epoch for inference on the validation set, and (2) submitting model results to the [fastMRI public leaderboard](https://fastmri.org/leaderboards/).
36+
For our experiments we created a subset of the fastMRI dataset which contains a `500/179/133` split for `train/val/test`. Please download [fastmri_data_split.json](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/fastmri_data_split.json) and put it here under `./data`.
4337

4438
# Model checkpoint
4539

4640
We have already provided a model checkpoint [unet_mri_reconstruction.pt](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/unet_mri_reconstruction.pt) for a U-Net with `7,782,849` parameters. To obtain this checkpoint, we trained
47-
a U-Net with the default hyper-parameters in `train.py` on the T2 subset of the brain dataset (`500` training and `180` validation volumes). The user can train their model on an arbitrary portion of the dataset.
41+
a U-Net with the default hyper-parameters in `train.py` on the T2 subset of the brain dataset. The user can train their model on an arbitrary portion of the dataset.
4842

49-
Our checkpoint achieves `0.9496` SSIM on the fastMRI T2 validation subset which is comparabale to the original result reported on the
50-
[fastMRI public leaderboard](https://fastmri.org/leaderboards/) (which is `0.9474` SSIM). The training dynamics for our checkpoint is depicted in the figure below.
43+
The training dynamics for our checkpoint is depicted in the figure below.
5144

5245
<p align="center"><img src="./figures/dynamics.PNG" width="800" height="225"></p>
5346

@@ -71,5 +64,9 @@ Running `train.py` trains a U-Net. The default setup automatically detects a GPU
7164

7265
# Inference
7366

74-
The notebook `inference.ipynb` contains an example to perform validation. Average SSIM score over the validation set is computed and then
67+
The notebook `inference.ipynb` contains an example to perform inference. Average SSIM score over the test subset is computed and then
7568
one sample is picked for visualization.
69+
70+
Our checkpoint achieves `0.9436` SSIM on our test subset which is comparable to the original result reported on the
71+
[fastMRI public leaderboard](https://fastmri.org/leaderboards/) (which is `0.9474` SSIM). Note that the results reported
72+
on the leaderboard are for the unreleased test set. Moreover, the leaderboard model is trained on the validation set.

reconstruction/MRI_reconstruction/fastmri_ssim.py renamed to reconstruction/MRI_reconstruction/unet_demo/fastmri_ssim.py

File renamed without changes.

reconstruction/MRI_reconstruction/figures/dynamics.PNG renamed to reconstruction/MRI_reconstruction/unet_demo/figures/dynamics.PNG

File renamed without changes.

reconstruction/MRI_reconstruction/figures/workflow.PNG renamed to reconstruction/MRI_reconstruction/unet_demo/figures/workflow.PNG

File renamed without changes.

reconstruction/MRI_reconstruction/inference.ipynb renamed to reconstruction/MRI_reconstruction/unet_demo/inference.ipynb

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"import torch\n",
3030
"import warnings\n",
3131
"import random\n",
32+
"import json\n",
3233
"from fastmri_ssim import skimage_ssim\n",
3334
"import matplotlib.pyplot as plt\n",
3435
"\n",
@@ -76,7 +77,7 @@
7677
" self.batch_size = 1 # can be set to >1 when input sizes are not different\n",
7778
" self.num_workers = 0\n",
7879
" self.cache_rate = 0.0 # what fraction of the data to be cached for faster loading\n",
79-
" self.data_path_val = '/data/fastmri/fastMRI/multicoil_val_t2/' # path to the validation set\n",
80+
" self.data_path_val = '/data/fastmri/multicoil_val/' # path to the validation set\n",
8081
" self.sample_rate = 0.9 # select 0.9 of the validation set for inference\n",
8182
" self.accelerations = [4] # acceleration factors used for valdiation.\n",
8283
" self.center_fractions = [0.08] # center_fractions used for valdiation.\n",
@@ -104,16 +105,29 @@
104105
"# Create validation data loader"
105106
]
106107
},
108+
{
109+
"cell_type": "code",
110+
"execution_count": null,
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"with open(\"./data/fastmri_data_split.json\", \"r\") as fn:\n",
115+
" data = json.load(fn)\n",
116+
"test_files = data['test_files']\n",
117+
"fastmri_val_set = list(Path(args.data_path_val).iterdir())\n",
118+
"test_files = [f for f in fastmri_val_set if str(f).split('/')[-1] in test_files]"
119+
]
120+
},
107121
{
108122
"cell_type": "code",
109123
"execution_count": 3,
110124
"metadata": {},
111125
"outputs": [],
112126
"source": [
113-
"val_files = list(Path(args.data_path_val).iterdir())\n",
114-
"random.shuffle(val_files)\n",
115-
"val_files = val_files[:int(args.sample_rate*len(val_files))] # select a subset of the data according to sample_rate\n",
116-
"val_files = [dict([(\"kspace\", val_files[i])]) for i in range(len(val_files))]\n",
127+
"random.shuffle(test_files)\n",
128+
"test_files = test_files[:int(args.sample_rate*len(test_files))] # select a subset of the data according to sample_rate\n",
129+
"test_files = [dict([(\"kspace\", test_files[i])]) for i in range(len(test_files))]\n",
130+
"print(f'#test files: {len(test_files)}')\n",
117131
"\n",
118132
"# define mask transform type (e.g., whether it is equispaced or random)\n",
119133
"if args.mask_type == 'random':\n",
@@ -129,7 +143,7 @@
129143
" spatial_dims=2,\n",
130144
" is_complex=True)\n",
131145
"\n",
132-
"val_transforms = Compose(\n",
146+
"test_transforms = Compose(\n",
133147
" [\n",
134148
" LoadImaged(keys=[\"kspace\"], reader=FastMRIReader, dtype=np.complex64),\n",
135149
" # user can also add other random transforms\n",
@@ -145,10 +159,10 @@
145159
" ]\n",
146160
")\n",
147161
"\n",
148-
"val_ds = CacheDataset(\n",
149-
" data=val_files, transform=val_transforms,\n",
162+
"test_ds = CacheDataset(\n",
163+
" data=test_files, transform=test_transforms,\n",
150164
" cache_rate=args.cache_rate, num_workers=args.num_workers)\n",
151-
"val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)"
165+
"test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)"
152166
]
153167
},
154168
{
@@ -203,30 +217,22 @@
203217
},
204218
{
205219
"cell_type": "code",
206-
"execution_count": 5,
220+
"execution_count": null,
207221
"metadata": {},
208-
"outputs": [
209-
{
210-
"name": "stdout",
211-
"output_type": "stream",
212-
"text": [
213-
"161 volume out of 161 done. \r"
214-
]
215-
}
216-
],
222+
"outputs": [],
217223
"source": [
218224
"outputs = defaultdict(list)\n",
219225
"targets = defaultdict(list)\n",
220226
"with torch.no_grad():\n",
221227
" val_ssim = list()\n",
222228
" step = 1\n",
223-
" for val_data in val_loader:\n",
229+
" for test_data in test_loader:\n",
224230
" input, target, mean, std, fname = (\n",
225-
" val_data[\"kspace_masked_ifft\"],\n",
226-
" val_data[\"reconstruction_rss\"],\n",
227-
" val_data[\"mean\"],\n",
228-
" val_data[\"std\"],\n",
229-
" val_data[\"kspace_meta_dict\"][\"filename\"]\n",
231+
" test_data[\"kspace_masked_ifft\"],\n",
232+
" test_data[\"reconstruction_rss\"],\n",
233+
" test_data[\"mean\"],\n",
234+
" test_data[\"std\"],\n",
235+
" test_data[\"kspace_meta_dict\"][\"filename\"]\n",
230236
" )\n",
231237
"\n",
232238
" # iterate through all slices:\n",
@@ -247,7 +253,7 @@
247253
" # save volume slices according to volume name given by fname\n",
248254
" outputs[fname[0]].append(output.data.cpu().numpy()[0][0]*_std+_mean)\n",
249255
" targets[fname[0]].append(tar.numpy()[0][0]*_std+_mean)\n",
250-
" print(step, ' volume out of', len(val_files), 'done.', '\\r', end='')\n",
256+
" print(step, ' volume out of', len(test_files), 'done.', '\\r', end='')\n",
251257
" step += 1\n",
252258
"\n",
253259
" # compute validation ssims values for all validation samples\n",
@@ -261,14 +267,14 @@
261267
},
262268
{
263269
"cell_type": "code",
264-
"execution_count": 7,
270+
"execution_count": 1,
265271
"metadata": {},
266272
"outputs": [
267273
{
268274
"name": "stdout",
269275
"output_type": "stream",
270276
"text": [
271-
"average SSIM score over the validation set: 0.9496\n"
277+
"average SSIM score over the validation set: 0.9436\n"
272278
]
273279
}
274280
],
File renamed without changes.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Accelerated MRI reconstruction with the end-to-end variational network (e2e-VarNet)
2+
3+
<p align="center"><img src="./figures/workflow.PNG" width="800" height="400"></p>
4+
5+
6+
This folder contains code to train and validate an e2e-VarNet ([https://arxiv.org/pdf/2004.06688.pdf](https://arxiv.org/pdf/2004.06688.pdf)) for accelerated MRI reconstruction. Accelerated MRI reconstruction is a compressed sensing task where the goal is to recover a ground-truth image from an under-sampled measurement. The under-sampled measurement is based in the frequency domain and is often called the $k$-space.
7+
8+
***
9+
10+
### List of contents
11+
12+
* [Questions and bugs](#Questions-and-bugs)
13+
14+
* [Dataset](#Dataset)
15+
16+
* [Model checkpoint](#Model-checkpoint)
17+
18+
* [Training](#Training)
19+
20+
* [Inference](#Inference)
21+
22+
***
23+
24+
# Questions and bugs
25+
26+
- For questions relating to the use of MONAI, please us our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI.
27+
- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues).
28+
- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues).
29+
30+
# Dataset
31+
32+
Please see [dataset description](../unet_demo/README.md#dataset) for our dataset preparation.
33+
34+
35+
# Model checkpoint
36+
37+
We have already provided a model checkpoint [varnet_mri_reconstruction.pt](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/varnet_mri_reconstruction.pt) for a VarNet with `30,069,558` parameters. To obtain this checkpoint, we trained
38+
a VarNet with the default hyper-parameters in `train.py` on our T2 subset of the brain dataset. The user can train their model on an arbitrary portion of the dataset.
39+
40+
The training dynamics for our checkpoint is depicted in the figure below.
41+
42+
<p align="center"><img src="./figures/dynamics.PNG" width="800" height="250"></p>
43+
44+
# Training
45+
46+
Running `train.py` trains a VarNet. The default setup automatically detects a GPU for training; if not available, CPU will be used.
47+
48+
# Run this to get a full list of training arguments
49+
python ./train.py -h
50+
51+
# This is an example of calling train.py
52+
python ./train.py
53+
--data_path_train train_dir \
54+
--data_path_val val_dir \
55+
--exp varnet_mri_recon \
56+
--exp_dir ./ \
57+
--mask_type equispaced \
58+
--num_epochs 50 \
59+
--num_workers 0 \
60+
--lr 0.00001
61+
62+
# Inference
63+
64+
The notebook `inference.ipynb` contains an example to perform inference. Average SSIM score over the test subset is computed and then
65+
one sample is picked for visualization.
66+
67+
Our checkpoint achieves `0.9650` SSIM on our test subset which is comparable to the original result reported on the
68+
[fastMRI public leaderboard](https://fastmri.org/leaderboards/) (which is `0.9606` SSIM). Note that the results reported
69+
on the leaderboard are for the unreleased test set. Moreover, the leaderboard model is trained on the validation set.
34.3 KB
Loading
302 KB
Loading

0 commit comments

Comments
 (0)