A machine learning emulator of a CPM based on a diffusion model.
This is the code for the paper Addison et al. (2024) "Machine learning emulation of precipitation from km-scale regional climate simulations using a diffusion model".
Diffusion model implementation forked from PyTorch implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole.
Assumes you have pixi installed for managing dependencies.
- Clone repo and cd into it
- [Optional] Install U-Net code:
git clone --depth 1 https://github.com/henryaddison/Pytorch-UNet.git src/ml_downscaling_emulator/unet- this is only necessary if you wish to use the deterministic comparison models. - Configure application behaviour with environment variables. See
.env.examplefor variables that can be set.
Any datasets are assumed to be found in ${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/. In particular, the config key config.data.dataset_name is the name of the dataset to use to train the model.
To add new packages or update their version, can update the dependencies in pixi.toml then run
pixi installor add them using:
pixi add NEW_DEPthen commit any changes to pixi.toml and pixi.lock
Datasets for use with the emulator can be created using [[https://github.com/henryaddison/mlde-data]]. This repo contains further information about dataset specification. The datasets used in the paper can be found on Zenodo.
NB the interface commonly takes just the name of a dataset name. It is expected to be found at ${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/ (where DERIVED_DATA is a configurable environment variable).
pixi run tests/smoke-testUses a simpler network to test the full training and sampling regime. Recommended to run with a sample of the dataset.
Train models through bin/main.py, e.g. to train the model used in the paper use
pixi run python bin/main.py --config src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERIVED_DATA}/path/to/models/paper-12em --mode trainmain.py:
--mode: <train>: Running mode: train
--workdir: Working directory for storing data related to model such as model snapshots, tranforms or samples
--config: Training configuration.
(default: 'None')-
modeis "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist inworkdir/checkpoints-meta. -
workdiris the path that stores all artifacts of one experiment, like checkpoints, transforms and samples. Recommended to be a subdirectory of ${DERIVED_DATA}. -
configis the path to the config file. Config files for emulators are provided insrc/configs/. They are formatted according toml_collectionsand heavily based on ncsnpp config files.Naming conventions of config files: the path of a config file is a combination of the following dimensions:
- SDE:
subvpsde - data source:
ukcp_local - variable:
pr - ensemble members:
12em(all 12) or1em(single) - model:
cncsnpp - continuous: train the model with continuously sampled time steps.
- SDE:
Functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package.
Once have trained a model create samples from it with bin/predict.py, e.g.
pixi run python bin/predict.py --checkpoint epoch_20 --dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --split test --ensemble-member 01 --input-transform-dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --input-transform-key pixelmmsstan --num-samples 1 ${DERIVED_DATA}/path/to/models/paper-12emThis example command will:
- use the checkpoint of the model in
${DERIVED_DATA}/path/to/models/paper-12em/checkpoints/{checkpoint}.pthand model config from training${DERIVED_DATA}/path/to/models/paper-12em/config.yml. - store samples generated in
${DERIVED_DATA}/path/to/models/paper-12em/samples/{dataset}/{input_transform_data}-{input_transform_key}/{split}/{ensemble_member}/. Sample files ar named likepredictions-{uuid}.nc. - generate samples conditioned on examples from ensemble member
01in thetestsubset of thebham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-seasondataset. - transform the inputs based on the
bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-seasondataset using thepixelmmsstanapproach. - generate 1 set of samples.