Skip to content

Latest commit

 

History

History
102 lines (86 loc) · 4.08 KB

File metadata and controls

102 lines (86 loc) · 4.08 KB

Training regression network with weighted sampling of image windows

This page describes how to acquire and use weighted sampler for image regression.

ref:

Berger et al., "An Adaptive Sampling Scheme to Efficiently Train Fully Convolutional Networks for Semantic Segmentation", https://arxiv.org/abs/1709.02764

Downloading model zoo files

The training data and initial error maps can be downloaded with the command

net_download mr_ct_regression_model_zoo

(Replace net_download with python net_download.py if you cloned the NiftyNet repository.)

Initial training

Command line parameters: --starting_iter 0 --max_iter 1000

python net_run.py train \
  -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression \
  -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini \
  --starting_iter 0 --max_iter 1000

Generating error maps

Command line parameters: --spatial_window_size 240,240,1 --batch_size 4 modify the inference batch size and window size for efficiency purpose. With parameter --error_map True the errors (elementwise squared differences) will be generated to ~/niftynet/models/mr_ct_regression/error_maps.

python net_run.py inference \
  -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression \
  -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini \
  --inference_iter 1000 --spatial_window_size 240,240,1 --batch_size 4 --error_map True

Continue training by sampling according to the error maps:

Command line parameters --starting_iter -1 indicate training the model from the most recently saved checkpoint (at iteration 1000).

python net_run.py train \
  -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression \
  -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini \
  --starting_iter -1 --max_iter 1500

Combine them together

Alternating in between error map generation and training with new sampling weights: (from git cloned source code)

python net_download.py mr_ct_regression_model_zoo
python net_run.py train \
  -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression \
  -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini \
  --starting_iter 0 --max_iter 1000
for max_iter in `seq 1500 500 10000`
do
  python net_run.py inference \
    -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression \
    -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini \
    --inference_iter -1 --spatial_window_size 240,240,1 --batch_size 4 --error_map True

  python net_run.py train \
    -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression \
    -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini \
    --starting_iter -1 --max_iter $max_iter
done

This script runs training for 10000 iterations, and new sampling weights are generated at every 500 iterations.

To see the training/validation curves using tensorboard:

tensorboard --logdir ~/niftynet/models/mr_ct_regression/logs

Generating regression output

Finally regression maps on the test set could be generated by (inference without --error_map True parameter):

python net_run.py inference \
  -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression \
  -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini \
  --inference_iter -1 --spatial_window_size 240,240,1 --batch_size 4 --error_map False

to generate results on training+validation+test, please set --dataset_split_file nofile to override the splitting file at ~/niftynet/models/mr_ct_regression/dataset_split_file.txt

python net_run.py inference \
  -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression \
  -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini \
  --inference_iter -1 --spatial_window_size 240,240,1 --batch_size 4 --error_map False --dataset_split_file nofile

The output can be found at ~/niftynet/models/mr_ct_regression/isampler_output/.