Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions era5_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ This will give you:

```
inputs/
└── 1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2010_constant_mu_sigma_scaling01.nc
model-huggingface/
├── ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch8.pt
├── ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch94.pt
└── attnunet_era5_global_global_uvthetaw_mseloss_train_epoch119.pt
└── 1x1_inputfeatures_u_v_theta_w_uw_vw_gcp_era5_training_data_hourly_2015_L93_constant_mu_sigma_scaling01.nc
model-huggingface
├── retrained_L93_ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch70.pt
└── retrained_L93_attnunet_era5_global_global_uvthetaw_mseloss_train_epoch100.pt
```

## Torchscript the models
Expand All @@ -36,7 +35,7 @@ First we would need torchscript'd models. These can be generated by running `inf
### Unet

```bash
python inference.py -M attention -d global -v global -f uvthetaw -e 119 -m 1 -s 1 -t era5 -i inputs/ -c model-huggingface/ -o outputs/ --script
python inference.py -M attention -d global -v global -f uvthetaw -e 100 -m 1 -s 1 -t era5 -i inputs/ -c model-huggingface/ -o outputs/ --script
```

This will generate some test data and a torchscripted model, to be used by `infer.f90` and `infer.py` later on.
Expand All @@ -51,7 +50,7 @@ test-data/
### Ann

```bash
python inference.py -M ann -d global -v global -f uvthetaw -e 8 -m 1 -s 1 -t era5 -i inputs/ -c model-huggingface/ -o outputs/ --script
python inference.py -M ann -d global -v global -f uvthetaw -e 70 -m 1 -s 1 -t era5 -i inputs/ -c model-huggingface/ -o outputs/ --script
```

This will generate some test data and a torchscripted model, to be used by `infer.f90` and `infer.py` later on.
Expand Down Expand Up @@ -86,7 +85,7 @@ python infer.py -M ann -t test-data/ -s .
To test the newly generate torchscript models, use the following command:

```bash
bash compile-and-run.sh intel
bash compile-and-run.sh gcc
```

This will compile `infer.f90` into `infer.exe`. This requires having cuda installed on your system. It also requires `ftorch` to
Expand Down
41 changes: 29 additions & 12 deletions era5_training/batch_ann.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash -l
#PBS -N 1x1_uvthw
#PBS -N scripting
#PBS -A USTN0009
#PBS -l select=1:ncpus=4:ngpus=1:mem=80GB
#PBS -l walltime=01:00:00
Expand Down Expand Up @@ -33,19 +33,36 @@ source ~/nonlocal_gwfluxes/.nlgw/bin/activate
# -o /glade/derecho/scratch/agupta/torch_saved_models/


#python inference.py \
# -M attention \
# -d global \
# -v global \
# -f uvthetaw \
# -e 119 \
# -m 1 \
# -s 3 \
# -t era5 \
# -i /glade/derecho/scratch/agupta/era5_training_data/ \
# -c /glade/derecho/scratch/agupta/hugging_face_checkpoints/ \
# -o /glade/derecho/scratch/agupta/gw_inference_files/


python inference.py \
-M attention \
-d global \
-v global \
-f uvthetaw \
-e 119 \
-m 1 \
-s 3 \
-t era5 \
-i /glade/derecho/scratch/agupta/era5_training_data/ \
-c /glade/derecho/scratch/agupta/hugging_face_checkpoints/ \
-o /glade/derecho/scratch/agupta/gw_inference_files/
-M ann \
-d global \
-v global \
-f uvthetaw \
-e 70 \
-s 1 \
-t era5 \
-m 1 \
-i inputs/ \
-c model-huggingface/ \
-o outputs/ \
--script


#python inference.py -M ann -d global -v global -f uvthetaw -e 85 -m 1 -s 1 -t era5 -i /glade/derecho/scratch/agupta/new_training_data/ -c /glade/derecho/scratch/agupta/hugging_face_checkpoints/ -o /glade/derecho/scratch/agupta/gw_inference_files/ --script



Expand Down
44 changes: 25 additions & 19 deletions era5_training/batch_unet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,31 @@ source ~/nonlocal_gwfluxes/.nlgw/bin/activate
#python training_attention_unet.py stratosphere_only uvthetawN2


python training.py \
-M attention \
-d global \
-v stratosphere_update \
-f uvw \
-i /glade/derecho/scratch/agupta/era5_training_data/ \
-o /glade/derecho/scratch/agupta/torch_saved_models/


