Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for Batched radius search, which enables Domino
and GeoTransolver with local features and batch size > 1.
- Added the underfill recipe.
- Adds Functional Generative Networks (FGN) weather training example
(`examples/weather/fgn`). Implements the latent-conditioned U-Net
stochastic generator from
`arXiv:2506.10772 <https://arxiv.org/abs/2506.10772>`_ (WeatherNext 2)
as a PhysicsNeMo ``Module``, trained with fair-CRPS loss on ERA5 via the
earth2studio ARCO data source. Supports autoregressive rollout training
with per-channel normalization, FSDP + ShardTensor domain parallelism,
deep-ensemble inference (paper §2.2.1), and validation diagnostics
(CRPS, RMSE, spread-skill, rank histograms, power spectra).

### Changed

Expand Down
9 changes: 9 additions & 0 deletions examples/weather/fgn/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
*.mlus
*.png
*.pt
*.tfevents*
*wandb/
rundir/
logs/
*.npz
FGN.md
266 changes: 266 additions & 0 deletions examples/weather/fgn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
<!-- markdownlint-disable -->
# Functional Generative Networks for Weather Forecasting

A PhysicsNeMo implementation of Functional Generative Networks (FGN) for
probabilistic global weather forecasting, following the approach of:

> Alet et al., "Skillful joint probabilistic weather forecasting from marginals"
> ([arXiv:2506.10772](https://arxiv.org/abs/2506.10772))

FGN is the architecture behind the production
[WeatherNext 2](https://developers.google.com/weathernext/guides/models) model,
which delivers 64-member ensemble forecasts (4 independently trained seeds ×
16 trajectories each) at 0.25° global resolution. The full variable schema is
described in the
[WeatherNext 2 model specs](https://developers.google.com/weathernext/guides/model-specs-vmg).

FGN generates ensemble weather forecasts by perturbing a deterministic backbone
with a low-dimensional latent noise vector `z ~ N(0, I_32)` injected through
conditional layer normalization (CLN) at every layer, producing globally coherent
ensemble spread from a marginal (fair-CRPS) training loss. Multiple independently
trained model seeds form a deep ensemble (J=4 seeds, 16 trajectories each = 64
members in production) capturing both aleatoric and epistemic uncertainty.

## Problem Overview

FGN autoregressively predicts the next 6-hour atmospheric state from the two
previous states (`X_{t-2}`, `X_{t-1}`), sampled from ERA5 (pre-training) and
HRES-fc0 (fine-tuning) at 0.25° global resolution. Each forward pass is
non-diffusive: one pass per forecast step, with a fresh `z` drawn per step per
ensemble member. AR fine-tuning with rollouts up to 8 steps (Table A.2) improves
temporal coherence without requiring a diffusion sampler.

This example implements:

- Latent-conditioned `FGNUNet` backbone (`utils/nn.py`) with AdaGN modulation
- ARCO-backed real dataset using `earth2studio.data.ARCO` (`datasets/arco.py`)
- Fair-CRPS training loss with paper-faithful per-variable and area weights (`utils/loss.py`)
- Autoregressive rollout training with BPTT (`utils/trainer.py`)
- Multi-stage AR schedule runner (Table A.2: `8k·1AR → 4k·2AR → 1k·{3..8}AR`)
- Validation metrics and plots: CRPS, RMSE, spread-skill, rank histograms,
power spectra (`utils/metrics.py`)
- FSDP + ShardTensor distributed training via `ParallelHelper` (`utils/parallel.py`)
- Deep ensemble inference across multiple independently trained checkpoints
- Per-channel normalization stats with Welford online estimation

## Dataset

Training data is fetched live from the [ARCO ERA5](https://cloud.google.com/storage/docs/public-datasets/era5)
dataset via `earth2studio.data.ARCO`. No local download is required for training.

The dataset covers the full 83-channel Table A.1 schema: 78 atmospheric channels
(6 variables × 13 pressure levels: 50–1000 hPa) plus 5 input/predicted surface
channels (`t2m`, `u10m`, `v10m`, `msl`, `sst`) and `tp06` (6-h accumulated
precipitation, predicted-only). Static inputs (surface geopotential, land-sea mask)
and clock features (local time, year progress sin/cos) are added automatically.

All variables use compact Earth2Studio / PhysicsNeMo names: `u10m`, `v10m`, `t2m`,
`msl`, `sst`, `tp06`, `z{level}`, `q{level}`, `t{level}`, `u{level}`, `v{level}`,
`w{level}`.

> **Note:** The production WeatherNext 2 output also includes `u100m` / `v100m`
> (100 m wind components). ERA5 via ARCO does not provide 100 m winds, so they
> are omitted from this ERA5-based training example.

### Normalization Stats

Pre-compute per-channel mean and standard deviation before training:

```bash
python scripts/compute_arco_stats.py \
--start 2020-01-01 --end 2023-12-31 \
--output rundir/fgn_2024_val/stats_2024.npz
```

Pass the resulting `.npz` file to the trainer via `dataset.stats_path`.

## Getting Started

### Requirements

```bash
pip install -r requirements.txt
```

PyTorch 2.10 or higher is required for domain parallelism.

### Smoke Test

Run the self-contained synthetic test suite (no GPU, no network access):

```bash
pytest test_training.py
```

Multi-GPU tests require `torchrun`:

```bash
torchrun --standalone --nproc_per_node=2 --no-python pytest test_training.py
```

## Configuration

Training is configured with [Hydra](https://hydra.cc) and validated with Pydantic
(`utils/config.py`). Configs live under `config/`:

- `config/fgn.yaml` — base defaults (model, training, dataset structure)
- `config/fgn_arco.yaml` — ARCO real-data training config (inherits from `fgn.yaml`)
- `config/test_fgn.yaml` — fast synthetic smoke-test config

Key config knobs:

| Setting | Description |
|---|---|
| `model.hidden_channels` | U-Net channel width (64 for quick runs, 256+ for full scale) |
| `model.latent_dim` | Latent noise dimension (32, per paper) |
| `training.batch_size` | Global batch size; local per-GPU = `batch_size / data_parallel_size` |
| `training.ar_steps` | AR rollout length for loss (1 = single-step pre-training) |
| `training.loss.num_samples` | Ensemble members per training example (N=2 per paper) |
| `training.domain_parallel_size` | GPUs per sample for domain parallelism (1 = pure DDP) |
| `dataset.stats_path` | Path to `.npz` normalization stats |

Training outputs (checkpoints, logs, plots) are saved to:

```
rundir/{training.experiment_name}/{training.run_id}/
```

## Training

### Single GPU

```bash
python train.py --config-name fgn_arco \
dataset.stats_path=rundir/fgn_2024_val/stats_2024.npz \
training.experiment_name=fgn_run \
training.batch_size=1
```

### Multi-GPU (torchrun)

```bash
torchrun --standalone --nnodes=1 --nproc_per_node=2 \
train.py --config-name fgn_arco \
dataset.stats_path=rundir/fgn_2024_val/stats_2024.npz \
training.experiment_name=fgn_run \
training.batch_size=2
```

With 2 GPUs and `domain_parallel_size=1` (DDP), `batch_size` is the global batch
size — each GPU processes `batch_size / 2` samples.

### SLURM

```bash
sbatch scripts/train_fgn.sh
```

Override defaults via environment variables:

```bash
sbatch --export=ALL,EXP_NAME=fgn_2024,RUN_ID=1,STEPS=10000 scripts/train_fgn.sh
```

See `scripts/train_fgn.sh` for all overridable variables (`EXP_NAME`, `RUN_ID`,
`STEPS`, `CFG`, `STATS_PATH`, `NGPU`).

### Resuming

Set `training.resume_checkpoint=latest` (default) to automatically resume from
the most recent checkpoint in the run directory.

### Domain Parallelism

For models too large to fit one sample on a single GPU, enable domain parallelism:

```bash
torchrun --standalone --nproc_per_node=4 \
train.py --config-name fgn_arco \
training.domain_parallel_size=2 \
training.batch_size=2
```

With `domain_parallel_size=2` and 4 GPUs: 2 domain-parallel pairs, each handling
1 sample (`batch_size / data_parallel_size = 2 / 2 = 1`).

### AR Fine-Tuning Schedule (Table A.2)

The trainer implements the paper's multi-stage AR schedule automatically when
`training.ar_steps` increases across runs. Start with single-step pre-training,
then resume with progressively longer rollouts:

| Stage | `ar_steps` | Steps | Notes |
|---|---|---|---|
| 1 | 1 | 8000 | Single-step pre-train |
| 2 | 2 | 4000 | Resume from stage 1 |
| 3–8 | 3–8 | 1000 each | Resume from previous |

## Inference

Run stochastic ensemble inference from a trained checkpoint:

```bash
python inference.py --config-name inference_fgn \
inference.checkpoint=rundir/fgn_run/0/checkpoints/FGNUNet.mdlus
```

For deep ensemble inference across multiple independently trained seeds:

```bash
python inference.py --config-name inference_fgn \
"inference.checkpoints=[seed0/FGNUNet.mdlus, seed1/FGNUNet.mdlus, seed2/FGNUNet.mdlus, seed3/FGNUNet.mdlus]"
```

Trajectories are distributed across checkpoints following paper §2.2.1.

### Bad-Seed Detection

Before including a checkpoint in a deep ensemble, check its spectral properties:

```bash
python scripts/check_spectra.py \
--checkpoint rundir/fgn_run/0/checkpoints/FGNUNet.mdlus \
--stats rundir/fgn_2024_val/stats_2024.npz
```

## Adding Custom Datasets

Implement the `FGNDataset` interface from `datasets/dataset.py`:

```python
class MyDataset(FGNDataset):
def state_channels(self) -> list[str]: ...
def background_channels(self) -> list[str]: ...
def image_shape(self) -> tuple[int, int]: ...
def __len__(self) -> int: ...
def __getitem__(self, idx): ...
# Optional:
def get_invariants(self) -> np.ndarray | None: ...
def output_only_channels(self) -> list[int]: ...
```

`__getitem__` should return a dict with keys `history` (shape `(T, C, H, W)`),
`target` (shape `(K, C, H, W)`), and optionally `background`. Register your
dataset by placing it in `datasets/` — it is discovered automatically via
`pkgutil.iter_modules` at import time.

## Memory Management

At 0.25° (721×1440), each training sample is large. Recommended settings for an
80 GB H100:

- `training.batch_size=2` (1 per GPU) with 2 GPUs, `domain_parallel_size=1`
- bf16 AMP is enabled automatically (`torch.autocast(bfloat16)`)
- `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` (set in `train_fgn.sh`)

For larger models (hidden_channels ≥ 256), use `domain_parallel_size=2` with
4+ GPUs, or enable gradient checkpointing via `model.checkpoint_level`.

## References

- [Skillful joint probabilistic weather forecasting from marginals](https://arxiv.org/abs/2506.10772)
- [WeatherNext 2 model overview](https://developers.google.com/weathernext/guides/models)
- [WeatherNext 2 variable schema](https://developers.google.com/weathernext/guides/model-specs-vmg)
- [Generative Ensemble Downscaling with Diffusion Models (CorrDiff)](https://arxiv.org/abs/2308.14453)
- [Kilometer-Scale Convection Allowing Model Emulation (StormCast)](https://arxiv.org/abs/2408.10958)
- [GraphCast: Learning skillful medium-range global weather forecasting](https://arxiv.org/abs/2212.12794)
71 changes: 71 additions & 0 deletions examples/weather/fgn/config/eval_fgn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Standalone evaluation config for FGN (§4 of arXiv:2506.10772).
#
# Usage:
# python eval.py --config-name eval_fgn \
# dataset.stats_path=rundir/fgn_2024_val/stats_2024.npz \
# eval.checkpoint=rundir/fgn_2024_long/0/checkpoints/FGNUNet.0.5000.mdlus
#
# To evaluate a deep ensemble pass a list:
# eval.checkpoints=[seed0/FGNUNet.mdlus,seed1/FGNUNet.mdlus]

defaults:
- fgn # model + dataset schema defaults
- _self_

dataset:
name: arco.ArcoFGNDataset
state_variables: null
invariant_variables:
- z
- lsm
step_hours: 6
history_frames: 2
# future_frames is overridden by eval.future_steps at runtime
future_frames: ${eval.future_steps}
val_start: "2024-10-01"
val_end: "2025-01-01"
# train_{start,end} unused in eval but required by the dataset schema
train_start: "2024-01-01"
train_end: "2024-10-01"
spatial_stride: 1
static_date: "2016-01-01"
arco_cache: true
stats_path: ???
tp_accumulation_hours: null

model:
latent_dim: 16
hidden_channels: 64

# training section is required by TrainMainConfig but unused during eval;
# keep it minimal so Pydantic passes validation.
training:
experiment_name: fgn_eval
run_id: "0"
batch_size: 1
total_train_steps: 1
ar_steps: ${eval.future_steps}

eval:
# Path to a single .mdlus checkpoint, or "latest" to auto-detect.
checkpoint: "latest"
# For deep-ensemble eval, list multiple checkpoints; overrides checkpoint.
checkpoints: null
# Number of AR steps forward (each step = step_hours hours).
# 20 steps = 5 days at 6h. Max ~40 steps (10 days) fits in memory.
future_steps: 20
# Ensemble members per checkpoint. Paper uses 56 total (14×4 seeds).
ensemble_size: 8
# Batch size for the eval DataLoader. Keep at 1 for full-resolution.
batch_size: 1
# Number of DataLoader workers.
num_workers: 0
# Output directory for plots + eval_metrics.npz.
outdir: "rundir/fgn_2024_long/0/eval"
# Pooled-CRPS cell sizes (number of 0.25° grid cells per side).
# [4,8,16,32] ≈ [120, 240, 480, 960] km.
pool_sizes: [4, 8, 16, 32]
30 changes: 30 additions & 0 deletions examples/weather/fgn/config/fgn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Base FGN config — required by train.py's @hydra.main(config_name="fgn").
# Provides model defaults and a skeleton dataset section; override dataset
# and training fields on the command line or via a derived config such as
# fgn_arco.yaml.
#
# Minimal usage (all dataset fields required as overrides):
# python train.py \
# dataset.name=arco.ArcoFGNDataset \
# dataset.stats_path=/path/to/stats.npz \
# [dataset.train_start=... dataset.train_end=... ...]

defaults:
- training: default
- _self_

dataset:
name: ??? # required — e.g. arco.ArcoFGNDataset

model:
model_name: fgn
history_frames: 2
latent_dim: 16
hidden_channels: 32
background_channels: auto
invariant_channels: auto
group_norm_groups: 8
Loading