#python inference.py \
# -M attention \
# -d global \
# -v stratosphere_update \
# -f uvw \
# -e 100 \
# -s 1 \
# -t era5 \
# -m 1 \
# -i /glade/derecho/scratch/agupta/era5_training_data/ \
#python training.py \
# -M attention \
# -d global \
# -v stratosphere_update \
# -f uvw \
# -i /glade/derecho/scratch/agupta/era5_training_data/ \
# -o /glade/derecho/scratch/agupta/torch_saved_models/


python inference.py \
-M attention \
-d global \
-v global \
-f uvthetaw \
-e 100 \
-s 1 \
-t era5 \
-m 1 \
-i inputs/ \
-c model-huggingface/ \
-o outputs/ \
--script


# -i /glade/derecho/scratch/agupta/era5_training_data/ \
# -c /glade/derecho/scratch/agupta/torch_saved_models/ \
# -o /glade/derecho/scratch/agupta/gw_inference_files/

Expand Down
36 changes: 8 additions & 28 deletions era5_training/compile-and-run.sh
Original file line number Diff line number Diff line change
@@ -1,34 +1,15 @@
COMP=$1
FC=ifort
FFLAGS=""

if [[ ${COMP} == "intel" ]]; then
FC=ifort
FFLAGS=""

# source /glade/u/home/tmeltzer/cam-test/debug_env.sh

module purge
module load cesmdev/1.0 ncarenv/23.06 craype/2.7.20 linaro-forge/23.0 intel/2023.0.0 mkl/2023.0.0
module load ncarcompilers/1.0.0 cmake/3.26.3 cray-mpich/8.1.25 hdf5-mpi/1.12.2
module load netcdf-mpi/4.9.2 parallel-netcdf/1.12.3 parallelio/2.6.2-debug esmf/8.6.0b04-debug
elif [[ ${COMP} == "gcc" ]]; then

FC=gfortran
FFLAGS="-ffree-line-length-none"

module purge
module load ncarenv/24.12 gcc/12.4.0 cmake cuda/12.3.2 netcdf/4.9.3
else
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[0;33m'
NC='\033[0m' # No Color
echo -e "${RED}ERROR:${YELLOW} required option missing. Please specify [${GREEN}gcc${YELLOW}] or [${GREEN}intel${YELLOW}] as compiler.${NC}"
exit 1
fi
module --force purge
# these come from the environment listed in software_environment.txt in the CESM Case directory
module load cesmdev/1.0 ncarenv/23.06 craype/2.7.20 intel/2023.0.0 mkl/2023.0.0 ncarcompilers/1.0.0
module load cmake/3.26.3 cray-mpich/8.1.25 hdf5-mpi/1.12.2 netcdf-mpi/4.9.2 parallel-netcdf/1.12.3
module load parallelio/2.6.2 esmf/8.6.0b04

source ../.nlgw/bin/activate

FTORCH_ROOT="/glade/u/home/tmeltzer/FTorch/bin/ftorch_${COMP}"
FTORCH_ROOT="${HOME}/fresh/ftorch-install"
NETCDF_LIB="${NETCDF}/lib"
export LD_LIBRARY_PATH="${NETCDF_LIB}:${FTORCH_ROOT}/lib64:${LD_LIBRARY_PATH}"

Expand All @@ -45,7 +26,6 @@ echo $COMMAND

${COMMAND}

# gdb -q --args ./infer.exe attention test-data/ .
./infer.exe attention test-data/ .
echo
echo "========================================="
Expand Down
7 changes: 3 additions & 4 deletions era5_training/get-model-and-data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ mkdir -p inputs

echo "retrieving model weights..."
cd model-huggingface
wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch8.pt
wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch94.pt
wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/attnunet_era5_global_global_uvthetaw_mseloss_train_epoch119.pt
wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/retrained_L93_ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch70.pt
wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/retrained_L93_attnunet_era5_global_global_uvthetaw_mseloss_train_epoch100.pt
cd ..

echo "retrieving test input..."
(cd inputs && wget https://g-b56e81.7a577b.6fbd.data.globus.org/1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2010_constant_mu_sigma_scaling01.nc)
(cd inputs && wget https://g-b56e81.7a577b.6fbd.data.globus.org/1x1_inputfeatures_u_v_theta_w_uw_vw_gcp_era5_training_data_hourly_2015_L93_constant_mu_sigma_scaling01.nc)
5 changes: 3 additions & 2 deletions era5_training/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ def main():
input_data = load_nc_dataset(args.test_data_dir / Path(prefix + "-input.nc"))
pred_reference = load_nc_dataset(args.test_data_dir / Path(prefix + "-predict.nc"))

model_path = args.scripted_model_dir / Path(f"nlgw_{prefix}_gpu_scripted.pt")
model_path = args.scripted_model_dir / Path(f"nlgw_{prefix}_{device}_scripted.pt")
print(f"loading model {model_path}...")
model = torch.jit.load(model_path)

# run model inference
pred = model(torch.tensor(input_data).to(device))
with torch.no_grad():
pred = model(torch.tensor(input_data).to(device))

pred = pred.cpu().detach().numpy()
print("pred.shape = ", pred.shape)
Expand Down
22 changes: 14 additions & 8 deletions era5_training/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
print(f"output_dir={args.output_dir}")
print(f"script={args.script}")

bs_train = 20 # 80 (80 works for most). (does not work for global uvthetaw)
bs_train = 5 # 20 # 80 (80 works for most). (does not work for global uvthetaw)
bs_test = bs_train

# --------------------------------------------------
Expand Down Expand Up @@ -136,11 +136,13 @@
odir = str(args.output_dir) + "/"
pref = str(args.ckpt_dir) + "/" # "/scratch/users/ag4680/torch_saved_models/attention_unet/"
if model == "ann":
ckpt = f"ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}__train_epoch{epoch}.pt"
# ckpt = f"retrained_ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}__train_epoch{epoch}.pt"
ckpt = f"retrained_L93_ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}__train_epoch{epoch}.pt"
log_filename = f"./{teston}_inference_ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_{features}_ckpt_epoch_{epoch}.txt"
elif model == "attention":
ckpt = (
f"attnunet_era5_{domain}_{vertical}_{features}_mseloss_train_epoch{str(epoch).zfill(2)}.pt"
# f"attnunet_era5_{domain}_{vertical}_{features}_mseloss_train_epoch{str(epoch).zfill(2)}.pt"
f"retrained_L93_attnunet_era5_{domain}_{vertical}_{features}_mseloss_train_epoch{epoch}.pt"
)
log_filename = (
f"./{teston}_inference_attnunet_{domain}_{vertical}_{features}_ckpt_epoch_{epoch}.txt"
Expand All @@ -157,7 +159,7 @@
# Define test files
# ------- To test on one year of ERA5 data
test_files = []
test_years = np.array([2010])
test_years = np.array([2015])
test_month = args.month # int(sys.argv[4]) # np.arange(1,13)
logger.info(f"Inference for month {test_month}")
if teston == "era5":
Expand All @@ -174,7 +176,7 @@
)
elif vertical == "global" or vertical == "stratosphere_update":
if stencil == 1:
pre = idir + f"1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_"
pre = idir + f"1x1_inputfeatures_u_v_theta_w_uw_vw_gcp_era5_training_data_hourly_"
else:
pre = (
idir
Expand All @@ -183,7 +185,10 @@

for year in test_years:
for months in np.arange(test_month, test_month + 1):
test_files.append(f"{pre}{year}_constant_mu_sigma_scaling{str(months).zfill(2)}.nc")
# test_files.append(f"{pre}{year}_constant_mu_sigma_scaling{str(months).zfill(2)}.nc") # usual
test_files.append(
f"{pre}{year}_L93_constant_mu_sigma_scaling{str(months).zfill(2)}.nc"
) # L93

elif teston == "ifs":
if vertical == "stratosphere_only":
Expand Down Expand Up @@ -215,10 +220,11 @@
features=features,
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=bs_test, drop_last=False, shuffle=False, num_workers=8
testset, batch_size=bs_test, drop_last=False, shuffle=False, num_workers=0
)

idim = testset.idim

odim = testset.odim
hdim = 4 * idim

Expand Down Expand Up @@ -252,7 +258,7 @@
files=test_files, domain=domain, vertical=vertical, manual_shuffle=False, features=features
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=bs_train, drop_last=False, shuffle=False, num_workers=8
testset, batch_size=bs_train, drop_last=False, shuffle=False, num_workers=0
)

ch_in = testset.idim
Expand Down
Loading