diff --git a/CHANGELOG.md b/CHANGELOG.md index b4e938a567..a2184e4ab3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -149,6 +149,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for Batched radius search, which enables Domino and GeoTransolver with local features and batch size > 1. - Added the underfill recipe. +- Adds Functional Generative Networks (FGN) weather training example + (`examples/weather/fgn`). Implements the latent-conditioned U-Net + stochastic generator from + `arXiv:2506.10772 `_ (WeatherNext 2) + as a PhysicsNeMo ``Module``, trained with fair-CRPS loss on ERA5 via the + earth2studio ARCO data source. Supports autoregressive rollout training + with per-channel normalization, FSDP + ShardTensor domain parallelism, + deep-ensemble inference (paper §2.2.1), and validation diagnostics + (CRPS, RMSE, spread-skill, rank histograms, power spectra). ### Changed diff --git a/examples/weather/fgn/.gitignore b/examples/weather/fgn/.gitignore new file mode 100644 index 0000000000..988dd46269 --- /dev/null +++ b/examples/weather/fgn/.gitignore @@ -0,0 +1,12 @@ +*.mlus +*.png +*.pt +*.tfevents* +*wandb/ +rundir/ +logs/ +*.npz +FGN.md +# SLURM launcher scripts are cluster-specific; keep locally, don't track. +scripts/train_fgn.sh +scripts/eval_fgn.sh diff --git a/examples/weather/fgn/README.md b/examples/weather/fgn/README.md new file mode 100644 index 0000000000..20da94c99d --- /dev/null +++ b/examples/weather/fgn/README.md @@ -0,0 +1,266 @@ + +# Functional Generative Networks for Weather Forecasting + +A PhysicsNeMo implementation of Functional Generative Networks (FGN) for +probabilistic global weather forecasting, following the approach of: + +> Alet et al., "Skillful joint probabilistic weather forecasting from marginals" +> ([arXiv:2506.10772](https://arxiv.org/abs/2506.10772)) + +FGN is the architecture behind the production +[WeatherNext 2](https://developers.google.com/weathernext/guides/models) model, +which delivers 64-member ensemble forecasts (4 independently trained seeds × +16 trajectories each) at 0.25° global resolution. The full variable schema is +described in the +[WeatherNext 2 model specs](https://developers.google.com/weathernext/guides/model-specs-vmg). + +FGN generates ensemble weather forecasts by perturbing a deterministic backbone +with a low-dimensional latent noise vector `z ~ N(0, I_32)` injected through +conditional layer normalization (CLN) at every layer, producing globally coherent +ensemble spread from a marginal (fair-CRPS) training loss. Multiple independently +trained model seeds form a deep ensemble (J=4 seeds, 16 trajectories each = 64 +members in production) capturing both aleatoric and epistemic uncertainty. + +## Problem Overview + +FGN autoregressively predicts the next 6-hour atmospheric state from the two +previous states (`X_{t-2}`, `X_{t-1}`), sampled from ERA5 (pre-training) and +HRES-fc0 (fine-tuning) at 0.25° global resolution. Each forward pass is +non-diffusive: one pass per forecast step, with a fresh `z` drawn per step per +ensemble member. AR fine-tuning with rollouts up to 8 steps (Table A.2) improves +temporal coherence without requiring a diffusion sampler. + +This example implements: + +- Latent-conditioned `FGNUNet` backbone (`utils/nn.py`) with AdaGN modulation +- ARCO-backed real dataset using `earth2studio.data.ARCO` (`datasets/arco.py`) +- Fair-CRPS training loss with paper-faithful per-variable and area weights (`utils/loss.py`) +- Autoregressive rollout training with BPTT (`utils/trainer.py`) +- Multi-stage AR schedule runner (Table A.2: `8k·1AR → 4k·2AR → 1k·{3..8}AR`) +- Validation metrics and plots: CRPS, RMSE, spread-skill, rank histograms, + power spectra (`utils/metrics.py`) +- FSDP + ShardTensor distributed training via `ParallelHelper` (`utils/parallel.py`) +- Deep ensemble inference across multiple independently trained checkpoints +- Per-channel normalization stats with Welford online estimation + +## Dataset + +Training data is fetched live from the [ARCO ERA5](https://cloud.google.com/storage/docs/public-datasets/era5) +dataset via `earth2studio.data.ARCO`. No local download is required for training. + +The dataset covers the full 83-channel Table A.1 schema: 78 atmospheric channels +(6 variables × 13 pressure levels: 50–1000 hPa) plus 5 input/predicted surface +channels (`t2m`, `u10m`, `v10m`, `msl`, `sst`) and `tp06` (6-h accumulated +precipitation, predicted-only). Static inputs (surface geopotential, land-sea mask) +and clock features (local time, year progress sin/cos) are added automatically. + +All variables use compact Earth2Studio / PhysicsNeMo names: `u10m`, `v10m`, `t2m`, +`msl`, `sst`, `tp06`, `z{level}`, `q{level}`, `t{level}`, `u{level}`, `v{level}`, +`w{level}`. + +> **Note:** The production WeatherNext 2 output also includes `u100m` / `v100m` +> (100 m wind components). ERA5 via ARCO does not provide 100 m winds, so they +> are omitted from this ERA5-based training example. + +### Normalization Stats + +Pre-compute per-channel mean and standard deviation before training: + +```bash +python scripts/compute_arco_stats.py \ + --start 2020-01-01 --end 2023-12-31 \ + --output rundir/fgn_2024_val/stats_2024.npz +``` + +Pass the resulting `.npz` file to the trainer via `dataset.stats_path`. + +## Getting Started + +### Requirements + +```bash +pip install -r requirements.txt +``` + +PyTorch 2.10 or higher is required for domain parallelism. + +### Smoke Test + +Run the self-contained synthetic test suite (no GPU, no network access): + +```bash +pytest test_training.py +``` + +Multi-GPU tests require `torchrun`: + +```bash +torchrun --standalone --nproc_per_node=2 --no-python pytest test_training.py +``` + +## Configuration + +Training is configured with [Hydra](https://hydra.cc) and validated with Pydantic +(`utils/config.py`). Configs live under `config/`: + +- `config/fgn.yaml` — base defaults (model, training, dataset structure) +- `config/fgn_arco.yaml` — ARCO real-data training config (inherits from `fgn.yaml`) +- `config/test_fgn.yaml` — fast synthetic smoke-test config + +Key config knobs: + +| Setting | Description | +|---|---| +| `model.hidden_channels` | U-Net channel width (64 for quick runs, 256+ for full scale) | +| `model.latent_dim` | Latent noise dimension (32, per paper) | +| `training.batch_size` | Global batch size; local per-GPU = `batch_size / data_parallel_size` | +| `training.ar_steps` | AR rollout length for loss (1 = single-step pre-training) | +| `training.loss.num_samples` | Ensemble members per training example (N=2 per paper) | +| `training.domain_parallel_size` | GPUs per sample for domain parallelism (1 = pure DDP) | +| `dataset.stats_path` | Path to `.npz` normalization stats | + +Training outputs (checkpoints, logs, plots) are saved to: + +``` +rundir/{training.experiment_name}/{training.run_id}/ +``` + +## Training + +### Single GPU + +```bash +python train.py --config-name fgn_arco \ + dataset.stats_path=rundir/fgn_2024_val/stats_2024.npz \ + training.experiment_name=fgn_run \ + training.batch_size=1 +``` + +### Multi-GPU (torchrun) + +```bash +torchrun --standalone --nnodes=1 --nproc_per_node=2 \ + train.py --config-name fgn_arco \ + dataset.stats_path=rundir/fgn_2024_val/stats_2024.npz \ + training.experiment_name=fgn_run \ + training.batch_size=2 +``` + +With 2 GPUs and `domain_parallel_size=1` (DDP), `batch_size` is the global batch +size — each GPU processes `batch_size / 2` samples. + +### SLURM + +```bash +sbatch scripts/train_fgn.sh +``` + +Override defaults via environment variables: + +```bash +sbatch --export=ALL,EXP_NAME=fgn_2024,RUN_ID=1,STEPS=10000 scripts/train_fgn.sh +``` + +See `scripts/train_fgn.sh` for all overridable variables (`EXP_NAME`, `RUN_ID`, +`STEPS`, `CFG`, `STATS_PATH`, `NGPU`). + +### Resuming + +Set `training.resume_checkpoint=latest` (default) to automatically resume from +the most recent checkpoint in the run directory. + +### Domain Parallelism + +For models too large to fit one sample on a single GPU, enable domain parallelism: + +```bash +torchrun --standalone --nproc_per_node=4 \ + train.py --config-name fgn_arco \ + training.domain_parallel_size=2 \ + training.batch_size=2 +``` + +With `domain_parallel_size=2` and 4 GPUs: 2 domain-parallel pairs, each handling +1 sample (`batch_size / data_parallel_size = 2 / 2 = 1`). + +### AR Fine-Tuning Schedule (Table A.2) + +The trainer implements the paper's multi-stage AR schedule automatically when +`training.ar_steps` increases across runs. Start with single-step pre-training, +then resume with progressively longer rollouts: + +| Stage | `ar_steps` | Steps | Notes | +|---|---|---|---| +| 1 | 1 | 8000 | Single-step pre-train | +| 2 | 2 | 4000 | Resume from stage 1 | +| 3–8 | 3–8 | 1000 each | Resume from previous | + +## Inference + +Run stochastic ensemble inference from a trained checkpoint: + +```bash +python inference.py --config-name inference_fgn \ + inference.checkpoint=rundir/fgn_run/0/checkpoints/FGNUNet.mdlus +``` + +For deep ensemble inference across multiple independently trained seeds: + +```bash +python inference.py --config-name inference_fgn \ + "inference.checkpoints=[seed0/FGNUNet.mdlus, seed1/FGNUNet.mdlus, seed2/FGNUNet.mdlus, seed3/FGNUNet.mdlus]" +``` + +Trajectories are distributed across checkpoints following paper §2.2.1. + +### Bad-Seed Detection + +Before including a checkpoint in a deep ensemble, check its spectral properties: + +```bash +python scripts/check_spectra.py \ + --checkpoint rundir/fgn_run/0/checkpoints/FGNUNet.mdlus \ + --stats rundir/fgn_2024_val/stats_2024.npz +``` + +## Adding Custom Datasets + +Implement the `FGNDataset` interface from `datasets/dataset.py`: + +```python +class MyDataset(FGNDataset): + def state_channels(self) -> list[str]: ... + def background_channels(self) -> list[str]: ... + def image_shape(self) -> tuple[int, int]: ... + def __len__(self) -> int: ... + def __getitem__(self, idx): ... + # Optional: + def get_invariants(self) -> np.ndarray | None: ... + def output_only_channels(self) -> list[int]: ... +``` + +`__getitem__` should return a dict with keys `history` (shape `(T, C, H, W)`), +`target` (shape `(K, C, H, W)`), and optionally `background`. Register your +dataset by placing it in `datasets/` — it is discovered automatically via +`pkgutil.iter_modules` at import time. + +## Memory Management + +At 0.25° (721×1440), each training sample is large. Recommended settings for an +80 GB H100: + +- `training.batch_size=2` (1 per GPU) with 2 GPUs, `domain_parallel_size=1` +- bf16 AMP is enabled automatically (`torch.autocast(bfloat16)`) +- `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` (set in `train_fgn.sh`) + +For larger models (hidden_channels ≥ 256), use `domain_parallel_size=2` with +4+ GPUs, or enable gradient checkpointing via `model.checkpoint_level`. + +## References + +- [Skillful joint probabilistic weather forecasting from marginals](https://arxiv.org/abs/2506.10772) +- [WeatherNext 2 model overview](https://developers.google.com/weathernext/guides/models) +- [WeatherNext 2 variable schema](https://developers.google.com/weathernext/guides/model-specs-vmg) +- [Generative Ensemble Downscaling with Diffusion Models (CorrDiff)](https://arxiv.org/abs/2308.14453) +- [Kilometer-Scale Convection Allowing Model Emulation (StormCast)](https://arxiv.org/abs/2408.10958) +- [GraphCast: Learning skillful medium-range global weather forecasting](https://arxiv.org/abs/2212.12794) diff --git a/examples/weather/fgn/config/eval_fgn.yaml b/examples/weather/fgn/config/eval_fgn.yaml new file mode 100644 index 0000000000..cdca0578ea --- /dev/null +++ b/examples/weather/fgn/config/eval_fgn.yaml @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Standalone evaluation config for FGN (§4 of arXiv:2506.10772). +# +# Usage: +# python eval.py --config-name eval_fgn \ +# dataset.stats_path=rundir/fgn_2024_val/stats_2024.npz \ +# eval.checkpoint=rundir/fgn_2024_long/0/checkpoints/FGNUNet.0.5000.mdlus +# +# To evaluate a deep ensemble pass a list: +# eval.checkpoints=[seed0/FGNUNet.mdlus,seed1/FGNUNet.mdlus] + +defaults: + - fgn # model + dataset schema defaults + - _self_ + +dataset: + name: arco.ArcoFGNDataset + state_variables: null + invariant_variables: + - z + - lsm + step_hours: 6 + history_frames: 2 + # future_frames is overridden by eval.future_steps at runtime + future_frames: ${eval.future_steps} + val_start: "2024-10-01" + val_end: "2025-01-01" + # train_{start,end} unused in eval but required by the dataset schema + train_start: "2024-01-01" + train_end: "2024-10-01" + spatial_stride: 1 + static_date: "2016-01-01" + arco_cache: true + stats_path: ??? + tp_accumulation_hours: null + +model: + latent_dim: 16 + hidden_channels: 64 + +# training section is required by TrainMainConfig but unused during eval; +# keep it minimal so Pydantic passes validation. +training: + experiment_name: fgn_eval + run_id: "0" + batch_size: 1 + total_train_steps: 1 + ar_steps: ${eval.future_steps} + +eval: + # Path to a single .mdlus checkpoint, or "latest" to auto-detect. + checkpoint: "latest" + # For deep-ensemble eval, list multiple checkpoints; overrides checkpoint. + checkpoints: null + # Number of AR steps forward (each step = step_hours hours). + # 20 steps = 5 days at 6h. Max ~40 steps (10 days) fits in memory. + future_steps: 20 + # Ensemble members per checkpoint. Paper uses 56 total (14×4 seeds). + ensemble_size: 8 + # Batch size for the eval DataLoader. Keep at 1 for full-resolution. + batch_size: 1 + # Number of DataLoader workers. + num_workers: 0 + # Output directory for plots + eval_metrics.npz. + outdir: "rundir/fgn_2024_long/0/eval" + # Pooled-CRPS cell sizes (number of 0.25° grid cells per side). + # [4,8,16,32] ≈ [120, 240, 480, 960] km. + pool_sizes: [4, 8, 16, 32] diff --git a/examples/weather/fgn/config/fgn.yaml b/examples/weather/fgn/config/fgn.yaml new file mode 100644 index 0000000000..3f2614c8cb --- /dev/null +++ b/examples/weather/fgn/config/fgn.yaml @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Base FGN config — required by train.py's @hydra.main(config_name="fgn"). +# Provides model defaults and a skeleton dataset section; override dataset +# and training fields on the command line or via a derived config such as +# fgn_arco.yaml. +# +# Minimal usage (all dataset fields required as overrides): +# python train.py \ +# dataset.name=arco.ArcoFGNDataset \ +# dataset.stats_path=/path/to/stats.npz \ +# [dataset.train_start=... dataset.train_end=... ...] + +defaults: + - training: default + - _self_ + +dataset: + name: ??? # required — e.g. arco.ArcoFGNDataset + +model: + model_name: fgn + history_frames: 2 + latent_dim: 16 + hidden_channels: 32 + background_channels: auto + invariant_channels: auto + group_norm_groups: 8 diff --git a/examples/weather/fgn/config/fgn_arco.yaml b/examples/weather/fgn/config/fgn_arco.yaml new file mode 100644 index 0000000000..1263d4a3d5 --- /dev/null +++ b/examples/weather/fgn/config/fgn_arco.yaml @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# ARCO ERA5 training config — practical starting point for a single-GPU run +# that reproduces the paper's channel set and loss settings (arXiv:2506.10772). +# +# Prerequisites: +# 1. Run scripts/compute_arco_stats.py to create the stats file. +# 2. Set dataset.stats_path to that file (override below or on the CLI). +# +# Usage (Stage-1 single-step training): +# python train.py --config-name fgn_arco \ +# dataset.stats_path=/path/to/stats.npz \ +# training.experiment_name=fgn_era5 training.run_id=0 +# +# Usage (Stage-4 AR finetune, paper Table A.2): +# python scripts/stage4_ar_schedule.py \ +# --config-name fgn_arco \ +# --rundir /path/to/stage4 \ +# --stats-path /path/to/stats.npz + +defaults: + - fgn # pulls in training/default + model defaults + - _self_ + +dataset: + name: arco.ArcoFGNDataset + # Full ERA5 variables from paper §2.1 / Table A.1. Set to null to use + # the dataset's built-in default variable list. + state_variables: null + invariant_variables: + - z # geopotential at surface (orography) + - lsm # land-sea mask + step_hours: 6 + history_frames: 2 + future_frames: ${training.ar_steps} + # Use 2018–2022 for training, 2023 for validation (paper split). + train_start: "2018-01-01" + train_end: "2022-12-31" + val_start: "2023-01-01" + val_end: "2023-12-31" + spatial_stride: 1 # full 0.25° resolution (721 × 1440) + static_date: "2016-01-01" + arco_cache: true + stats_path: ??? # required — run scripts/compute_arco_stats.py first + tp_accumulation_hours: null # set to 6 to include tp06 (paper §3) + +model: + latent_dim: 16 + hidden_channels: 64 # larger than smoke-test default (32); ~4× more params + +training: + experiment_name: fgn_era5 + run_id: "0" + batch_size: 4 + total_train_steps: 5000 # ~1 h on a single H100 at full resolution + print_progress_freq: 100 + checkpoint_freq: 500 + validation_freq: 500 + validation_steps: 8 # batches per rank per validation call + validation_metrics: true + validation_ensemble_size: 4 + ar_steps: 1 # Stage 1; use stage4_ar_schedule.py for ramp + loss: + num_samples: 4 + mse_weight: 0.1 + use_channel_weights: true # GraphCast-style per-variable weights (§2.2.3) + use_area_weights: true # cos(lat) area weighting + optimizer: + lr: 3.0e-4 diff --git a/examples/weather/fgn/config/training/default.yaml b/examples/weather/fgn/config/training/default.yaml new file mode 100644 index 0000000000..de26a3d0c5 --- /dev/null +++ b/examples/weather/fgn/config/training/default.yaml @@ -0,0 +1,49 @@ +outdir: rundir +experiment_name: fgn +run_id: "0" +rundir: ${training.outdir}/${training.experiment_name}/${training.run_id} +checkpoint_dir: checkpoints +num_data_workers: 0 +seed: 7 +batch_size: 8 +total_train_steps: 100 +print_progress_freq: 10 +checkpoint_freq: 50 +validation_freq: 25 +resume_checkpoint: latest +clip_grad_norm: -1.0 + +# Autoregressive rollout depth per training step. +# Must match dataset.future_frames. 1 = single-step training (paper Stages +# 1-3). Stage 4 ramps ar_steps up to 8 per Table A.2. +ar_steps: 1 + +# Data + domain parallelism knobs (mirrors StormCast's convention): +# - domain_parallel_size=1 & force_sharding=false → pure single-process or +# plain DDP; smoke-test default. +# - domain_parallel_size>1 → ShardTensor on a spatial axis of the input. +# - force_sharding=true → wrap with ShardTensor even at domain size 1, +# to exercise the sharding code path end-to-end. +domain_parallel_size: 1 +force_sharding: false + +# Number of validation batches per rank when running under ParallelHelper — +# the rank-sharded sampler is infinite by design (StormCast convention), +# so we cap iteration here. Null = use one local epoch (len(valid_dataset) +# / (world_size * batch_size)). +validation_steps: null + +# Figure 2 + 3 validation diagnostics (per-variable CRPS / RMSE / +# spread-skill / rank hist / 1D power spectra). Off by default so the +# smoke test stays cheap. +validation_metrics: false +validation_ensemble_size: 4 + +optimizer: + lr: 3.0e-4 + betas: [0.9, 0.999] + weight_decay: 1.0e-4 + +loss: + num_samples: 4 + mse_weight: 0.1 diff --git a/examples/weather/fgn/datasets/__init__.py b/examples/weather/fgn/datasets/__init__.py new file mode 100644 index 0000000000..d29899e794 --- /dev/null +++ b/examples/weather/fgn/datasets/__init__.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Auto-discovery registry of FGNDataset subclasses. + +Mirrors ``examples/weather/stormcast/datasets/__init__.py``: +scan every module in this package and register any class that subclasses +``FGNDataset`` under the key ``"."``. + +Usage:: + + from datasets import dataset_classes + cls = dataset_classes["arco.ArcoFGNDataset"] +""" + +from __future__ import annotations + +import importlib +import pathlib +import pkgutil + +from .dataset import FGNDataset + +_pkg_dir = str(pathlib.Path(__file__).parent) +dataset_classes: dict[str, type[FGNDataset]] = {} + +for _mod_info in pkgutil.iter_modules([_pkg_dir]): + if _mod_info.name == "dataset": + continue + _module = importlib.import_module(f"datasets.{_mod_info.name}") + for _name, _member in _module.__dict__.items(): + if ( + _name != "FGNDataset" + and isinstance(_member, type) + and issubclass(_member, FGNDataset) + ): + dataset_classes[f"{_mod_info.name}.{_name}"] = _member diff --git a/examples/weather/fgn/datasets/arco.py b/examples/weather/fgn/datasets/arco.py new file mode 100644 index 0000000000..c855bb5ace --- /dev/null +++ b/examples/weather/fgn/datasets/arco.py @@ -0,0 +1,517 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""ARCO/ERA5 dataset for FGN, following Appendix A.1 of arXiv:2506.10772v1. + +Wraps :class:`earth2studio.data.ARCO` so data fetching, caching, and the +compact-name lexicon live in earth2studio. This module only turns a sample +index into a `(history, target, background)` triple at a fixed 6-hour +stride, applies the SST land-NaN imputation described in the paper, and +computes clock / invariant features locally. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any + +import numpy as np +import torch + +from .dataset import FGNDataset + +# Paper Table A.1 atmospheric schema +PAPER_ATMOS_VARS: tuple[str, ...] = ("z", "q", "t", "u", "v", "w") +PAPER_LEVELS: tuple[int, ...] = ( + 50, + 100, + 150, + 200, + 250, + 300, + 400, + 500, + 600, + 700, + 850, + 925, + 1000, +) +PAPER_SURFACE_IN_OUT: tuple[str, ...] = ("t2m", "u10m", "v10m", "msl", "sst") + +# 6 atmospheric * 13 levels + 5 surface input/predicted = 83 state channels +# `tp` is predicted-only (6-h accumulation) and is not currently handled here; +# it belongs in the target-only output variables. +DEFAULT_STATE: tuple[str, ...] = tuple( + [f"{v}{lvl}" for v in PAPER_ATMOS_VARS for lvl in PAPER_LEVELS] + + list(PAPER_SURFACE_IN_OUT) +) + +# Static fields available from ARCO directly via the compact-name lexicon. +ARCO_STATIC_VARS: frozenset[str] = frozenset({"z", "lsm"}) +# Computed locally from the grid, not from ARCO. +LOCAL_INVARIANTS: frozenset[str] = frozenset({"lat", "lon"}) + +# Clock features are computed locally from the target timestamp. +CLOCK_CHANNELS: tuple[str, ...] = ( + "local_time_sin", + "local_time_cos", + "year_progress_sin", + "year_progress_cos", +) + +ARCO_LAT = np.linspace(90, -90, 721, dtype=np.float32) +ARCO_LON = np.linspace(0, 359.75, 1440, dtype=np.float32) + + +def _parse_date(value: Any) -> datetime: + if isinstance(value, datetime): + return value + return datetime.fromisoformat(str(value)) + + +class ArcoFGNDataset(FGNDataset): + """ERA5 training dataset for FGN, served through earth2studio ARCO. + + Each sample returns a history of ``history_frames`` state tensors at a + ``step_hours`` stride and the next state tensor as the target. The + default state variable list matches Table A.1 of arXiv:2506.10772v1. + + Parameters expected on ``params`` (Hydra DictConfig or similar): + + - ``state_variables`` (list[str], optional): compact ARCO names; defaults + to the 83-channel Table A.1 list. + - ``invariant_variables`` (list[str]): subset of ``{"z", "lsm", "lat", + "lon"}``; defaults to ``["z", "lsm"]``. + - ``train_start`` / ``train_end`` / ``val_start`` / ``val_end`` (str): + ISO-8601 dates defining the split. + - ``step_hours`` (int, default 6): temporal stride between frames. + - ``history_frames`` (int, default 2): number of prior frames used as + input. Paper uses 2. + - ``spatial_stride`` (int, default 1): sub-sample the ARCO 721x1440 grid + for cheaper dev runs. + - ``static_date`` (str, default ``"2016-01-01"``): date used to fetch + the truly-static ``z``/``lsm`` fields once. + - ``arco_cache`` (bool, default True): earth2studio ARCO local cache. + - ``tp_accumulation_hours`` (int or None, default None): when set, the + state variable named ``tp{tp_accumulation_hours}`` (e.g. ``tp06``) is + treated as a paper §3 predicted-only accumulated precipitation + channel: history values are forced to zero (matching HRES-fc0 + initialisation and the earth2studio ``gencast_mini`` convention for + ``tp12``) and target values are computed as the sum of + ``tp_accumulation_hours`` hourly ARCO ``tp`` values leading up to + each target timestamp. Requires ``tp{N}`` to appear in + ``state_variables`` so the channel exists in both input and output + tensors. + """ + + def __init__(self, params: Any, train: bool) -> None: + def _get(name: str, default: Any) -> Any: + return ( + getattr(params, name, default) + if hasattr(params, name) + else ( + params[name] + if isinstance(params, dict) and name in params + else default + ) + ) + + state = _get("state_variables", None) + self._state_variables: list[str] = list(state) if state else list(DEFAULT_STATE) + invariants = _get("invariant_variables", None) + self._invariant_variables: list[str] = ( + list(invariants) if invariants is not None else ["z", "lsm"] + ) + for v in self._invariant_variables: + if v not in ARCO_STATIC_VARS and v not in LOCAL_INVARIANTS: + raise ValueError( + f"invariant_variables entry {v!r} is not supported; " + f"expected one of {sorted(ARCO_STATIC_VARS | LOCAL_INVARIANTS)}" + ) + + self.step_hours = int(_get("step_hours", 6)) + self.history_frames = int(_get("history_frames", 2)) + if self.history_frames < 1: + raise ValueError("history_frames must be >= 1") + self.future_frames = int(_get("future_frames", 1)) + if self.future_frames < 1: + raise ValueError("future_frames must be >= 1") + + if train: + start = _get("train_start", "1979-01-01") + end = _get("train_end", "2018-01-15") + else: + start = _get("val_start", "2018-01-15") + end = _get("val_end", "2019-01-01") + self.start: datetime = _parse_date(start) + self.end: datetime = _parse_date(end) + + self.stride = int(_get("spatial_stride", 1)) + if self.stride < 1: + raise ValueError(f"spatial_stride must be >= 1, got {self.stride}") + self.height = len(ARCO_LAT[:: self.stride]) + self.width = len(ARCO_LON[:: self.stride]) + + self.static_date = _parse_date(_get("static_date", "2016-01-01")) + self.arco_cache = bool(_get("arco_cache", True)) + + # Paper §3 "Total precipitation" handling. When tp_accumulation_hours + # is set, the state variable literally named ``tp`` (e.g. + # ``tp06`` for the paper's 6-hour accumulation) is: + # - zeroed in history (predicted-only — matches HRES-fc0 init + # and earth2studio gencast_mini's tp12 convention) + # - computed in target by summing N hourly ARCO ``tp`` values + # leading up to each target timestamp. + self.tp_accumulation_hours: int | None = None + self._tp_channel_idx: int | None = None + self._tp_state_name: str | None = None + tp_hours = _get("tp_accumulation_hours", None) + if tp_hours is not None: + self.tp_accumulation_hours = int(tp_hours) + if self.tp_accumulation_hours < 1: + raise ValueError( + f"tp_accumulation_hours must be >= 1, got {self.tp_accumulation_hours}" + ) + tp_name = f"tp{self.tp_accumulation_hours:02d}" + if tp_name not in self._state_variables: + raise ValueError( + f"tp_accumulation_hours={self.tp_accumulation_hours} requires " + f"{tp_name!r} in state_variables; got {self._state_variables}" + ) + self._tp_state_name = tp_name + self._tp_channel_idx = self._state_variables.index(tp_name) + + # Optional per-channel z-score stats. File layout: an .npz with + # arrays ``mean`` and ``std``, each of shape ``(len(state_variables),)`` + # in the same order as ``state_variables``. Mirrors the StormCast + # convention (means.npy / stds.npy) but packaged as one file to + # avoid order-mismatch bugs when variable lists change. + self._mean: np.ndarray | None = None + self._std: np.ndarray | None = None + stats_path = _get("stats_path", None) + if stats_path: + self._load_stats(str(stats_path)) + + # Sample count: for each target index i, + # last_target_time(i) = start + (i + history_frames + future_frames - 1) * step + # and we need last_target_time <= end. + total_hours = max(0, int((self.end - self.start).total_seconds() // 3600)) + total_steps = total_hours // self.step_hours + window = self.history_frames + self.future_frames - 1 + self.num_samples = max(0, total_steps - window) + if self.num_samples <= 0: + raise ValueError( + "Date range is shorter than (history_frames + future_frames) * step_hours; " + f"start={self.start}, end={self.end}, step_hours={self.step_hours}" + ) + + # Lazy, per-worker ARCO client + self._arco = None + self._invariants_cache: np.ndarray | None = None + + # --- FGNDataset interface ------------------------------------------------ + + def state_channels(self) -> list[str]: + return list(self._state_variables) + + def background_channels(self) -> list[str]: + return list(CLOCK_CHANNELS) + + def image_shape(self) -> tuple[int, int]: + return (self.height, self.width) + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + # ParallelHelper.sharded_dataloader yields numpy.int64 indices from a + # numpy index array; Python 3.12's timedelta rejects np.int64. + idx = int(idx) + if idx < 0 or idx >= self.num_samples: + raise IndexError(idx) + + first_target_time = self._target_time(idx) + total_frames = self.history_frames + self.future_frames + # Window start is the time of history[0]. + window_start = first_target_time - timedelta( + hours=self.history_frames * self.step_hours + ) + times = [ + window_start + timedelta(hours=k * self.step_hours) + for k in range(total_frames) + ] + + # Fetch real state variables from ARCO. The tp{N} placeholder name + # (e.g. "tp06") is NOT in the ARCOLexicon — that's a derived + # accumulation the paper §3 defines on top of the hourly ``tp`` + # variable — so we exclude it from the bulk fetch and fill its + # slot from _fetch_tp_accumulation below. + if self._tp_channel_idx is not None: + ci = self._tp_channel_idx + arco_vars = [v for i, v in enumerate(self._state_variables) if i != ci] + else: + ci = None + arco_vars = list(self._state_variables) + + da = self._ensure_arco()(time=times, variable=arco_vars) + fetched = np.asarray(da.values, dtype=np.float32) # (T, V-1 or V, 721, 1440) + if self.stride > 1: + fetched = fetched[..., :: self.stride, :: self.stride] + + # Re-embed the fetched channels into the full (T, V, H, W) layout + # with a zero tp{N} slot we'll overwrite immediately after. This + # keeps channel indexing in lockstep with ``self._state_variables``. + T = fetched.shape[0] + V = len(self._state_variables) + arr = np.zeros((T, V, self.height, self.width), dtype=np.float32) + if ci is None: + arr[:] = fetched + else: + arr[:, :ci] = fetched[:, :ci] + arr[:, ci + 1 :] = fetched[:, ci:] + + self._impute_sst_nans_(arr) + + # Paper §3: replace the tp{N} channel with N-hour accumulation for + # every target frame, keep history zeroed (predicted-only — "FGN + # is trained to only output tp, not taking it as input"). Done + # here before normalization so z-score stats apply to the + # accumulated values. + if ci is not None: + tp_acc = self._fetch_tp_accumulation(times) # (T, H, W) + arr[self.history_frames :, ci, :, :] = tp_acc[self.history_frames :] + + if self._mean is not None: + # Broadcast (V,) stats over (T, V, H, W). + arr = (arr - self._mean[None, :, None, None]) / self._std[ + None, :, None, None + ] + + history = arr[: self.history_frames] + target = arr[self.history_frames :] # (future_frames, V, H, W) + # Clock features are computed for the first target step only for now; + # if downstream code wants per-step clocks, extend to (future_frames, 4, H, W). + background = self._clock_features(first_target_time) + + return { + "history": torch.from_numpy(history), + "target": torch.from_numpy(target), + "background": torch.from_numpy(background), + } + + def output_only_channels(self) -> list[int]: + if self._tp_channel_idx is None: + return [] + return [self._tp_channel_idx] + + def get_invariants(self) -> np.ndarray | None: + if not self._invariant_variables: + return None + if self._invariants_cache is not None: + return self._invariants_cache + + pieces: list[np.ndarray] = [] + arco_vars = [v for v in self._invariant_variables if v in ARCO_STATIC_VARS] + if arco_vars: + da = self._ensure_arco()(time=[self.static_date], variable=arco_vars) + raw = np.asarray(da.values, dtype=np.float32)[0] # (V, 721, 1440) + if self.stride > 1: + raw = raw[..., :: self.stride, :: self.stride] + pieces.extend(raw[i] for i in range(raw.shape[0])) + + if "lat" in self._invariant_variables: + pieces.append( + np.broadcast_to( + ARCO_LAT[:: self.stride, None], (self.height, self.width) + ) + .astype(np.float32) + .copy() + ) + if "lon" in self._invariant_variables: + pieces.append( + np.broadcast_to( + ARCO_LON[None, :: self.stride], (self.height, self.width) + ) + .astype(np.float32) + .copy() + ) + + # Reorder to match `self._invariant_variables` order. + by_name: dict[str, np.ndarray] = {} + cursor = 0 + for v in arco_vars: + by_name[v] = pieces[cursor] + cursor += 1 + if "lat" in self._invariant_variables: + by_name["lat"] = pieces[cursor] + cursor += 1 + if "lon" in self._invariant_variables: + by_name["lon"] = pieces[cursor] + cursor += 1 + + self._invariants_cache = np.stack( + [by_name[v] for v in self._invariant_variables], axis=0 + ).astype(np.float32) + return self._invariants_cache + + def normalize_state( + self, x: np.ndarray | torch.Tensor + ) -> np.ndarray | torch.Tensor: + if self._mean is None: + return x + mean, std = self._broadcast_stats_for(x) + return (x - mean) / std + + def denormalize_state( + self, x: np.ndarray | torch.Tensor + ) -> np.ndarray | torch.Tensor: + if self._mean is None: + return x + mean, std = self._broadcast_stats_for(x) + return x * std + mean + + # --- Internals ----------------------------------------------------------- + + def _fetch_tp_accumulation(self, frame_times: list[datetime]) -> np.ndarray: + """Sum the N hourly ARCO ``tp`` values preceding each frame time. + + ARCO stores ``total_precipitation`` as the hourly accumulation + during ``[t-1h, t]``, matching ECMWF's ERA5 convention. So the + paper's N-hour accumulation ending at T equals + ``sum(tp(T-N+1), tp(T-N+2), ..., tp(T))`` — N hourly values. We + fetch all distinct hourly stamps required by any frame in a single + earth2studio call to minimise GCS round-trips. + """ + if self.tp_accumulation_hours is None: + raise RuntimeError( + "_fetch_tp_accumulation called without tp_accumulation_hours set" + ) + N = self.tp_accumulation_hours + + # Union of hours we need across all frames, sorted. + hourly_set: set[datetime] = set() + per_frame_hours: list[list[datetime]] = [] + for t in frame_times: + hours = [t - timedelta(hours=N - 1 - j) for j in range(N)] + per_frame_hours.append(hours) + hourly_set.update(hours) + unique_hours = sorted(hourly_set) + hour_to_idx = {t: i for i, t in enumerate(unique_hours)} + + da = self._ensure_arco()(time=unique_hours, variable=["tp"]) + hourly = np.asarray(da.values, dtype=np.float32) # (U, 1, 721, 1440) + hourly = hourly[:, 0] # (U, 721, 1440) + if self.stride > 1: + hourly = hourly[:, :: self.stride, :: self.stride] + + acc = np.zeros((len(frame_times), self.height, self.width), dtype=np.float32) + for k, hours_k in enumerate(per_frame_hours): + acc[k] = sum(hourly[hour_to_idx[h]] for h in hours_k) + return acc + + def _load_stats(self, stats_path: str) -> None: + from pathlib import Path as _Path + + path = _Path(stats_path) + if not path.exists(): + raise FileNotFoundError(f"stats_path does not exist: {path}") + data = np.load(path) + if "mean" not in data or "std" not in data: + raise KeyError( + f"{path} must contain arrays 'mean' and 'std'; got {list(data.files)}" + ) + mean = np.asarray(data["mean"], dtype=np.float32) + std = np.asarray(data["std"], dtype=np.float32) + expected = (len(self._state_variables),) + if mean.shape != expected or std.shape != expected: + raise ValueError( + f"stats mean/std must have shape {expected} matching " + f"state_variables; got mean={mean.shape}, std={std.shape}" + ) + if np.any(std == 0): + raise ValueError("stats std contains zeros; cannot z-score normalize") + self._mean = mean + self._std = std + + def _broadcast_stats_for(self, x: np.ndarray | torch.Tensor) -> tuple[Any, Any]: + """Reshape `(V,)` stats to broadcast along the channel axis of ``x``. + + Supports ``x`` of shape ``(V, H, W)``, ``(T, V, H, W)``, or + ``(B, T, V, H, W)`` — channel axis is the third-from-last. + """ + if x.ndim == 3: + shape = (-1, 1, 1) + elif x.ndim == 4: + shape = (1, -1, 1, 1) + elif x.ndim == 5: + shape = (1, 1, -1, 1, 1) + else: + raise ValueError(f"unsupported state tensor ndim {x.ndim}") + if isinstance(x, torch.Tensor): + mean = torch.as_tensor(self._mean, dtype=x.dtype, device=x.device).reshape( + shape + ) + std = torch.as_tensor(self._std, dtype=x.dtype, device=x.device).reshape( + shape + ) + else: + mean = self._mean.reshape(shape) + std = self._std.reshape(shape) + return mean, std + + def _ensure_arco(self): + if self._arco is None: + from earth2studio.data import ARCO + + self._arco = ARCO(cache=self.arco_cache, verbose=False) + return self._arco + + def _target_time(self, idx: int) -> datetime: + return self.start + timedelta( + hours=(idx + self.history_frames) * self.step_hours + ) + + @staticmethod + def _impute_sst_nans_(arr: np.ndarray) -> None: + """Replace NaNs in SST with the global min SST seen in the batch. + + Paper A.1.1: ERA5 represents land in SST with NaNs; we impute with a + global minimum to keep the tensor dense. This happens over the whole + fetched window to avoid leaking land-mask information. + """ + # Implicitly finds any channel whose values contain NaN; SST is the + # only Table A.1 variable expected to have them. + if not np.isnan(arr).any(): + return + for c in range(arr.shape[1]): + chan = arr[:, c, :, :] + mask = np.isnan(chan) + if not mask.any(): + continue + finite_min = float(np.nanmin(chan)) + chan[mask] = finite_min + arr[:, c, :, :] = chan + + def _clock_features(self, t: datetime) -> np.ndarray: + # Year progress in [0, 1) + year_start = datetime(t.year, 1, 1) + year_end = datetime(t.year + 1, 1, 1) + yp = (t - year_start).total_seconds() / (year_end - year_start).total_seconds() + yp_sin = float(np.sin(2 * np.pi * yp)) + yp_cos = float(np.cos(2 * np.pi * yp)) + + utc_hours = t.hour + t.minute / 60.0 + t.second / 3600.0 + lon = ARCO_LON[:: self.stride] + local_hours = (utc_hours + lon / 15.0) % 24.0 + local_frac = local_hours / 24.0 + lt_sin_row = np.sin(2 * np.pi * local_frac).astype(np.float32) + lt_cos_row = np.cos(2 * np.pi * local_frac).astype(np.float32) + + lt_sin = np.broadcast_to(lt_sin_row[None, :], (self.height, self.width)).copy() + lt_cos = np.broadcast_to(lt_cos_row[None, :], (self.height, self.width)).copy() + yp_sin_field = np.full((self.height, self.width), yp_sin, dtype=np.float32) + yp_cos_field = np.full((self.height, self.width), yp_cos, dtype=np.float32) + + return np.stack([lt_sin, lt_cos, yp_sin_field, yp_cos_field], axis=0) diff --git a/examples/weather/fgn/datasets/dataset.py b/examples/weather/fgn/datasets/dataset.py new file mode 100644 index 0000000000..038ad09c47 --- /dev/null +++ b/examples/weather/fgn/datasets/dataset.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""FGNDataset ABC and DataLoader worker initialiser. + +Mirrors ``examples/weather/stormcast/datasets/dataset.py`` — same ABC pattern, +same ``worker_init`` seeding convention. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import numpy as np +import torch + + +class FGNDataset(torch.utils.data.Dataset, ABC): + """Abstract base class for all FGN training datasets. + + Subclasses must implement the five abstract methods below plus the standard + ``torch.utils.data.Dataset`` protocol (``__len__`` and ``__getitem__``). + ``__getitem__`` should return a dict with keys ``"history"``, ``"target"``, + and ``"background"`` — already z-score normalized. + """ + + @abstractmethod + def state_channels(self) -> list[str]: + """Ordered list of state variable names (e.g. ``["t2m", "z500", ...]``).""" + + @abstractmethod + def background_channels(self) -> list[str]: + """Ordered list of background / conditioning variable names.""" + + @abstractmethod + def image_shape(self) -> tuple[int, int]: + """Spatial grid size as ``(H, W)``.""" + + def get_invariants(self) -> np.ndarray | None: + """Static invariant channels as ``(C_inv, H, W)`` float32, or None.""" + return None + + def output_only_channels(self) -> list[int]: + """Channel indices that must not be fed back as input (e.g. tp06).""" + return [] + + +def worker_init(wrk_id: int) -> None: + np.random.seed(torch.utils.data.get_worker_info().seed % (2**32 - 1)) diff --git a/examples/weather/fgn/eval.py b/examples/weather/fgn/eval.py new file mode 100644 index 0000000000..d7a4bc927e --- /dev/null +++ b/examples/weather/fgn/eval.py @@ -0,0 +1,448 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Standalone evaluation for FGN — paper §4 (arXiv:2506.10772v1). + +Iterates the full validation split, runs an M-member AR ensemble rollout for K +lead times, and accumulates the diagnostics shown in Figures 2–3: + + - Fair CRPS per variable per lead (area-weighted via cos-lat) + - Ensemble-mean RMSE per variable per lead (area-weighted) + - Spread-skill ratio per variable per lead (area-weighted) + - Rank histograms per variable (aggregated over all leads) + - Average-pooled CRPS and max-pooled CRPS (Figure 3 a-b) + - Derived-variable CRPS: wind speed and z300-z500 (Figure 3 c) + - Azimuthal power spectra at the final lead (Figure 3 e-j) + - Fair energy score per lead (multivariate CRPS) + +Per-variable CRPS, RMSE, and rank histograms use +``earth2studio.statistics.{crps, rmse, rank_histogram}`` with +``earth2studio.statistics.lat_weight`` for area weighting. +Pooled CRPS and power spectra use the lightweight torch kernels in +``utils/metrics.py``. + +All results are written as ``eval_metrics.npz`` + PNG plots to ``eval.outdir``. + +Usage:: + + python eval.py --config-name eval_fgn \\ + dataset.stats_path=rundir/fgn_2024_val/stats_2024.npz \\ + eval.checkpoint=rundir/fgn_2024_long/0/checkpoints/FGNUNet.0.5000.mdlus + +Deep-ensemble:: + + python eval.py --config-name eval_fgn \\ + dataset.stats_path=... \\ + "eval.checkpoints=[seed0/FGNUNet.mdlus,seed1/FGNUNet.mdlus]" +""" + +from __future__ import annotations + +import logging +from collections import OrderedDict +from pathlib import Path + +import hydra +import numpy as np +import torch +from datasets import dataset_classes +from datasets.dataset import worker_init +from omegaconf import DictConfig, OmegaConf +from torch.utils.data import DataLoader +from utils.config import EvalMainConfig +from utils.metrics import ( + derived_variable_crps, + energy_score_per_lead, + plot_crps_scorecard, + plot_metric_vs_lead, + plot_pooled_crps, + plot_power_spectra, + plot_rank_histograms, + plot_spread_skill_lines, + pooled_crps_per_lead, + power_spectra_per_variable, + save_summary, +) +from utils.trainer import find_latest_model_checkpoint + +from earth2studio.statistics import crps as e2s_crps +from earth2studio.statistics import rank_histogram as e2s_rh +from earth2studio.statistics import rmse as e2s_rmse +from earth2studio.statistics.weights import lat_weight +from physicsnemo.core import Module + +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _resolve_checkpoints(cfg: DictConfig) -> list[str]: + checkpoints = getattr(cfg.eval, "checkpoints", None) + if checkpoints: + return [str(c) for c in checkpoints] + checkpoint = cfg.eval.checkpoint + if checkpoint == "latest": + ckpt_dir = Path(cfg.training.rundir) / cfg.training.checkpoint_dir + return [str(find_latest_model_checkpoint(ckpt_dir))] + return [str(checkpoint)] + + +def _make_coords( + B: int, + M: int, + variables: list[str], + lats: np.ndarray, + lons: np.ndarray, +) -> tuple[OrderedDict, OrderedDict]: + """CoordSystem pair for (B, M, C, H, W) ensemble and (B, C, H, W) target.""" + x_coords = OrderedDict( + [ + ("batch", np.arange(B)), + ("ensemble", np.arange(M)), + ("variable", np.array(variables)), + ("lat", lats), + ("lon", lons), + ] + ) + y_coords = OrderedDict( + [ + ("batch", np.arange(B)), + ("variable", np.array(variables)), + ("lat", lats), + ("lon", lons), + ] + ) + return x_coords, y_coords + + +def _ar_rollout( + model: torch.nn.Module, + history: torch.Tensor, + background: torch.Tensor, + invariants: torch.Tensor | None, + num_steps: int, + latent_dim: int, + num_members: int, + device: torch.device, + output_only: list[int], +) -> torch.Tensor: + """Return ``(B, K, M, C, H, W)`` ensemble from an M-member AR rollout. + + Mirrors ``utils/trainer.py:_run_validation_metrics``: each member advances + independently; predicted-only channels (e.g. tp06) are zeroed before being + fed back as history on the next step. + """ + B, T, C, H, W = history.shape + per_member_hist = ( + history.unsqueeze(1).expand(B, num_members, T, C, H, W).contiguous() + ) + preds_all: list[torch.Tensor] = [] + for k in range(num_steps): + members: list[torch.Tensor] = [] + for n in range(num_members): + latent = torch.randn(B, latent_dim, device=device, dtype=torch.float32) + with torch.autocast( + "cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available() + ): + pred = model( + history=per_member_hist[:, n], + latent=latent, + background=background, + invariants=invariants, + ).float() + members.append(pred) + preds = torch.stack(members, dim=1) # (B, M, C, H, W) + preds_all.append(preds) + if k < num_steps - 1: + next_frame = preds + if output_only: + next_frame = next_frame.clone() + for ci in output_only: + next_frame[:, :, ci].zero_() + per_member_hist = torch.cat( + [per_member_hist[:, :, 1:], next_frame.unsqueeze(2)], dim=2 + ) + return torch.stack(preds_all, dim=1) # (B, K, M, C, H, W) + + +# --------------------------------------------------------------------------- +# Main eval loop +# --------------------------------------------------------------------------- + + +def run_eval(cfg: DictConfig) -> None: + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + ecfg = EvalMainConfig(**cfg_dict) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.manual_seed(17) + + # --- Dataset --- + dataset_cls = dataset_classes[cfg.dataset.name] + val_dataset = dataset_cls(cfg.dataset, train=False) + num_workers = int(ecfg.eval.num_workers) + val_loader = DataLoader( + val_dataset, + batch_size=int(ecfg.eval.batch_size), + shuffle=False, + num_workers=num_workers, + worker_init_fn=worker_init if num_workers else None, + ) + log.info(f"Val dataset: {len(val_dataset)} samples, {len(val_loader)} batches") + + invariants = val_dataset.get_invariants() + if invariants is not None: + invariants = torch.from_numpy(invariants).to(device, dtype=torch.float32) + + variables = val_dataset.state_channels() + output_only = val_dataset.output_only_channels() + C = len(variables) + K = int(ecfg.eval.future_steps) + M = int(ecfg.eval.ensemble_size) + latent_dim = int(ecfg.model.latent_dim) + pool_sizes = list(ecfg.eval.pool_sizes) + step_hours = int(getattr(cfg.dataset, "step_hours", 6)) + spatial_stride = int(getattr(cfg.dataset, "spatial_stride", 1)) + + # --- lat/lon grids and area weights (earth2studio.statistics.lat_weight) --- + from datasets.arco import ARCO_LAT, ARCO_LON + + lats_np = ARCO_LAT[::spatial_stride] + lons_np = ARCO_LON[::spatial_stride] + H, W = len(lats_np), len(lons_np) + + area_w_1d = lat_weight(torch.from_numpy(lats_np)) # (H,) cos-lat + area_w_2d = area_w_1d.unsqueeze(-1).expand(H, W).contiguous() # (H, W) + + # --- earth2studio stat objects (area-weighted, per-batch per-lead) --- + crps_fn = e2s_crps( + ensemble_dimension="ensemble", + reduction_dimensions=["lat", "lon"], + weights=area_w_2d, + fair=True, + ) + rmse_fn = e2s_rmse( + reduction_dimensions=["lat", "lon"], + weights=area_w_2d, + ensemble_dimension="ensemble", + ) + rh_fn = e2s_rh( + ensemble_dimension="ensemble", + reduction_dimensions=["lat", "lon"], + ) + + # --- Load model(s) --- + checkpoint_paths = _resolve_checkpoints(cfg) + log.info(f"Checkpoints: {checkpoint_paths}") + models: list[torch.nn.Module] = [] + for ckpt_path in checkpoint_paths: + models.append(Module.from_checkpoint(ckpt_path).to(device).eval()) + + n_models = len(models) + base_m = M // n_models + rem_m = M - base_m * n_models + members_per_model = [base_m + (1 if i < rem_m else 0) for i in range(n_models)] + + # --- Accumulators --- + # Shape (K, C) for per-lead per-variable metrics. + crps_acc = np.zeros((K, C), dtype=np.float64) + rmse_acc = np.zeros((K, C), dtype=np.float64) + # Spread-skill: accumulate ensemble std and RMSE separately. + spread_acc = np.zeros((K, C), dtype=np.float64) + rank_acc = np.zeros((M + 1, C), dtype=np.float64) # aggregated over leads+batches + energy_acc = np.zeros(K, dtype=np.float64) + power_ens_acc: np.ndarray | None = None + power_tgt_acc: np.ndarray | None = None + pooled_avg_acc: np.ndarray | None = None + pooled_max_acc: np.ndarray | None = None + derived_acc: dict[str, np.ndarray] = {} + n_batches = 0 + + # --- Eval loop --- + with torch.no_grad(): + for batch_idx, batch in enumerate(val_loader): + history = batch["history"].to(device, dtype=torch.float32) + target = batch["target"].to(device, dtype=torch.float32) + background = batch["background"].to(device, dtype=torch.float32) + if target.ndim == 4: + target = target.unsqueeze(1) + + B = history.shape[0] + inv_b = ( + invariants.unsqueeze(0).expand(B, -1, -1, -1) + if invariants is not None + else None + ) + + # AR rollout across all models → (B, K, M, C, H, W) + preds: list[torch.Tensor] = [] + for model, n_mem in zip(models, members_per_model, strict=True): + if n_mem > 0: + preds.append( + _ar_rollout( + model, history, background, inv_b, + K, latent_dim, n_mem, device, output_only, + ) + ) + ensemble = torch.cat(preds, dim=2) # (B, K, M, C, H, W) + + # Per-lead metrics (earth2studio, area-weighted) + xc, yc = _make_coords(B, M, variables, lats_np, lons_np) + for k in range(K): + ens_k = ensemble[:, k] # (B, M, C, H, W) + tgt_k = target[:, k] # (B, C, H, W) + + crps_res, _ = crps_fn(ens_k, xc, tgt_k, yc) # (B, C) + crps_acc[k] += crps_res.mean(dim=0).cpu().numpy() + + rmse_res, _ = rmse_fn(ens_k, xc, tgt_k, yc) # (B, C) RMSE + rmse_acc[k] += rmse_res.mean(dim=0).cpu().numpy() + + # Spread: sqrt of mean ensemble variance over lat/lon (area-weighted) + ens_var = ens_k.var(dim=1, unbiased=True) # (B, C, H, W) var over members + w = area_w_2d.to(ens_var.device) + spread_kc = (ens_var * w).sum(dim=(-2, -1)) / w.sum() # (B, C) mean var + spread_acc[k] += spread_kc.sqrt().mean(dim=0).cpu().numpy() + + rh_res, _ = rh_fn(ens_k, xc, tgt_k, yc) + # rh_res: (2, M+1, B, C) → [bin_centers, bin_counts]; sum over batch + rank_acc += rh_res[1].sum(dim=-2).cpu().numpy() # (M+1, C) + + # Full-rollout metrics (utils/metrics.py) + ens_mean = ensemble.mean(dim=2) # (B, K, C, H, W) + k_vec, ens_spec, tgt_spec = power_spectra_per_variable(ens_mean, target) + if power_ens_acc is None: + power_ens_acc, power_tgt_acc = ens_spec, tgt_spec + else: + power_ens_acc += ens_spec + power_tgt_acc += tgt_spec + + energy_acc += energy_score_per_lead(ensemble, target) + + p_avg = pooled_crps_per_lead(ensemble, target, pool_sizes, "avg") + p_max = pooled_crps_per_lead(ensemble, target, pool_sizes, "max") + if pooled_avg_acc is None: + pooled_avg_acc, pooled_max_acc = p_avg, p_max + else: + pooled_avg_acc += p_avg + pooled_max_acc += p_max + + for dname, vals in derived_variable_crps(ensemble, target, variables).items(): + derived_acc[dname] = derived_acc.get(dname, 0.0) + vals + + n_batches += 1 + if (batch_idx + 1) % 50 == 0: + log.info(f" {batch_idx + 1}/{len(val_loader)} batches done") + + if n_batches == 0: + raise RuntimeError("Validation dataset is empty.") + + # --- Normalise --- + crps_mean = crps_acc / n_batches + rmse_mean = rmse_acc / n_batches + spread_mean = spread_acc / n_batches + ratio_mean = spread_mean / np.maximum(rmse_mean, 1e-12) + energy_mean = energy_acc / n_batches + power_ens_mean = power_ens_acc / n_batches + power_tgt_mean = power_tgt_acc / n_batches + pooled_avg_mean = pooled_avg_acc / n_batches + pooled_max_mean = pooled_max_acc / n_batches + + leads = np.arange(1, K + 1, dtype=np.int64) + lead_hours = leads * step_hours + + # --- Save --- + out_dir = Path(ecfg.eval.outdir) + out_dir.mkdir(parents=True, exist_ok=True) + + summary: dict = { + "crps_per_lead_per_channel": crps_mean, + "rmse_per_lead_per_channel": rmse_mean, + "spread_per_lead_per_channel": spread_mean, + "spread_skill_ratio": ratio_mean, + "rank_histograms": rank_acc, + "energy_score_per_lead": energy_mean, + "avg_pooled_crps": pooled_avg_mean, + "max_pooled_crps": pooled_max_mean, + "pool_sizes": np.array(pool_sizes, dtype=np.int64), + "power_spectrum_k": k_vec, + "power_spectrum_forecast": power_ens_mean, + "power_spectrum_truth": power_tgt_mean, + "variables": np.array(variables, dtype=object), + "lead_steps": leads, + "lead_hours": lead_hours, + "num_batches": np.array(n_batches), + "checkpoint_paths": np.array(checkpoint_paths, dtype=object), + } + for dname, vals in derived_acc.items(): + summary[f"derived_crps_{dname}"] = vals / n_batches + save_summary(summary, str(out_dir / "eval_metrics.npz")) + log.info(f"Saved eval_metrics.npz ({n_batches} batches) → {out_dir}") + + # --- Plots --- + # Figure 2a: CRPS scorecard heatmap (rows=variables, cols=lead times) + plot_crps_scorecard( + crps_mean, variables, lead_hours, + str(out_dir / "crps_scorecard.png"), + title="Fair CRPS scorecard (normalised per variable)", + ) + # Figure 2a equivalent for RMSE + plot_crps_scorecard( + rmse_mean, variables, lead_hours, + str(out_dir / "rmse_scorecard.png"), + title="Ensemble-mean RMSE scorecard (normalised per variable)", + ) + # Figure 2a equivalent for spread-skill ratio + plot_crps_scorecard( + ratio_mean, variables, lead_hours, + str(out_dir / "spread_skill_scorecard.png"), + title="Spread-skill ratio scorecard (normalised per variable)", + ) + # Figure 2b-f: spread vs RMSE line plots for 5 key variables + plot_spread_skill_lines( + spread_mean, rmse_mean, variables, lead_hours, + str(out_dir / "spread_skill_lines.png"), + ) + # rank_acc shape: (M+1, C) → plot expects (C, M+1) + plot_rank_histograms(rank_acc.T.astype(np.int64), variables, + str(out_dir / "rank_histograms.png")) + # Energy score: single line, fine to keep + plot_metric_vs_lead( + energy_mean[:, None], ["multivariate"], lead_hours, "energy score", + "Energy score per lead (lower is better)", + str(out_dir / "energy_score_vs_lead.png"), + ) + plot_power_spectra( + k_vec, power_ens_mean, power_tgt_mean, variables, + lead_hours_all=lead_hours, + out_path=str(out_dir / "power_spectra.png"), + ) + plot_pooled_crps( + pooled_avg_mean, pool_sizes, variables, lead_hours, + str(out_dir / "avg_pooled_crps.png"), title="Average-pooled CRPS", + ) + plot_pooled_crps( + pooled_max_mean, pool_sizes, variables, lead_hours, + str(out_dir / "max_pooled_crps.png"), title="Max-pooled CRPS", + ) + # Figure 3c: derived variable CRPS — single line per derived var, readable + for dname, vals in derived_acc.items(): + plot_metric_vs_lead( + (vals / n_batches)[:, None], [dname], lead_hours, "CRPS", + f"{dname} CRPS per lead (Figure 3c)", + str(out_dir / f"derived_crps_{dname}.png"), + ) + + log.info(f"Eval complete. All outputs in {out_dir}") + + +@hydra.main(version_base=None, config_path="config", config_name="eval_fgn") +def main(cfg: DictConfig) -> None: + run_eval(cfg) + + +if __name__ == "__main__": + main() diff --git a/examples/weather/fgn/inference.py b/examples/weather/fgn/inference.py new file mode 100644 index 0000000000..77d97b4fbb --- /dev/null +++ b/examples/weather/fgn/inference.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Run autoregressive stochastic inference for the FGN recipe. + +Supports both single-model stochastic rollout and the paper's deep-ensemble +inference path (§2.2.1 of arXiv:2506.10772v1: J=4 independently-trained +models, equal number of members per model, model identity fixed over all +timesteps of a given trajectory; aleatoric noise ``z_t`` resampled every +step). +""" + +from pathlib import Path + +import hydra +import torch +from datasets import dataset_classes +from omegaconf import DictConfig +from utils.trainer import find_latest_model_checkpoint + +from physicsnemo.core import Module +from physicsnemo.distributed import DistributedManager + + +def _resolve_checkpoints(cfg: DictConfig) -> list[str]: + """Resolve the inference config to an ordered list of checkpoint paths. + + Priority: ``inference.checkpoints`` (list) wins if set. Otherwise fall + back to ``inference.checkpoint`` (single path or ``"latest"``). Single- + model inference is just the length-1 deep-ensemble case. + """ + checkpoints = ( + cfg.inference.get("checkpoints", None) + if hasattr(cfg.inference, "get") + else getattr(cfg.inference, "checkpoints", None) + ) + if checkpoints: + return [str(c) for c in checkpoints] + + checkpoint = cfg.inference.checkpoint + if checkpoint == "latest": + return [ + str( + find_latest_model_checkpoint( + Path(cfg.training.rundir) / cfg.training.checkpoint_dir + ) + ) + ] + return [str(checkpoint)] + + +def _allocate_members(num_trajectories: int, num_models: int) -> list[int]: + """Distribute ``num_trajectories`` members across ``num_models`` models. + + Paper §2.2.1: "we generate an equal number of ensemble member + trajectories from each model". When ``num_trajectories`` is not + divisible by ``num_models``, put the remainder on the earlier models. + """ + if num_models <= 0: + raise ValueError("num_models must be positive") + base = num_trajectories // num_models + rem = num_trajectories % num_models + return [base + (1 if i < rem else 0) for i in range(num_models)] + + +def _rollout( + model: torch.nn.Module, + history: torch.Tensor, + background: torch.Tensor, + invariants: torch.Tensor | None, + num_steps: int, + latent_dim: int, + num_trajectories: int, + device: torch.device, + output_only_channels: list[int] | None = None, +) -> torch.Tensor: + """Run ``num_trajectories`` independent autoregressive rollouts. + + Returns a tensor of shape ``(num_trajectories, num_steps, C, H, W)``. + The model identity is fixed for the lifetime of each trajectory + (paper §2.2.1); ``z_t`` is resampled every step (paper §2.2.2). + Paper §3: predicted-only channels (``tp06``) are zeroed before being + fed back as input for the next rollout step. + """ + output_only_channels = output_only_channels or [] + trajectories: list[torch.Tensor] = [] + for _ in range(num_trajectories): + rollout_history = history.clone() + states: list[torch.Tensor] = [] + for _ in range(num_steps): + latent = torch.randn( + history.shape[0], + latent_dim, + device=device, + dtype=torch.float32, + ) + pred = model( + history=rollout_history, + latent=latent, + background=background, + invariants=invariants, + ) + states.append(pred) + next_frame = pred + if output_only_channels: + next_frame = next_frame.clone() + for ci in output_only_channels: + next_frame[:, ci].zero_() + rollout_history = torch.cat( + [rollout_history[:, 1:], next_frame.unsqueeze(1)], + dim=1, + ) + trajectories.append(torch.stack(states, dim=1)) + return torch.cat(trajectories, dim=0) + + +def run_inference(cfg: DictConfig) -> dict[str, float | str | int | list[int]]: + DistributedManager.initialize() + dist = DistributedManager() + if dist.world_size != 1: + raise NotImplementedError( + "The FGN inference scaffold currently supports a single process only." + ) + + device = dist.device + torch.manual_seed(int(cfg.inference.seed)) + + dataset_cls = dataset_classes[cfg.dataset.name] + dataset = dataset_cls(cfg.dataset, train=False) + sample = dataset[int(cfg.inference.dataset_index)] + + checkpoint_paths = _resolve_checkpoints(cfg) + num_trajectories = int(cfg.inference.num_trajectories) + members_per_model = _allocate_members(num_trajectories, len(checkpoint_paths)) + + history = sample["history"].unsqueeze(0).to(device=device, dtype=torch.float32) + target = sample["target"].unsqueeze(0).to(device=device, dtype=torch.float32) + # Datasets may emit target as (K, C, H, W) for AR training; inference MAE + # only uses the first step, so collapse to (B, C, H, W). + if target.ndim == 5: + target = target[:, 0] + background = ( + sample["background"].unsqueeze(0).to(device=device, dtype=torch.float32) + ) + + invariants = dataset.get_invariants() + if invariants is not None: + invariants = ( + torch.from_numpy(invariants) + .unsqueeze(0) + .to(device=device, dtype=torch.float32) + ) + + all_trajectories: list[torch.Tensor] = [] + num_steps = int(cfg.inference.num_steps) + output_only = dataset.output_only_channels() + with torch.no_grad(): + for ckpt_path, n_members in zip( + checkpoint_paths, members_per_model, strict=True + ): + if n_members <= 0: + continue + model = Module.from_checkpoint(ckpt_path).to(device).eval() + latent_dim = int(getattr(model, "latent_dim", cfg.model.latent_dim)) + traj = _rollout( + model=model, + history=history, + background=background, + invariants=invariants, + num_steps=num_steps, + latent_dim=latent_dim, + num_trajectories=n_members, + device=device, + output_only_channels=output_only, + ) + all_trajectories.append(traj) + del model # free VRAM before loading next checkpoint + + trajectory_tensor = torch.cat(all_trajectories, dim=0).cpu() + ensemble_mean = trajectory_tensor[:, 0].mean(dim=0, keepdim=True) + first_step_mae = float((ensemble_mean - target.cpu()).abs().mean()) + + output_path = Path(cfg.inference.output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "history": history.cpu(), + "target": target.cpu(), + "trajectories": trajectory_tensor, + "first_step_mae": first_step_mae, + "num_models": len(checkpoint_paths), + "members_per_model": members_per_model, + "checkpoint_paths": checkpoint_paths, + }, + output_path, + ) + + return { + "output_path": str(output_path), + "first_step_mae": first_step_mae, + "num_models": len(checkpoint_paths), + "members_per_model": members_per_model, + } + + +@hydra.main(version_base=None, config_path="config", config_name="inference_fgn") +def main(cfg: DictConfig) -> None: + result = run_inference(cfg) + print(f"Saved inference outputs to {result['output_path']}") + print(f"First-step ensemble-mean MAE: {result['first_step_mae']:.6f}") + + +if __name__ == "__main__": + main() diff --git a/examples/weather/fgn/scripts/compute_arco_stats.py b/examples/weather/fgn/scripts/compute_arco_stats.py new file mode 100644 index 0000000000..f5de7808ae --- /dev/null +++ b/examples/weather/fgn/scripts/compute_arco_stats.py @@ -0,0 +1,230 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Compute per-channel mean/std from ARCO/ERA5 for FGN normalization. + +Writes an ``.npz`` with arrays ``mean`` and ``std`` (each of shape +``(len(variables),)``) in the same variable order you pass in. The output +is consumed by `datasets.arco.ArcoFGNDataset` via the ``stats_path`` config +key. + +Usage +----- + python scripts/compute_arco_stats.py \\ + --variables u10m v10m t2m msl z500 q850 \\ + --samples 128 --stride 4 --output stats.npz + +Defaults target the 83-channel Table A.1 schema; ``--samples`` draws random +timestamps uniformly from the training window. Uses a Welford-style online +accumulator to avoid materialising the full sample stack in memory. +""" + +from __future__ import annotations + +import argparse +from datetime import datetime, timedelta +from pathlib import Path + +import numpy as np + +# Paper Table A.1 defaults -- keep in sync with datasets.arco.DEFAULT_STATE. +DEFAULT_ATMOS_VARS = ("z", "q", "t", "u", "v", "w") +DEFAULT_LEVELS = ( + 50, + 100, + 150, + 200, + 250, + 300, + 400, + 500, + 600, + 700, + 850, + 925, + 1000, +) +DEFAULT_SURFACE = ("t2m", "u10m", "v10m", "msl", "sst") +DEFAULT_STATE = tuple( + [f"{v}{lvl}" for v in DEFAULT_ATMOS_VARS for lvl in DEFAULT_LEVELS] + + list(DEFAULT_SURFACE) +) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "--variables", + nargs="+", + default=list(DEFAULT_STATE), + help="ARCO compact-name variables, in the order you want the stats " + "arrays to be laid out. Default: paper Table A.1 (83 channels).", + ) + p.add_argument( + "--start", + type=str, + default="1979-01-01", + help="Earliest sample time (ISO date).", + ) + p.add_argument( + "--end", + type=str, + default="2018-01-01", + help="Latest sample time (ISO date, exclusive).", + ) + p.add_argument( + "--step-hours", + type=int, + default=6, + help="Sampling cadence; restricts timestamps to a 6h grid by default.", + ) + p.add_argument( + "--samples", + type=int, + default=256, + help="Number of random timestamps to average over.", + ) + p.add_argument( + "--stride", + type=int, + default=1, + help="Spatial stride applied to the 721x1440 grid to cut fetch cost.", + ) + p.add_argument( + "--tp-accumulation-hours", + type=int, + default=None, + help="If set to N, any variable named tp{N:02d} (e.g. tp06 for N=6) " + "in --variables is treated as a paper §3 N-hour accumulation of " + "ARCO hourly ``tp``, not a native ARCOLexicon key. Matches " + "ArcoFGNDataset.tp_accumulation_hours so stats and training see " + "the same representation.", + ) + p.add_argument("--seed", type=int, default=0) + p.add_argument( + "--output", + type=Path, + required=True, + help="Destination .npz path.", + ) + return p.parse_args() + + +def iter_sample_times( + start: datetime, + end: datetime, + step_hours: int, + samples: int, + rng: np.random.Generator, +): + total_hours = int((end - start).total_seconds() // 3600) + max_offset = total_hours // step_hours + if max_offset <= 0: + raise ValueError("start/end window is too narrow for the requested step_hours") + offsets = rng.integers(0, max_offset, size=samples) + return [start + timedelta(hours=int(o) * step_hours) for o in offsets] + + +def _accumulate_tp(arco, time: datetime, hours: int) -> np.ndarray: + """Sum ``hours`` hourly ARCO ``tp`` values ending at ``time``. + + Mirrors ArcoFGNDataset._fetch_tp_accumulation's semantics (ARCO ``tp`` + is the hourly accumulation during ``[t-1h, t]``; an N-hour window + ending at T sums values at T-N+1, ..., T). + """ + window = [time - timedelta(hours=hours - 1 - j) for j in range(hours)] + da = arco(time=window, variable=["tp"]) + hourly = np.asarray(da.values, dtype=np.float32)[:, 0] # (N, 721, 1440) + return hourly.sum(axis=0) + + +def main() -> None: + args = parse_args() + from earth2studio.data import ARCO + + rng = np.random.default_rng(args.seed) + start = datetime.fromisoformat(args.start) + end = datetime.fromisoformat(args.end) + + times = iter_sample_times(start, end, args.step_hours, args.samples, rng) + arco = ARCO(cache=True, verbose=True) + + # Figure out which slot (if any) is the tp{N} accumulation. Exclude it + # from the bulk fetch just like ArcoFGNDataset does, and splice the + # accumulation back into the channel tensor before computing stats. + tp_name: str | None = None + tp_idx: int | None = None + arco_vars = list(args.variables) + if args.tp_accumulation_hours is not None: + tp_name = f"tp{args.tp_accumulation_hours:02d}" + if tp_name in args.variables: + tp_idx = args.variables.index(tp_name) + arco_vars = [v for v in args.variables if v != tp_name] + + # Welford online accumulator per channel. + n_channels = len(args.variables) + count = np.int64(0) + mean = np.zeros(n_channels, dtype=np.float64) + m2 = np.zeros(n_channels, dtype=np.float64) + + for i, t in enumerate(times): + da = arco(time=[t], variable=arco_vars) + fetched = np.asarray(da.values, dtype=np.float32)[0] # (V', 721, 1440) + if args.stride > 1: + fetched = fetched[:, :: args.stride, :: args.stride] + + # Embed fetched channels + (optional) tp accumulation into a single + # tensor whose channel axis matches args.variables order. + arr = np.zeros( + (n_channels, fetched.shape[-2], fetched.shape[-1]), dtype=np.float32 + ) + if tp_idx is None: + arr[:] = fetched + else: + arr[:tp_idx] = fetched[:tp_idx] + arr[tp_idx + 1 :] = fetched[tp_idx:] + tp_acc = _accumulate_tp(arco, t, args.tp_accumulation_hours) + if args.stride > 1: + tp_acc = tp_acc[:: args.stride, :: args.stride] + arr[tp_idx] = tp_acc + + # SST NaN imputation with global min (paper A.1.1) so the stats are + # computed on the same representation the training pipeline sees. + if "sst" in args.variables: + sst_idx = args.variables.index("sst") + sst = arr[sst_idx] + nan_mask = np.isnan(sst) + if nan_mask.any(): + sst[nan_mask] = float(np.nanmin(sst)) + arr[sst_idx] = sst + + flat = arr.reshape(n_channels, -1).astype(np.float64) + for v in range(n_channels): + col = flat[v] + n_col = col.size + delta = col - mean[v] + total = count + n_col + mean[v] += delta.sum() / total + m2[v] += (delta * (col - mean[v])).sum() + count += flat.shape[1] + + if (i + 1) % 16 == 0 or i == len(times) - 1: + print(f"[{i + 1}/{len(times)}] running mean[0]={mean[0]:.4g}") + + var = m2 / max(count - 1, 1) + std = np.sqrt(var).astype(np.float32) + mean = mean.astype(np.float32) + + args.output.parent.mkdir(parents=True, exist_ok=True) + np.savez( + args.output, + mean=mean, + std=std, + variables=np.array(list(args.variables), dtype=object), + ) + print(f"wrote {args.output} (variables={n_channels}, samples={len(times)})") + + +if __name__ == "__main__": + main() diff --git a/examples/weather/fgn/scripts/prefetch_arco.py b/examples/weather/fgn/scripts/prefetch_arco.py new file mode 100644 index 0000000000..7342d2b185 --- /dev/null +++ b/examples/weather/fgn/scripts/prefetch_arco.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Pre-warm the earth2studio ARCO local cache for an FGN training window. + +The default FGN config streams ERA5 chunks from Google's public ARCO store +at training time. For small state-variable counts that's fine, but at the +paper's 83-channel schema (5 surface + 78 atmospheric) the first epoch is +dominated by GCS fetch latency rather than GPU compute. + +This script iterates over the (time, variable) combinations that +``ArcoFGNDataset`` would need for a given window and calls the ARCO data +source so earth2studio populates its on-disk cache. Subsequent training runs +hitting the same cache location read locally and run at GPU speed. + +Fetches are issued in monthly batches to stay within GCS per-request limits. +earth2studio's ARCO data source fires all (time, variable) pairs in the batch +concurrently via asyncio, so each batch saturates available network bandwidth. + +Fetches: + - State variables at ``step_hours`` cadence, covering + ``[start - history_frames*step_hours, end]`` (so the earliest sample + has all its prior frames cached). The derived ``tp{N}`` placeholder names + are excluded — we handle total precipitation separately. + - Hourly ``tp`` over ``[start - (history_frames*step_hours + N - 1), end]`` + so every sample's ``tp_accumulation_hours`` backward-window is cached. + - Invariants ``z``, ``lsm`` at ``static_date`` (cheap, one-off fetch). + +earth2studio ≥0.15.0a0 dynamically reads ``valid_time_stop`` from the ARCO +zarr metadata, extending coverage to 2025-12-31. + +Runs fine on a CPU-only slurm node — no GPU needed. +""" + +from __future__ import annotations + +import argparse +import sys +from datetime import datetime, timedelta +from pathlib import Path + +_EXAMPLE_DIR = Path(__file__).resolve().parents[1] +if str(_EXAMPLE_DIR) not in sys.path: + sys.path.insert(0, str(_EXAMPLE_DIR)) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "--start", default="2024-01-01", help="Window start (ISO, inclusive)." + ) + p.add_argument("--end", default="2025-01-01", help="Window end (ISO, exclusive).") + p.add_argument("--step-hours", type=int, default=6) + p.add_argument("--history-frames", type=int, default=2) + p.add_argument( + "--tp-accumulation-hours", + type=int, + default=6, + help="N for tp{N:02d} accumulation. 0 = skip tp fetch.", + ) + p.add_argument( + "--variables", + nargs="+", + default=None, + help="Override state variables. Default: ArcoFGNDataset.DEFAULT_STATE (82 ARCO vars).", + ) + p.add_argument( + "--static-date", + default="2016-01-01", + help="Date for one-off invariants fetch (z, lsm).", + ) + p.add_argument( + "--batch-days", + type=int, + default=31, + help="Days of data per time-batch. Default: 31.", + ) + p.add_argument( + "--var-group-size", + type=int, + default=10, + help="Variables per sub-batch. Limits concurrent GCS requests to " + "batch_days_timestamps × var_group_size. Default: 10.", + ) + p.add_argument( + "--no-tp", + dest="include_tp", + action="store_false", + help="Skip hourly tp prefetch.", + ) + p.add_argument( + "--no-invariants", + dest="include_invariants", + action="store_false", + help="Skip invariants prefetch.", + ) + p.set_defaults(include_tp=True, include_invariants=True) + return p.parse_args() + + +def _window_times(start: datetime, end: datetime, step_hours: int) -> list[datetime]: + out: list[datetime] = [] + t = start + while t <= end: + out.append(t) + t += timedelta(hours=step_hours) + return out + + +def _batch( + times: list[datetime], batch_days: int, step_hours: int = 1 +) -> list[list[datetime]]: + """Split a time list into chunks covering at most batch_days of real time.""" + n = max(1, batch_days * 24 // step_hours) + return [times[i : i + n] for i in range(0, len(times), n)] + + +def main() -> int: + args = parse_args() + # Silence earth2studio's per-fetch DEBUG lines — they flood log files at scale. + import sys + + from loguru import logger + + logger.remove() + logger.add(sys.stderr, level="INFO") + + from datasets.arco import DEFAULT_STATE # noqa: E402 — after sys.path patch + from earth2studio.data import ARCO # noqa: E402 + + state_vars = list(args.variables) if args.variables else list(DEFAULT_STATE) + # Filter the derived tp{N} placeholder — not a valid ARCOLexicon key. + fetch_vars = [v for v in state_vars if not v.startswith("tp")] + + start = datetime.fromisoformat(args.start) + end = datetime.fromisoformat(args.end) + static_date = datetime.fromisoformat(args.static_date) + + # Warm window: shift start back so the first sample's history frames are cached. + warm_start_state = start - timedelta(hours=args.history_frames * args.step_hours) + warm_end_state = end + times_state = _window_times(warm_start_state, warm_end_state, args.step_hours) + + tp_acc = args.tp_accumulation_hours + fetch_tp = args.include_tp and tp_acc > 0 + if fetch_tp: + warm_start_tp = warm_start_state - timedelta(hours=tp_acc - 1) + times_tp = _window_times(warm_start_tp, warm_end_state, step_hours=1) + + n_state_req = len(times_state) * len(fetch_vars) + n_tp_req = len(times_tp) * 1 if fetch_tp else 0 + print( + f"State : {len(fetch_vars)} vars × {len(times_state)} timestamps" + f" = {n_state_req:,} requests ({warm_start_state} → {warm_end_state})" + ) + if fetch_tp: + print( + f"TP : 1 var × {len(times_tp)} hourly timestamps" + f" = {n_tp_req:,} requests" + ) + # --var-group-size: number of variables per sub-batch. Keeps concurrent + # GCS requests to batch_days_timestamps × var_group_size to avoid GCS + # rate-limit timeouts at large scale. + var_group_size = args.var_group_size + var_groups = [ + fetch_vars[i : i + var_group_size] + for i in range(0, len(fetch_vars), var_group_size) + ] + total_state_batches = len(_batch(times_state, args.batch_days, args.step_hours)) + print( + f"Batch : {args.batch_days} days × {var_group_size} vars/group" + f" → {total_state_batches} time-batches × {len(var_groups)} var-groups" + f" = {total_state_batches * len(var_groups)} calls\n" + ) + + # async_timeout per call (one time-batch × one var-group). + arco = ARCO(cache=True, async_timeout=3600) + + # --- state variables: time-batch outer loop, var-group inner loop --- + state_batches = _batch(times_state, args.batch_days, step_hours=args.step_hours) + n_tb = len(state_batches) + n_vg = len(var_groups) + for ti, batch_times in enumerate(state_batches, 1): + for vi, vgroup in enumerate(var_groups, 1): + print( + f"[state {ti}/{n_tb} vars {vi}/{n_vg}] " + f"{batch_times[0].date()} → {batch_times[-1].date()}" + f" ({len(batch_times)} steps × {len(vgroup)} vars" + f" = {len(batch_times) * len(vgroup)} requests)" + ) + arco(time=batch_times, variable=vgroup) + + # --- hourly tp, batched by time only (1 var, small requests) --- + if fetch_tp: + tp_batches = _batch(times_tp, args.batch_days, step_hours=1) + for i, batch_times in enumerate(tp_batches, 1): + print( + f"[tp {i}/{len(tp_batches)}] " + f"{batch_times[0].date()} → {batch_times[-1].date()}" + f" ({len(batch_times)} hourly steps)" + ) + arco(time=batch_times, variable=["tp"]) + + # --- invariants (one-off) --- + if args.include_invariants: + print(f"\n[invariants] z, lsm at {static_date.date()}") + arco(time=[static_date], variable=["z", "lsm"]) + + print("\nDone.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/weather/fgn/scripts/stage4_ar_schedule.py b/examples/weather/fgn/scripts/stage4_ar_schedule.py new file mode 100644 index 0000000000..e3786fbaa6 --- /dev/null +++ b/examples/weather/fgn/scripts/stage4_ar_schedule.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Paper Stage-4 AR-finetune scheduler (arXiv:2506.10772v1 Table A.2). + +Chains multiple training stages at increasing ``ar_steps``. Table A.2 +Stage 4 is: + + 8000 steps at 1AR, then + 4000 steps at 2AR, then + 1000 steps each at 3AR, 4AR, 5AR, 6AR, 7AR, 8AR + +(LR decays ``8e-5 → 8e-6 → 8e-7``.) + +This helper runs the sequence **in-process**: for each stage we call +``train.run_training(cfg)``, then copy the final checkpoint of that stage +into the next stage's rundir so the existing ``resume_checkpoint: latest`` +logic in ``Trainer._resume_if_needed`` picks it up without any trainer +code changes. Operators are expected to wrap a single invocation of this +script in their own ``sbatch`` (one GPU job, sequential stages). + +Usage +----- + python scripts/stage4_ar_schedule.py \\ + --config-name fgn_arco_dev \\ + --rundir /mnt/data/.../fgn_stage4 \\ + --stats-path /mnt/data/.../stats.npz \\ + [--dry-run] + +The ``--config-name`` is the Hydra base config; everything else is +supplied via command-line overrides so the orchestration layer is thin +and the per-stage configuration stays faithful to Table A.2. + +Per-stage knobs can be overridden via ``--stages`` which takes a JSON +list of ``{"ar_steps": int, "total_train_steps": int, "lr": float}`` +dicts — useful for dev runs that don't want the full 18 000-step +schedule. +""" + +from __future__ import annotations + +import argparse +import copy +import json +import shutil +import sys +from pathlib import Path + +# Ensure ``train``/``utils``/``datasets`` resolve when invoking this script +# from outside the example directory, e.g. via sbatch --chdir elsewhere. +_EXAMPLE_DIR = Path(__file__).resolve().parents[1] +if str(_EXAMPLE_DIR) not in sys.path: + sys.path.insert(0, str(_EXAMPLE_DIR)) + +from hydra import compose, initialize # noqa: E402 +from omegaconf import DictConfig, OmegaConf # noqa: E402 + +# Stage 4 of paper Table A.2. +PAPER_STAGES: list[dict] = [ + {"ar_steps": 1, "total_train_steps": 8000, "lr": 8e-5}, + {"ar_steps": 2, "total_train_steps": 4000, "lr": 8e-5}, + {"ar_steps": 3, "total_train_steps": 1000, "lr": 8e-5}, + {"ar_steps": 4, "total_train_steps": 1000, "lr": 8e-6}, + {"ar_steps": 5, "total_train_steps": 1000, "lr": 8e-6}, + {"ar_steps": 6, "total_train_steps": 1000, "lr": 8e-7}, + {"ar_steps": 7, "total_train_steps": 1000, "lr": 8e-7}, + {"ar_steps": 8, "total_train_steps": 1000, "lr": 8e-7}, +] + +# Small-footprint dev schedule for quick smoke testing on ARCO. +DEV_STAGES: list[dict] = [ + {"ar_steps": 1, "total_train_steps": 20, "lr": 3e-4}, + {"ar_steps": 2, "total_train_steps": 20, "lr": 1e-4}, + {"ar_steps": 4, "total_train_steps": 10, "lr": 1e-4}, +] + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "--config-name", + default="fgn", + help="Hydra config name. Default: fgn. Use fgn_arco_dev for dev runs.", + ) + p.add_argument( + "--config-path", + default="../config", + help="Hydra config_path relative to this script. Default: ../config.", + ) + p.add_argument( + "--rundir", + required=True, + type=Path, + help="Base run directory; each stage writes to /stage/.", + ) + p.add_argument( + "--stats-path", + type=Path, + default=None, + help="Normalization stats .npz (propagated to dataset.stats_path).", + ) + p.add_argument( + "--stages", + type=str, + default=None, + help="JSON list of stage dicts to override the paper schedule.", + ) + p.add_argument( + "--dev", + action="store_true", + help="Use the small-footprint DEV_STAGES schedule instead of paper.", + ) + p.add_argument( + "--dry-run", + action="store_true", + help="Print each stage's resolved config without invoking the trainer.", + ) + p.add_argument("--extra", nargs="*", default=(), help="Hydra overrides.") + return p.parse_args() + + +def _last_checkpoint(checkpoint_dir: Path) -> Path | None: + candidates = sorted(checkpoint_dir.glob("*.mdlus")) + return candidates[-1] if candidates else None + + +def _seed_from_prev_stage( + prev_checkpoint_dir: Path, stage_checkpoint_dir: Path +) -> None: + """Copy the previous stage's final ``.mdlus`` + ``.pt`` to the new stage. + + ``Trainer._resume_if_needed`` uses ``physicsnemo.utils.load_checkpoint`` + which looks in the configured ``checkpoint_dir``. Copying the final + files over is enough for ``resume_checkpoint: latest`` to pick them up, + and keeps the trainer completely unchanged. + """ + last_mdlus = _last_checkpoint(prev_checkpoint_dir) + if last_mdlus is None: + raise FileNotFoundError( + f"No .mdlus checkpoint found in {prev_checkpoint_dir} — previous stage didn't save?" + ) + stage_checkpoint_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(last_mdlus, stage_checkpoint_dir / last_mdlus.name) + # Copy optimizer/scheduler state if present so the resume is exact. + # physicsnemo names these ``checkpoint.{mp_rank}.{epoch}.pt``; find the + # matching epoch by filename suffix (``..mdlus``). + epoch_suffix = last_mdlus.stem.split(".")[-1] + for cand in prev_checkpoint_dir.glob(f"checkpoint.*.{epoch_suffix}.pt"): + shutil.copy2(cand, stage_checkpoint_dir / cand.name) + + +def build_stage_cfg( + base_cfg: DictConfig, + stage: dict, + stage_rundir: Path, + stats_path: Path | None, +) -> DictConfig: + cfg = copy.deepcopy(base_cfg) + cfg.training.outdir = str(stage_rundir.parent) + cfg.training.experiment_name = stage_rundir.parent.name + cfg.training.run_id = stage_rundir.name + cfg.training.rundir = str(stage_rundir) + cfg.training.ar_steps = int(stage["ar_steps"]) + cfg.training.total_train_steps = int(stage["total_train_steps"]) + cfg.training.optimizer.lr = float(stage["lr"]) + cfg.training.resume_checkpoint = "latest" + if stats_path is not None: + cfg.dataset.stats_path = str(stats_path) + return cfg + + +def main() -> int: + args = parse_args() + + if args.stages is not None: + stages = json.loads(args.stages) + elif args.dev: + stages = DEV_STAGES + else: + stages = PAPER_STAGES + + args.rundir.mkdir(parents=True, exist_ok=True) + + # Hydra's ``initialize`` resolves ``config_path`` relative to THIS + # file, not the caller's cwd. Default ``../config`` points at the + # example's config tree regardless of where the user runs the script. + with initialize(version_base=None, config_path=args.config_path, job_name="stage4"): + base_cfg = compose(config_name=args.config_name, overrides=list(args.extra)) + + prev_checkpoint_dir: Path | None = None + + from train import run_training # noqa: E402 — imported after Hydra setup + + for i, stage in enumerate(stages): + stage_rundir = args.rundir / f"stage{i}_ar{stage['ar_steps']}" + stage_rundir.mkdir(parents=True, exist_ok=True) + + stage_cfg = build_stage_cfg(base_cfg, stage, stage_rundir, args.stats_path) + print( + f"\n[stage {i}] ar_steps={stage['ar_steps']} " + f"steps={stage['total_train_steps']} lr={stage['lr']:g} " + f"rundir={stage_rundir}" + ) + if prev_checkpoint_dir is not None and not args.dry_run: + target = stage_rundir / stage_cfg.training.checkpoint_dir + _seed_from_prev_stage(prev_checkpoint_dir, target) + print(f"[stage {i}] seeded from {prev_checkpoint_dir} → {target}") + + if args.dry_run: + print(OmegaConf.to_yaml(stage_cfg)) + else: + run_training(stage_cfg) + + prev_checkpoint_dir = stage_rundir / stage_cfg.training.checkpoint_dir + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/weather/fgn/test_loss.py b/examples/weather/fgn/test_loss.py new file mode 100644 index 0000000000..be3e6b9c31 --- /dev/null +++ b/examples/weather/fgn/test_loss.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the fair-CRPS loss used in FGN training. + +The reference values are computed directly from arXiv:2506.10772v1 eq. (4): + + fCRPS(x_{1:N}, y) = (1/N) sum_n |x_n - y| + - (1/(2 N (N-1))) sum_{n, n'} |x_n - x_{n'}| +""" + +import math + +import numpy as np +import pytest +import torch +from utils.loss import ( + build_area_weights, + build_channel_weights, + ensemble_mean_mse, + fair_crps, +) + +# --------------------------------------------------------------------------- +# Paper eq. (4): scalar reference calculations +# --------------------------------------------------------------------------- + + +def _fcrps_reference(x: list[float], y: float) -> float: + """Reference fCRPS per eq. (4) using naive Python (no torch/numpy).""" + n = len(x) + first = sum(abs(xi - y) for xi in x) / n + pair = sum(abs(xi - xj) for xi in x for xj in x) + second = pair / (2.0 * n * (n - 1)) + return first - second + + +def _to_ensemble(x: list[float]) -> torch.Tensor: + """Turn a flat list into shape ``[1, N, 1, 1, 1]``.""" + return torch.tensor(x, dtype=torch.float64).view(1, -1, 1, 1, 1) + + +@pytest.mark.parametrize( + "x,y,expected", + [ + # Symmetric ensemble around truth, perfect calibration. + # N=2, y=0, x=[1,-1]: first=1, second=(|1-1|+|1+1|+|-1-1|+|-1+1|)/4=1 → 0. + ([1.0, -1.0], 0.0, 0.0), + # N=3, y=0, x=[1,2,3] — worked example from the chat log. + # first = (1+2+3)/3 = 2 + # pair = 0+1+2 + 1+0+1 + 2+1+0 = 8 + # second = 8 / (2*3*2) = 8/12 + # fCRPS = 2 - 8/12 + ([1.0, 2.0, 3.0], 0.0, 2.0 - 8.0 / 12.0), + # Degenerate ensemble: all members identical, spread term = 0, + # so fCRPS reduces to plain |x - y|. + ([5.0, 5.0, 5.0], 3.0, 2.0), + ], +) +def test_fair_crps_matches_paper_eq4(x, y, expected): + ensemble = _to_ensemble(x) + target = torch.tensor(y, dtype=torch.float64).view(1, 1, 1, 1) + got = fair_crps(ensemble, target).item() + assert math.isclose(got, expected, rel_tol=0, abs_tol=1e-12), (got, expected) + + +@pytest.mark.parametrize( + "x,y", + [ + ([0.3, -1.7], 1.1), + ([-2.0, 0.5, 4.5], -0.3), + ([1.0, 1.0, 2.0, 2.0, 3.0], 1.5), + ], +) +def test_fair_crps_matches_naive_reference(x, y): + ensemble = _to_ensemble(x) + target = torch.tensor(y, dtype=torch.float64).view(1, 1, 1, 1) + expected = _fcrps_reference(x, y) + got = fair_crps(ensemble, target).item() + assert math.isclose(got, expected, rel_tol=1e-12, abs_tol=1e-12) + + +def test_fair_crps_rejects_single_member(): + with pytest.raises(ValueError, match="at least two"): + fair_crps(torch.zeros(1, 1, 1, 1, 1), torch.zeros(1, 1, 1, 1)) + + +def test_fair_crps_rejects_bad_shapes(): + with pytest.raises(ValueError, match=r"\[B, M, C, H, W\]"): + fair_crps(torch.zeros(1, 2, 1, 1), torch.zeros(1, 1, 1, 1)) + with pytest.raises(ValueError, match=r"\[B, C, H, W\]"): + fair_crps(torch.zeros(1, 2, 1, 1, 1), torch.zeros(1, 1, 1)) + + +# --------------------------------------------------------------------------- +# Paper eq. (5): weighted loss reduction matches (1/G) sum_i a_i fCRPS_i +# --------------------------------------------------------------------------- + + +def test_weighted_fair_crps_matches_paper_eq5(): + # Two channels, 1x3 spatial, N=2 ensemble, hand-computed. + # Build a case where channels have different per-location fCRPS values. + ensemble = torch.tensor( + [ + [ + [[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]], + [[[-1.0, 0.0, 1.0]], [[2.0, 3.0, 4.0]]], + ] + ], + dtype=torch.float64, + ) # shape (B=1, M=2, C=2, H=1, W=3) + target = torch.tensor( + [[[[0.0, 0.0, 0.0]], [[5.0, 5.0, 5.0]]]], dtype=torch.float64 + ) # (1, 2, 1, 3) + + # Per-location fCRPS with N=2: + # fCRPS = (|x_1 - y| + |x_2 - y|) / 2 - |x_1 - x_2| / 2 + # (since pairwise sum = 2 * |x_1 - x_2| and 2N(N-1) = 4, so second term + # = 2|x_1 - x_2|/4 = |x_1 - x_2|/2) + def fc(x1, x2, y): + return (abs(x1 - y) + abs(x2 - y)) / 2.0 - abs(x1 - x2) / 2.0 + + per_loc = torch.tensor( + [ + [ + [[fc(1, -1, 0), fc(2, 0, 0), fc(3, 1, 0)]], + [[fc(4, 2, 5), fc(5, 3, 5), fc(6, 4, 5)]], + ] + ], + dtype=torch.float64, + ) + # Unweighted: mean over (B, C, H, W). + expected_unw = per_loc.mean().item() + got_unw = fair_crps(ensemble, target).item() + assert math.isclose(got_unw, expected_unw, rel_tol=0, abs_tol=1e-12) + + # Weighted reduction per paper eq. (5): (1/G) sum a_i fCRPS_i. + # weights shape (1, 2, 1, 1): per-channel. + weights = torch.tensor([0.1, 1.0], dtype=torch.float64).view(1, 2, 1, 1) + expected_w = (per_loc * weights).mean().item() + got_w = fair_crps(ensemble, target, weights=weights).item() + assert math.isclose(got_w, expected_w, rel_tol=0, abs_tol=1e-12) + + +def test_weighted_fair_crps_broadcast_shapes(): + ensemble = torch.randn(2, 3, 4, 5, 6, dtype=torch.float64) + target = torch.randn(2, 4, 5, 6, dtype=torch.float64) + + # Per-channel weights only. + w_c = torch.rand(1, 4, 1, 1, dtype=torch.float64) + 0.1 + # Per-lat weights only. + w_h = torch.rand(1, 1, 5, 1, dtype=torch.float64) + 0.1 + # Full broadcast. + w_full = (w_c * w_h).broadcast_to(2, 4, 5, 6) + + loss_c = fair_crps(ensemble, target, weights=w_c).item() + loss_h = fair_crps(ensemble, target, weights=w_h).item() + loss_full = fair_crps(ensemble, target, weights=w_full).item() + + # Not equal to each other, but each should match the explicit per-loc + # computation when reduced with .mean(). + assert math.isfinite(loss_c) and math.isfinite(loss_h) and math.isfinite(loss_full) + + +# --------------------------------------------------------------------------- +# ensemble_mean_mse +# --------------------------------------------------------------------------- + + +def test_ensemble_mean_mse_matches_hand(): + ensemble = torch.tensor( + [[[[[1.0]]], [[[3.0]]]]], dtype=torch.float64 + ) # (1, 2, 1, 1, 1) + target = torch.tensor([[[[2.0]]]], dtype=torch.float64) + # mean pred = 2.0, squared error = 0. + assert ensemble_mean_mse(ensemble, target).item() == pytest.approx(0.0) + + target2 = torch.tensor([[[[5.0]]]], dtype=torch.float64) + # mean pred = 2.0, error = -3, sq = 9. + assert ensemble_mean_mse(ensemble, target2).item() == pytest.approx(9.0) + + +# --------------------------------------------------------------------------- +# build_channel_weights — paper §2.2.3 scheme +# --------------------------------------------------------------------------- + + +def test_channel_weights_surface_and_t2m(): + # Paper scheme: surface → 0.1; t2m special-cased to 1.0. + w = build_channel_weights(["u10m", "v10m", "t2m", "msl"]) + assert np.allclose(w, [0.1, 0.1, 1.0, 0.1]) + + +def test_channel_weights_atmospheric_linear_by_level(): + # Two atmospheric variables of the same prefix, levels 300 and 500. + # Expected: level / sum(levels) = 3/8 and 5/8. + w = build_channel_weights(["t300", "t500"]) + assert math.isclose(w[0], 3 / 8, abs_tol=1e-6) + assert math.isclose(w[1], 5 / 8, abs_tol=1e-6) + + +def test_channel_weights_geopotential_halved(): + # Paper §2.2.3: geopotential weights halved to tame overfitting. + w_z = build_channel_weights(["z300", "z500"]) + assert math.isclose(w_z[0], 0.5 * 3 / 8, abs_tol=1e-6) + assert math.isclose(w_z[1], 0.5 * 5 / 8, abs_tol=1e-6) + + # Non-geopotential atmospheric (temperature) with the same levels is NOT + # halved — acts as a control to confirm the halving is scoped to z*. + w_t = build_channel_weights(["t300", "t500"]) + assert math.isclose(w_z[0] * 2.0, w_t[0], abs_tol=1e-6) + assert math.isclose(w_z[1] * 2.0, w_t[1], abs_tol=1e-6) + + +def test_channel_weights_paper_table_a1_mixed(): + # Mix of atmospheric and surface from Table A.1 — all weights positive + # and finite, geopotential scaled down relative to temperature. + variables = [ + "z500", + "z850", # atmospheric geopotential, halved + "t500", + "t850", # atmospheric temperature, NOT halved + "u10m", + "v10m", + "t2m", + "msl", # surface + ] + w = build_channel_weights(variables) + assert (w > 0).all() and np.all(np.isfinite(w)) + # z at level L_i gets half the temperature weight at level L_i. + assert math.isclose(w[0] * 2.0, w[2], abs_tol=1e-6) # z500 -> t500 + assert math.isclose(w[1] * 2.0, w[3], abs_tol=1e-6) # z850 -> t850 + # Surface layout (float32 tolerance). + assert math.isclose(w[4], 0.1, abs_tol=1e-6) + assert math.isclose(w[5], 0.1, abs_tol=1e-6) + assert math.isclose(w[6], 1.0, abs_tol=1e-6) + assert math.isclose(w[7], 0.1, abs_tol=1e-6) + + +# --------------------------------------------------------------------------- +# build_area_weights +# --------------------------------------------------------------------------- + + +def test_area_weights_normalised_to_unit_mean(): + for h in (37, 73, 181, 721): + w = build_area_weights(h) + assert w.shape == (h, 1) + # Mean over latitude equals 1 by construction. + assert math.isclose(float(w.mean()), 1.0, abs_tol=1e-6) + # Poles get the smallest weight, equator the largest. + assert w[0, 0] < w[h // 2, 0] + assert w[-1, 0] < w[h // 2, 0] diff --git a/examples/weather/fgn/test_metrics.py b/examples/weather/fgn/test_metrics.py new file mode 100644 index 0000000000..f29313b697 --- /dev/null +++ b/examples/weather/fgn/test_metrics.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the Figure 2 + Figure 3 validation diagnostics in +``utils/metrics.py``. + +The lightweight torch kernels here mirror the coord-aware +``earth2studio.statistics.{rmse, spread_skill_ratio, rank_histogram, crps}`` +family; canonical numerics are cross-checked against hand computation +and the paper's eq. (4). +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +import torch +from utils.metrics import ( + crps_per_variable_per_lead, + derived_variable_crps, + ensemble_rmse_per_variable_per_lead, + flag_bad_seeds, + plot_metric_vs_lead, + plot_power_spectra, + plot_rank_histograms, + power_spectra_per_variable, + rank_histogram_per_variable, + save_summary, + spread_skill_per_variable_per_lead, +) + + +def _fixture(seed: int = 0, B: int = 2, K: int = 3, M: int = 4, C: int = 5): + torch.manual_seed(seed) + ensemble = torch.randn(B, K, M, C, 6, 6, dtype=torch.float64) + target = torch.randn(B, K, C, 6, 6, dtype=torch.float64) + return ensemble, target + + +def test_crps_per_variable_per_lead_shape_and_finite(): + ens, tgt = _fixture() + out = crps_per_variable_per_lead(ens, tgt) + assert out.shape == (ens.shape[1], ens.shape[3]) + assert np.all(np.isfinite(out)) + + +def test_ensemble_rmse_matches_hand(): + # With a single-member ensemble (M=1) RMSE reduces to the deterministic + # RMSE of that single prediction against the target. + torch.manual_seed(0) + B, K, C, H, W = 1, 1, 2, 3, 3 + pred = torch.randn(B, K, 1, C, H, W, dtype=torch.float64) + tgt = torch.randn(B, K, C, H, W, dtype=torch.float64) + out = ensemble_rmse_per_variable_per_lead(pred, tgt) + expected = ((pred[:, :, 0] - tgt) ** 2).mean(dim=(0, -2, -1)).sqrt().numpy() + np.testing.assert_allclose(out, expected, rtol=1e-12) + + +def test_spread_skill_degenerate_ensemble_is_zero_over_skill(): + # Identical members → variance is 0 → spread=0, ratio=0. + torch.manual_seed(0) + B, K, M, C, H, W = 1, 2, 3, 2, 4, 4 + base = torch.randn(B, K, C, H, W, dtype=torch.float64) + pred = base.unsqueeze(2).expand(B, K, M, C, H, W) + tgt = base + 1.0 # non-zero skill + spread, skill, ratio = spread_skill_per_variable_per_lead(pred, tgt) + np.testing.assert_allclose(spread, np.zeros_like(spread), atol=1e-12) + np.testing.assert_allclose(ratio, np.zeros_like(ratio), atol=1e-12) + # Skill is nonzero when ensemble mean misses target. + assert np.all(skill > 0) + + +def test_rank_histogram_sums_to_total_positions(): + ens, tgt = _fixture() + hist = rank_histogram_per_variable(ens, tgt) + # Sum per channel == B * K * H * W. + B, K, _, C, H, W = ens.shape + assert hist.shape == (C, ens.shape[2] + 1) + expected_total = B * K * H * W + np.testing.assert_array_equal(hist.sum(axis=1), np.full(C, expected_total)) + + +def test_derived_variable_crps_returns_wspd_and_dz_when_present(): + # Fabricate a tensor with a variable ordering that contains all the + # components needed by both derived quantities. + variables = ["u10m", "v10m", "z300", "z500", "t2m"] + ens, tgt = _fixture(C=5) + derived = derived_variable_crps(ens, tgt, variables) + assert set(derived) == {"wspd10m", "z300_minus_z500"} + for arr in derived.values(): + assert arr.shape == (ens.shape[1],) + assert np.all(np.isfinite(arr)) + + +def test_derived_variable_crps_skips_missing_components(): + variables = ["t2m", "msl"] + ens, tgt = _fixture(C=2) + derived = derived_variable_crps(ens, tgt, variables) + assert derived == {} + + +def test_power_spectra_shape(): + torch.manual_seed(0) + B, K, C, H, W = 1, 2, 3, 32, 32 + ens_mean = torch.randn(B, K, C, H, W, dtype=torch.float32) + tgt = torch.randn(B, K, C, H, W, dtype=torch.float32) + k, ens_pow, tgt_pow = power_spectra_per_variable(ens_mean, tgt) + assert k.ndim == 1 + assert ens_pow.shape == (K, C, k.size) + assert tgt_pow.shape == (K, C, k.size) + assert np.all(ens_pow >= 0) and np.all(tgt_pow >= 0) + + +def test_flag_bad_seeds_picks_amplified_tail_only(): + # 3 seeds, 2 leads, 2 channels, 8 wavenumber bins. Truth has a flat + # low-power tail (top 20% = last 2 bins). + K, C, B = 2, 2, 8 + truth = np.ones((K, C, B), dtype=np.float64) + truth[..., -2:] = 0.1 # small truth tail + + fore = np.tile(truth, (3, 1, 1, 1)).astype(np.float64) # (S, K, C, B) + # Seed 1 amplifies the tail of one channel / one lead 10x — above 3x + # threshold. + fore[1, 1, 0, -2:] = 1.0 # 10x the truth tail + # Seed 2 amplifies modestly (2x) — below threshold. + fore[2, :, :, -2:] = 0.2 + + bad = flag_bad_seeds(fore, truth, tail_fraction=0.25, threshold=3.0) + assert bad == [1] + + # Tighter threshold catches seed 2 too. + bad_strict = flag_bad_seeds(fore, truth, tail_fraction=0.25, threshold=1.5) + assert bad_strict == [1, 2] + + +def test_flag_bad_seeds_validates_shapes_and_knobs(): + truth = np.ones((1, 1, 4)) + good = np.ones((1, 1, 1, 4)) + with pytest.raises(ValueError, match="shape"): + flag_bad_seeds(good.squeeze(0), truth) # wrong ndim + with pytest.raises(ValueError, match="shape"): + flag_bad_seeds(good, truth[:, :, :3]) # mismatched bins + with pytest.raises(ValueError, match="tail_fraction"): + flag_bad_seeds(good, truth, tail_fraction=0.0) + with pytest.raises(ValueError, match="threshold"): + flag_bad_seeds(good, truth, threshold=0.0) + + +def test_flag_bad_seeds_returns_empty_when_all_ok(): + truth = np.ones((1, 1, 4)) + fore = truth[None].repeat(4, axis=0) # (4, 1, 1, 4), all matching truth + assert flag_bad_seeds(fore, truth, threshold=1.5) == [] + + +def test_plot_and_save_roundtrip(tmp_path: Path): + ens, tgt = _fixture() + crps = crps_per_variable_per_lead(ens, tgt) + variables = [f"v{i}" for i in range(ens.shape[3])] + leads = np.arange(1, ens.shape[1] + 1) + + crps_path = tmp_path / "crps.png" + was_plotted = plot_metric_vs_lead( + crps, variables, leads, "CRPS", "test crps", str(crps_path) + ) + if was_plotted: # matplotlib available + assert crps_path.is_file() and crps_path.stat().st_size > 0 + + hist = rank_histogram_per_variable(ens, tgt) + hist_path = tmp_path / "hist.png" + was_plotted = plot_rank_histograms(hist, variables, str(hist_path)) + if was_plotted: + assert hist_path.is_file() and hist_path.stat().st_size > 0 + + B, K, C, H, W = 1, 1, 2, 32, 32 + ens_mean = torch.randn(B, K, C, H, W, dtype=torch.float32) + tgt2 = torch.randn(B, K, C, H, W, dtype=torch.float32) + k, ens_pow, tgt_pow = power_spectra_per_variable(ens_mean, tgt2) + spec_path = tmp_path / "spec.png" + was_plotted = plot_power_spectra( + k, ens_pow, tgt_pow, ["v0", "v1"], lead_idx=0, out_path=str(spec_path) + ) + if was_plotted: + assert spec_path.is_file() and spec_path.stat().st_size > 0 + + summary_path = tmp_path / "summary.npz" + save_summary( + {"crps_per_lead_per_channel": crps, "rank_histograms": hist}, str(summary_path) + ) + loaded = np.load(summary_path) + assert "crps_per_lead_per_channel" in loaded.files + assert "rank_histograms" in loaded.files diff --git a/examples/weather/fgn/test_tp.py b/examples/weather/fgn/test_tp.py new file mode 100644 index 0000000000..eec4b8614e --- /dev/null +++ b/examples/weather/fgn/test_tp.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the paper §3 / Table A.1 predicted-only ``tp06`` channel. + +The accumulation helper ``_fetch_tp_accumulation`` is tested with a +monkey-patched ARCO client so we don't touch GCS — the logic we care +about is: (a) union of required hourly stamps, (b) correct per-frame +sum of the N preceding hourly values, (c) zeroing of the history tp +channel vs correct target accumulation. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta +from types import SimpleNamespace + +import numpy as np +import pytest +from datasets.arco import ArcoFGNDataset + + +@pytest.fixture +def cfg() -> SimpleNamespace: + return SimpleNamespace( + state_variables=["u10m", "v10m", "t2m", "msl", "tp06"], + invariant_variables=[], + step_hours=6, + history_frames=2, + future_frames=2, + train_start="2016-01-01", + train_end="2016-01-15", + val_start="2018-01-01", + val_end="2018-02-01", + spatial_stride=40, + static_date="2016-01-01", + arco_cache=False, + stats_path=None, + tp_accumulation_hours=6, + ) + + +def test_tp_config_enforces_channel_presence(): + bad = SimpleNamespace( + state_variables=["u10m", "v10m"], # no tp06 + invariant_variables=[], + step_hours=6, + history_frames=2, + future_frames=1, + train_start="2016-01-01", + train_end="2016-01-15", + val_start="2018-01-01", + val_end="2018-02-01", + spatial_stride=40, + tp_accumulation_hours=6, + ) + with pytest.raises(ValueError, match="tp06"): + ArcoFGNDataset(bad, train=True) + + +def test_tp_channel_index_and_output_only(cfg): + ds = ArcoFGNDataset(cfg, train=True) + assert ds.state_channels() == ["u10m", "v10m", "t2m", "msl", "tp06"] + assert ds._tp_channel_idx == 4 + assert ds.output_only_channels() == [4] + assert ds.tp_accumulation_hours == 6 + + +def test_tp_no_op_when_disabled(): + c = SimpleNamespace( + state_variables=["u10m", "v10m"], + invariant_variables=[], + step_hours=6, + history_frames=2, + future_frames=1, + train_start="2016-01-01", + train_end="2016-01-15", + val_start="2018-01-01", + val_end="2018-02-01", + spatial_stride=40, + tp_accumulation_hours=None, + ) + ds = ArcoFGNDataset(c, train=True) + assert ds.output_only_channels() == [] + + +class _FakeARCO: + """Stand-in for ``earth2studio.data.ARCO`` that returns a deterministic + value derived from the request timestamp so accumulation sums are + easy to hand-verify. Only used for the accumulation helper test. + """ + + def __init__(self, reference: datetime): + self.reference = reference + + def __call__(self, *, time, variable): + import xarray as xr + + assert variable == ["tp"], variable + data = np.zeros((len(time), 1, 721, 1440), dtype=np.float32) + for i, t in enumerate(time): + # Value = hours since reference (so a 6-hour accumulation ending + # at T yields sum of the 6 integer offsets preceding T). + offset = (t - self.reference).total_seconds() / 3600.0 + data[i, 0, :, :] = float(offset) + return xr.DataArray( + data, + dims=("time", "variable", "lat", "lon"), + ) + + +def test_tp_accumulation_sums_six_hourly_values(cfg): + ds = ArcoFGNDataset(cfg, train=True) + reference = datetime(2016, 1, 1, 0, 0) + ds._arco = _FakeARCO(reference) + + frame_times = [ + reference + timedelta(hours=24), + reference + timedelta(hours=30), + ] + acc = ds._fetch_tp_accumulation(frame_times) + assert acc.shape == (2, ds.height, ds.width) + # For T = reference + 24h, 6-hour accumulation sums offsets + # {19, 20, 21, 22, 23, 24} = 129. + np.testing.assert_allclose(acc[0], np.full((ds.height, ds.width), 129.0)) + # For T = reference + 30h, sums {25, 26, 27, 28, 29, 30} = 165. + np.testing.assert_allclose(acc[1], np.full((ds.height, ds.width), 165.0)) + + +def test_tp_getitem_zeros_history_and_accumulates_target(cfg, monkeypatch): + ds = ArcoFGNDataset(cfg, train=True) + + reference = datetime(2016, 1, 1, 0, 0) + ds._arco = _FakeARCO(reference) + + # Stub the state-variable ARCO fetch so __getitem__ doesn't hit the + # network. We return uniform values equal to the channel index to + # make shape + indexing easy to verify. + import xarray as xr + + original_call = ds._arco.__call__ + + def patched_call(*, time, variable): + if variable == ["tp"]: + return original_call(time=time, variable=variable) + data = np.zeros((len(time), len(variable), 721, 1440), dtype=np.float32) + for j in range(len(variable)): + data[:, j] = float(j) + 1.0 # non-zero placeholder + return xr.DataArray(data, dims=("time", "variable", "lat", "lon")) + + monkeypatch.setattr( + ds, "_arco", type("Stub", (), {"__call__": staticmethod(patched_call)})() + ) + # The above monkeypatch replaces `_arco` with an object whose + # `__call__` is our stub; `_ensure_arco()` returns `self._arco`. + + sample = ds[0] + history = sample["history"].numpy() + target = sample["target"].numpy() + + assert history.shape == (2, 5, ds.height, ds.width) + assert target.shape == (2, 5, ds.height, ds.width) + + # Paper §3: tp06 channel (index 4) is zero throughout history. + np.testing.assert_array_equal( + history[:, 4], np.zeros((2, ds.height, ds.width), dtype=np.float32) + ) + # Target tp06 is the accumulation (non-zero by construction). + assert np.all(target[:, 4] > 0.0) + + # Non-tp channels in both history and target are unchanged (= j+1 from stub). + for j in range(4): + np.testing.assert_allclose(history[:, j], float(j) + 1.0) + np.testing.assert_allclose(target[:, j], float(j) + 1.0) diff --git a/examples/weather/fgn/test_training.py b/examples/weather/fgn/test_training.py new file mode 100644 index 0000000000..912c579423 --- /dev/null +++ b/examples/weather/fgn/test_training.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import numpy as np +import torch +from hydra import compose, initialize +from inference import _allocate_members, run_inference +from train import run_training +from utils.trainer import find_latest_model_checkpoint + + +def _load_config(config_name: str): + with initialize(version_base=None, config_path="config", job_name="test_fgn"): + return compose(config_name=config_name) + + +def test_fgn_training_and_inference(tmp_path: Path): + cfg = _load_config("test_fgn") + cfg.training.outdir = str(tmp_path) + cfg.training.experiment_name = "fgn-smoke" + cfg.training.run_id = "0" + cfg.training.rundir = str(tmp_path / "fgn-smoke" / "0") + + run_training(cfg) + + checkpoint = find_latest_model_checkpoint( + Path(cfg.training.rundir) / cfg.training.checkpoint_dir + ) + assert checkpoint.endswith(".mdlus") + + infer_cfg = _load_config("inference_fgn") + infer_cfg.training.rundir = cfg.training.rundir + infer_cfg.inference.checkpoint = "latest" + infer_cfg.inference.output_path = str(tmp_path / "forecast.pt") + + result = run_inference(infer_cfg) + assert Path(result["output_path"]).is_file() + assert result["num_models"] == 1 + assert result["members_per_model"] == [int(infer_cfg.inference.num_trajectories)] + + payload = torch.load(result["output_path"], map_location="cpu") + assert payload["trajectories"].ndim == 5 + assert payload["target"].ndim == 4 + + +def test_fgn_deep_ensemble_inference(tmp_path: Path): + """Two independently-trained seeds rolled out together (paper §2.2.1).""" + # Train two seeds with distinct run_ids so checkpoints live in separate + # directories we can point the ensemble inference path at. + checkpoint_paths: list[str] = [] + for seed_idx, seed in enumerate([7, 13]): + cfg = _load_config("test_fgn") + cfg.training.outdir = str(tmp_path) + cfg.training.experiment_name = "fgn-ensemble" + cfg.training.run_id = f"seed{seed_idx}" + cfg.training.rundir = str(tmp_path / "fgn-ensemble" / f"seed{seed_idx}") + cfg.training.seed = seed + run_training(cfg) + checkpoint_paths.append( + find_latest_model_checkpoint( + Path(cfg.training.rundir) / cfg.training.checkpoint_dir + ) + ) + assert len(checkpoint_paths) == 2 + + infer_cfg = _load_config("inference_fgn") + # rundir is unused when ``checkpoints`` is given; set something benign. + infer_cfg.training.rundir = str(tmp_path / "fgn-ensemble" / "seed0") + infer_cfg.inference.checkpoints = checkpoint_paths + # 5 trajectories across 2 models -> [3, 2] (remainder on earlier model). + infer_cfg.inference.num_trajectories = 5 + infer_cfg.inference.output_path = str(tmp_path / "ensemble_forecast.pt") + + result = run_inference(infer_cfg) + assert result["num_models"] == 2 + assert result["members_per_model"] == [3, 2] + + payload = torch.load(result["output_path"], map_location="cpu") + # Trajectories from both models concatenated on the leading axis. + assert payload["trajectories"].shape[0] == 5 + assert payload["num_models"] == 2 + assert payload["checkpoint_paths"] == checkpoint_paths + + +def test_fgn_training_writes_validation_metrics(tmp_path: Path): + """With ``training.validation_metrics: true`` the trainer should emit + an .npz summary + PNG plots under ``rundir/validation/step=/``. + """ + cfg = _load_config("test_fgn") + cfg.training.outdir = str(tmp_path) + cfg.training.experiment_name = "fgn-diag" + cfg.training.run_id = "0" + cfg.training.rundir = str(tmp_path / "fgn-diag" / "0") + cfg.training.validation_metrics = True + cfg.training.validation_ensemble_size = 2 + + run_training(cfg) + + val_root = Path(cfg.training.rundir) / "validation" + assert val_root.is_dir() + step_dirs = sorted(val_root.glob("step=*")) + assert step_dirs, "expected at least one validation snapshot" + npz_path = step_dirs[-1] / "metrics.npz" + assert npz_path.is_file() + data = np.load(npz_path, allow_pickle=True) + # Paper Figure 2 core panels are all present. + for key in ( + "crps_per_lead_per_channel", + "rmse_per_lead_per_channel", + "spread_skill_ratio", + "rank_histograms", + "power_spectrum_forecast", + "power_spectrum_truth", + ): + assert key in data.files, key + + +def test_allocate_members_distribution(): + # Equal split when divisible. + assert _allocate_members(16, 4) == [4, 4, 4, 4] + # Remainder on the earlier models (paper-faithful default). + assert _allocate_members(14, 4) == [4, 4, 3, 3] + assert _allocate_members(1, 4) == [1, 0, 0, 0] + assert _allocate_members(0, 4) == [0, 0, 0, 0] + # Single-model degenerate case. + assert _allocate_members(7, 1) == [7] diff --git a/examples/weather/fgn/train.py b/examples/weather/fgn/train.py new file mode 100644 index 0000000000..d9eb2eab83 --- /dev/null +++ b/examples/weather/fgn/train.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Train a latent-conditioned FGN weather model.""" + +import hydra +from omegaconf import DictConfig +from utils.trainer import Trainer + +from physicsnemo.distributed import DistributedManager + + +def run_training(cfg: DictConfig) -> None: + DistributedManager.initialize() + trainer = Trainer(cfg) + trainer.train() + + +@hydra.main(version_base=None, config_path="config", config_name="fgn") +def main(cfg: DictConfig) -> None: + run_training(cfg) + + +if __name__ == "__main__": + main() diff --git a/examples/weather/fgn/utils/config.py b/examples/weather/fgn/utils/config.py new file mode 100644 index 0000000000..a8a175921e --- /dev/null +++ b/examples/weather/fgn/utils/config.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal + +from pydantic import Field +from pydantic.dataclasses import dataclass + + +@dataclass(config={"extra": "allow"}) +class DatasetConfig: + name: str + + +@dataclass(config={"extra": "forbid"}) +class ModelConfig: + model_name: Literal["fgn"] = "fgn" + history_frames: int = Field(default=2, ge=2) + latent_dim: int = Field(default=16, ge=1) + hidden_channels: int = Field(default=32, ge=4) + background_channels: int | Literal["auto"] = "auto" + invariant_channels: int | Literal["auto"] = "auto" + group_norm_groups: int = Field(default=8, ge=1) + + +@dataclass(config={"extra": "forbid"}) +class OptimizerConfig: + lr: float = Field(default=3e-4, gt=0.0) + betas: tuple[float, float] = (0.9, 0.999) + weight_decay: float = Field(default=1e-4, ge=0.0) + + +@dataclass(config={"extra": "forbid"}) +class LossConfig: + num_samples: int = Field(default=4, ge=2) + mse_weight: float = Field(default=0.1, ge=0.0) + # GraphCast-style per-variable weights with geopotential halved per + # FGN §2.2.3. Independent of cos(lat) area weighting. + use_channel_weights: bool = False + # cos(lat) area weighting for the lat/lon grid. + use_area_weights: bool = False + + +@dataclass(config={"extra": "forbid"}) +class TrainingConfig: + outdir: str = "rundir" + experiment_name: str = "fgn" + run_id: str = "0" + rundir: str = "rundir/fgn/0" + checkpoint_dir: str = "checkpoints" + num_data_workers: int = Field(default=0, ge=0) + seed: int = 7 + batch_size: int = Field(default=8, ge=1) + total_train_steps: int = Field(default=100, ge=1) + print_progress_freq: int = Field(default=10, ge=1) + checkpoint_freq: int = Field(default=50, ge=1) + validation_freq: int = Field(default=25, ge=1) + resume_checkpoint: int | Literal["latest"] | None = "latest" + clip_grad_norm: float = -1.0 + ar_steps: int = Field(default=1, ge=1, le=8) + # Data + domain parallelism knobs. Mirrors StormCast's convention. + # - domain_parallel_size=1 & force_sharding=False → pure single-process + # or plain DDP, no ShardTensor overhead (default for smoke tests). + # - domain_parallel_size>1 → spatial sharding on the domain mesh axis. + # - force_sharding=True → wrap tensors/model in ShardTensor even with a + # single domain rank (useful to test the sharding path end-to-end). + domain_parallel_size: int = Field(default=1, ge=1) + force_sharding: bool = False + # Validation diagnostic hooks. When enabled, the trainer runs a short + # ensemble rollout on a single validation batch at each ``validation_freq`` + # step and writes per-variable CRPS / RMSE / spread-skill / rank-hist / + # power-spectrum artifacts (Figures 2 + 3 of arXiv:2506.10772v1, minus + # baseline-dependent scorecards and REV) to ``rundir/validation/``. + # Cap on per-rank validation batches when running under ParallelHelper + # (the rank-sharded sampler is infinite by design — StormCast convention). + # None = sweep one local epoch. + validation_steps: int | None = None + validation_metrics: bool = False + validation_ensemble_size: int = Field(default=4, ge=2) + optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) + loss: LossConfig = Field(default_factory=LossConfig) + + +@dataclass(config={"extra": "forbid"}) +class InferenceConfig: + # Single-checkpoint mode: set ``checkpoint`` ("latest" or a path). + # Deep-ensemble mode: set ``checkpoints`` (list of paths); paper §2.2.1 + # uses J=4 independently-trained models with equal members each and a + # fixed model identity per trajectory. + checkpoint: str = "latest" + checkpoints: list[str] | None = None + dataset_index: int = Field(default=0, ge=0) + num_steps: int = Field(default=3, ge=1) + num_trajectories: int = Field(default=4, ge=1) + seed: int = 17 + output_path: str = "rundir/fgn/0/forecast.pt" + + +@dataclass(config={"extra": "forbid"}) +class EvalConfig: + checkpoint: str = "latest" + checkpoints: list[str] | None = None + future_steps: int = Field(default=20, ge=1, le=60) + ensemble_size: int = Field(default=8, ge=2) + batch_size: int = Field(default=1, ge=1) + num_workers: int = Field(default=0, ge=0) + outdir: str = "rundir/fgn/0/eval" + pool_sizes: list[int] = Field(default_factory=lambda: [4, 8, 16, 32]) + + +@dataclass(config={"extra": "forbid"}) +class TrainMainConfig: + dataset: DatasetConfig + model: ModelConfig + training: TrainingConfig + + +@dataclass(config={"extra": "forbid"}) +class InferenceMainConfig: + dataset: DatasetConfig + model: ModelConfig + training: TrainingConfig + inference: InferenceConfig + + +@dataclass(config={"extra": "forbid"}) +class EvalMainConfig: + dataset: DatasetConfig + model: ModelConfig + training: TrainingConfig + eval: EvalConfig diff --git a/examples/weather/fgn/utils/loss.py b/examples/weather/fgn/utils/loss.py new file mode 100644 index 0000000000..143d85cade --- /dev/null +++ b/examples/weather/fgn/utils/loss.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""FGN training loss functions (arXiv:2506.10772 §2.2). + +fair_crps — eq. (4): fair CRPS over ensemble members. +ensemble_mean_mse — MSE of the ensemble mean (supplementary term). +build_channel_weights — per-variable weights (§2.2.3 / GraphCast scheme). +build_area_weights — cos(lat) area weights (§2.2.3). +""" + +from __future__ import annotations + +import re + +import numpy as np +import torch + + +def fair_crps( + ensemble: torch.Tensor, + target: torch.Tensor, + weights: torch.Tensor | None = None, +) -> torch.Tensor: + """Fair CRPS (paper eq. 4), mean-reduced over batch and spatial dims. + + Parameters + ---------- + ensemble : ``(B, M, C, H, W)`` + M ensemble members. + target : ``(B, C, H, W)`` + Ground-truth state. + weights : broadcastable to ``(B, C, H, W)``, optional + Per-location/channel loss weights (eq. 5). + + Returns + ------- + torch.Tensor + Scalar loss. + """ + if ensemble.ndim != 5: + raise ValueError( + f"ensemble must have shape [B, M, C, H, W], got {tuple(ensemble.shape)}" + ) + if target.ndim != 4: + raise ValueError( + f"target must have shape [B, C, H, W], got {tuple(target.shape)}" + ) + M = ensemble.shape[1] + if M < 2: + raise ValueError( + f"fair_crps requires at least two ensemble members, got M={M}" + ) + + # term1: E[|X - y|] per location — shape (B, C, H, W) + term1 = (ensemble - target.unsqueeze(1)).abs().mean(dim=1) + + # term2: (1/2) E[|X - X'|] per location via exhaustive pairwise sum. + # The diagonal is zero so including it is free; the factor 2M(M-1) in the + # denominator matches the fair (unbiased) estimator in eq. (4). + x_i = ensemble.unsqueeze(2) # (B, M, 1, C, H, W) + x_j = ensemble.unsqueeze(1) # (B, 1, M, C, H, W) + term2 = (x_i - x_j).abs().sum(dim=(1, 2)) / (2.0 * M * (M - 1)) + + per_loc = term1 - term2 # (B, C, H, W) + if weights is not None: + per_loc = per_loc * weights + return per_loc.mean() + + +def ensemble_mean_mse( + ensemble: torch.Tensor, + target: torch.Tensor, + weights: torch.Tensor | None = None, +) -> torch.Tensor: + """MSE of the ensemble mean, mean-reduced over batch and spatial dims. + + Parameters + ---------- + ensemble : ``(B, M, C, H, W)`` + target : ``(B, C, H, W)`` + weights : broadcastable to ``(B, C, H, W)``, optional + """ + sq_err = (ensemble.mean(dim=1) - target).pow(2) + if weights is not None: + sq_err = sq_err * weights + return sq_err.mean() + + +def build_channel_weights(state_channels: list[str]) -> np.ndarray: + """Per-channel loss weights following paper §2.2.3 / GraphCast scheme. + + Rules + ----- + - Atmospheric channels (name matches ```` with a purely + numeric suffix, e.g. ``t500``): weight = level / sum_of_levels_in_prefix. + Geopotential (``z*``) weights are halved to tame overfitting. + - Surface channels (anything else): weight = 0.1, except ``t2m`` = 1.0. + + Returns + ------- + np.ndarray, shape ``(C,)``, float32. + """ + # Identify atmospheric channels and their pressure levels. + _atmos_re = re.compile(r"^([a-zA-Z]+)(\d+)$") + prefix_levels: dict[str, list[int]] = {} + for ch in state_channels: + m = _atmos_re.match(ch) + if m: + prefix_levels.setdefault(m.group(1), []).append(int(m.group(2))) + + prefix_sum: dict[str, float] = {p: float(sum(lvls)) for p, lvls in prefix_levels.items()} + + weights = np.zeros(len(state_channels), dtype=np.float32) + for i, ch in enumerate(state_channels): + m = _atmos_re.match(ch) + if m: + prefix, level = m.group(1), int(m.group(2)) + w = level / prefix_sum[prefix] + if prefix == "z": + w *= 0.5 + weights[i] = w + elif ch == "t2m": + weights[i] = 1.0 + else: + weights[i] = 0.1 + + return weights + + +def build_area_weights(H: int) -> np.ndarray: + """Cos(lat) area weights, normalised so the mean over rows equals 1. + + Follows ERA5 north-to-south ordering (lat 90° → −90°). + + Parameters + ---------- + H : int + Number of latitude rows (e.g. 721 for 0.25° ERA5). + + Returns + ------- + np.ndarray, shape ``(H, 1)``, float32. + """ + lats = np.linspace(90.0, -90.0, H, dtype=np.float64) + w = np.cos(np.deg2rad(lats)) + w /= w.mean() + return w.astype(np.float32).reshape(H, 1) diff --git a/examples/weather/fgn/utils/metrics.py b/examples/weather/fgn/utils/metrics.py new file mode 100644 index 0000000000..a371c2ee05 --- /dev/null +++ b/examples/weather/fgn/utils/metrics.py @@ -0,0 +1,870 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Validation diagnostics for FGN, drawing on Figures 2 + 3 of arXiv:2506.10772v1. + +Canonical metric library: **earth2studio has coord-aware equivalents** of +every metric here under ``earth2studio.statistics`` — ``spread_skill_ratio``, +``rmse``, ``rank_histogram``, ``crps``, ``lsd`` (log spectral distance), +``energy_score``, ``fss``, plus ``weights.lat_weight`` for ``cos(lat)``. Those +classes operate on ``(tensor, CoordSystem)`` pairs and are the right choice +for xarray-style evaluation pipelines. This module intentionally uses +lightweight torch kernels against the FGN trainer's pure +``(B, K, M, C, H, W)`` tensors to keep the inline validation hook cheap and +dependency-free; docstrings below cite the canonical equivalent for each +diagnostic so a later refactor can swap them in. + +What we compute per validation rollout: + +- **CRPS per variable per lead time** (Figure 2a, without a baseline). + Delegates to :func:`physicsnemo.metrics.general.crps.kcrps` with + ``biased=False`` (the Zamo-Naveau fair estimator used for training). + Canonical coord-aware equivalent: :class:`earth2studio.statistics.crps`. +- **Spread-skill ratio per variable per lead time** (Figure 2 b-f). + Standard definition: ``spread = sqrt(mean over grid of var across + members)`` vs ``skill = sqrt(mean over grid of MSE of ensemble mean)``. + A well-calibrated ensemble sits near 1. + Canonical: :class:`earth2studio.statistics.spread_skill_ratio`. +- **Ensemble-mean RMSE per variable per lead time** (Figure 2 companion). + Canonical: :class:`earth2studio.statistics.rmse`. +- **Rank histogram** per variable, aggregated over grid + validation + batches. A uniform histogram indicates good calibration; U-shaped → + under-dispersive, hump-shaped → over-dispersive. + Canonical: :class:`earth2studio.statistics.rank_histogram`. +- **Energy score per lead time** (multivariate CRPS generalisation). + Computed over the *variable* axis so that cross-channel calibration + is captured, averaged over a spatially subsampled grid to keep the + O(M²) pairwise term cheap. A single scalar per lead; lower is better. + Added in earth2studio 0.13.0 as :class:`earth2studio.statistics.energy_score`. +- **Azimuthal 1D power spectra** per variable for ensemble-mean vs + ground truth (Figure 3 e-j). Uses + :func:`physicsnemo.metrics.general.power_spectrum.power_spectrum` — + 2D-FFT azimuthally averaged, a reasonable proxy for the paper's + spherical-harmonic spectra without adding a ``torch-harmonics`` dep. + Honeycomb artifacts at the mesh frequency (Figure 5) would surface here + as a localised high-frequency bump. +- **Derived-variable CRPS** for `10m wind speed = sqrt(u10m^2 + v10m^2)` + and `z300 - z500` (Figure 3 c-d), when those component variables are + present in the state channel list. + +Deferred (require more scope or data we don't have yet): + +- FGN-vs-baseline scorecards (no baseline model wired in). +- Pooled CRPS (Figure 3 a-b) — pool-size sweep. +- REV for extreme thresholds (Figure 2 g-h) — needs a climatology. +- Cyclone track evaluation (Figure 4) — needs IBTrACS + Tempest Extremes. +- Direct honeycomb-artifact viz (Figure 5) — implicit in the spectra plot. + +earth2studio examples (see ``earth2studio/examples/02_medium_range/``) use +raw ``matplotlib.pyplot`` + ``cartopy`` for geospatial plots — there is no +shared plotting wrapper to delegate to. The plot helpers below mirror that +convention: plain ``matplotlib`` with ``Agg`` backend so the hook stays +headless-safe. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import numpy as np +import torch + +from physicsnemo.metrics.general.crps import kcrps +from physicsnemo.metrics.general.power_spectrum import power_spectrum + +# --------------------------------------------------------------------------- +# Metric computations. All expect float32 tensors on the same device and +# return CPU numpy arrays for easy plotting / logging. +# --------------------------------------------------------------------------- + + +def _check_shapes(ensemble: torch.Tensor, target: torch.Tensor) -> None: + """Validate the shared (B, K, M, C, H, W) vs (B, K, C, H, W) layout.""" + if ensemble.ndim != 6: + raise ValueError( + f"ensemble must have shape [B, K, M, C, H, W], got {tuple(ensemble.shape)}" + ) + if target.ndim != 5: + raise ValueError( + f"target must have shape [B, K, C, H, W], got {tuple(target.shape)}" + ) + if ensemble.shape[0] != target.shape[0] or ensemble.shape[1] != target.shape[1]: + raise ValueError( + f"ensemble/target batch + lead dims must match, got {tuple(ensemble.shape)} vs {tuple(target.shape)}" + ) + + +def crps_per_variable_per_lead( + ensemble: torch.Tensor, target: torch.Tensor +) -> np.ndarray: + """Fair CRPS averaged over batch + spatial dims, retaining (K, C). + + Shapes: ensemble (B, K, M, C, H, W), target (B, K, C, H, W). + Returns numpy array of shape (K, C). + """ + _check_shapes(ensemble, target) + B, K, M, C, H, W = ensemble.shape + flat_ens = ensemble.reshape(B * K, M, C, H, W) + flat_tgt = target.reshape(B * K, C, H, W) + per_loc = kcrps(flat_ens, flat_tgt, dim=1, biased=False) # (B*K, C, H, W) + per_loc = per_loc.reshape(B, K, C, H, W) + return per_loc.mean(dim=(0, -2, -1)).detach().cpu().numpy() + + +def ensemble_rmse_per_variable_per_lead( + ensemble: torch.Tensor, target: torch.Tensor +) -> np.ndarray: + """sqrt(mean((mean_n x_n − y)^2)) per lead, per channel. Shape (K, C).""" + _check_shapes(ensemble, target) + mean_pred = ensemble.mean(dim=2) # (B, K, C, H, W) + sq = (mean_pred - target) ** 2 + mse = sq.mean(dim=(0, -2, -1)) # (K, C) + return mse.sqrt().detach().cpu().numpy() + + +def spread_skill_per_variable_per_lead( + ensemble: torch.Tensor, target: torch.Tensor +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Ensemble spread, skill, and their ratio — per lead, per channel. + + Standard definitions: + skill = sqrt(mean over grid of MSE of ensemble mean) + spread = sqrt(mean over grid of ensemble variance across members) + + A well-calibrated ensemble has spread/skill ≈ 1 (paper Figure 2 b-f). + + Returns ``(spread, skill, ratio)`` each of shape (K, C) as numpy. + """ + _check_shapes(ensemble, target) + # Variance across member axis (dim=2), unbiased estimator. + member_var = ensemble.var(dim=2, unbiased=True) # (B, K, C, H, W) + spread = member_var.mean(dim=(0, -2, -1)).sqrt() + skill = ensemble_rmse_per_variable_per_lead(ensemble, target) + skill_t = torch.from_numpy(skill).to(spread.device) + ratio = spread / skill_t.clamp_min(1e-12) + return ( + spread.detach().cpu().numpy(), + skill, + ratio.detach().cpu().numpy(), + ) + + +def rank_histogram_per_variable( + ensemble: torch.Tensor, target: torch.Tensor, num_bins: int | None = None +) -> np.ndarray: + """Verification rank histogram per channel, aggregated over batch/lead/grid. + + For each observation, count the rank of the truth among the sorted + ensemble members; the histogram over many observations reveals + calibration. With ``M`` members there are ``M + 1`` possible ranks. + + Returns an integer numpy array of shape ``(C, num_bins)`` with + ``num_bins = M + 1`` by default. + """ + _check_shapes(ensemble, target) + B, K, M, C, H, W = ensemble.shape + bins = num_bins if num_bins is not None else M + 1 + # For each position, rank of target among members is the count of + # members strictly less than target (breaks ties by assigning the + # lowest possible rank). Paper-standard convention uses a random + # tiebreak; with float32 data ties are rare, so strict-less-than is + # a reasonable MVP. + less = (ensemble < target.unsqueeze(2)).sum(dim=2) # (B, K, C, H, W) + # Rank values are in 0..M inclusive. Scale/clip to 0..bins-1. + less = less.to(torch.long) + if bins != M + 1: + less = (less * bins // (M + 1)).clamp(0, bins - 1) + hist = torch.zeros(C, bins, dtype=torch.long, device=less.device) + for c in range(C): + flat = less[..., c, :, :].reshape(-1) + hist[c] = torch.bincount(flat, minlength=bins) + return hist.detach().cpu().numpy() + + +def derived_variable_crps( + ensemble: torch.Tensor, + target: torch.Tensor, + variables: Sequence[str], +) -> dict[str, np.ndarray]: + """Paper Figure 3 c-d: CRPS of derived quantities. + + - ``wspd10m = sqrt(u10m^2 + v10m^2)`` when both ``u10m`` and ``v10m`` + are present. + - ``z300_minus_z500 = z300 - z500`` when both levels are present. + + Returns a dict of ``name -> (K,)`` CRPS arrays; empty dict if neither + derived quantity is available. + """ + _check_shapes(ensemble, target) + idx = {name: i for i, name in enumerate(variables)} + derived: dict[str, np.ndarray] = {} + + if "u10m" in idx and "v10m" in idx: + u_ens = ensemble[:, :, :, idx["u10m"]] + v_ens = ensemble[:, :, :, idx["v10m"]] + u_tgt = target[:, :, idx["u10m"]] + v_tgt = target[:, :, idx["v10m"]] + wspd_ens = torch.sqrt(u_ens**2 + v_ens**2).unsqueeze(3) # (B,K,M,1,H,W) + wspd_tgt = torch.sqrt(u_tgt**2 + v_tgt**2).unsqueeze(2) # (B,K,1,H,W) + per_loc = kcrps( + wspd_ens.reshape(-1, wspd_ens.shape[2], 1, *wspd_ens.shape[-2:]), + wspd_tgt.reshape(-1, 1, *wspd_tgt.shape[-2:]), + dim=1, + biased=False, + ) + B, K = ensemble.shape[0], ensemble.shape[1] + per_loc = per_loc.reshape(B, K, 1, *per_loc.shape[-2:]) + derived["wspd10m"] = per_loc.mean(dim=(0, 2, -2, -1)).detach().cpu().numpy() + + if "z300" in idx and "z500" in idx: + dz_ens = ensemble[:, :, :, idx["z300"]] - ensemble[:, :, :, idx["z500"]] + dz_tgt = target[:, :, idx["z300"]] - target[:, :, idx["z500"]] + dz_ens = dz_ens.unsqueeze(3) + dz_tgt = dz_tgt.unsqueeze(2) + per_loc = kcrps( + dz_ens.reshape(-1, dz_ens.shape[2], 1, *dz_ens.shape[-2:]), + dz_tgt.reshape(-1, 1, *dz_tgt.shape[-2:]), + dim=1, + biased=False, + ) + B, K = ensemble.shape[0], ensemble.shape[1] + per_loc = per_loc.reshape(B, K, 1, *per_loc.shape[-2:]) + derived["z300_minus_z500"] = ( + per_loc.mean(dim=(0, 2, -2, -1)).detach().cpu().numpy() + ) + + return derived + + +def power_spectra_per_variable( + ensemble_mean: torch.Tensor, target: torch.Tensor +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Azimuthal 1D power spectrum of ensemble-mean and target, per channel. + + ``ensemble_mean`` and ``target`` are ``(B, K, C, H, W)`` tensors; the + spectrum is averaged over batch and any requested lead times by the + caller (pass in ``[:, lead_idx:lead_idx+1]`` to isolate a single lead). + + Uses :func:`physicsnemo.metrics.general.power_spectrum.power_spectrum` + — 2D-FFT azimuthal averaging. Not a true spherical-harmonic spectrum + (the paper uses spherical harmonics), but a cheap proxy that still + surfaces the mesh-frequency spike described in Figure 3e / Figure 5. + + Returns ``(k_bins, ens_spectra, tgt_spectra)`` with shapes + ``(nbins,)`` and ``(K, C, nbins)`` respectively. + """ + if ensemble_mean.shape != target.shape: + raise ValueError( + f"ensemble_mean/target shapes must match, got {tuple(ensemble_mean.shape)} vs {tuple(target.shape)}" + ) + # Average spectrum over batch on a per-lead, per-channel basis. + B, K, C, H, W = ensemble_mean.shape + ens_flat = ensemble_mean.reshape(B * K * C, H, W) + tgt_flat = target.reshape(B * K * C, H, W) + k, ens_pow = power_spectrum(ens_flat) + _, tgt_pow = power_spectrum(tgt_flat) + ens_pow = ens_pow.reshape(B, K, C, -1).mean(dim=0) + tgt_pow = tgt_pow.reshape(B, K, C, -1).mean(dim=0) + return ( + k.detach().cpu().numpy(), + ens_pow.detach().cpu().numpy(), + tgt_pow.detach().cpu().numpy(), + ) + + +def energy_score_per_lead( + ensemble: torch.Tensor, + target: torch.Tensor, + spatial_stride: int = 8, + fair: bool = True, +) -> np.ndarray: + """Fair energy score (multivariate CRPS) per lead, averaged over variables + grid. + + The energy score is the multivariate generalisation of CRPS: + + ES = E[||X - y||] - (1/2) E[||X - X'||] + + where the norm is taken over the *variable* axis (dim C) at each spatial + point. This captures cross-channel calibration that per-variable CRPS misses. + + Computing the O(M²) pairwise term over the full 721 × 1440 grid is + expensive; ``spatial_stride`` sub-samples before computing to keep it fast. + With the default stride of 8 the spatial footprint is ~91 × 180 = 16 k + points, well within budget for a validation hook. + + Shapes: ensemble (B, K, M, C, H, W), target (B, K, C, H, W). + Returns numpy array of shape (K,). + + Canonical coord-aware equivalent: :class:`earth2studio.statistics.energy_score` + (added in earth2studio 0.13.0, March 2026). + """ + _check_shapes(ensemble, target) + B, K, M, C, H, W = ensemble.shape + + ens = ensemble[:, :, :, :, ::spatial_stride, ::spatial_stride] + tgt = target[:, :, :, ::spatial_stride, ::spatial_stride] + Hs, Ws = ens.shape[-2], ens.shape[-1] + N = B * K * Hs * Ws + + # (N, M, C) and (N, C) + flat_ens = ens.permute(0, 1, 4, 5, 2, 3).reshape(N, M, C).float() + flat_tgt = tgt.permute(0, 1, 3, 4, 2).reshape(N, C).float() + + # Term 1: (1/M) * sum_m ||x_m - y||_C + term1 = (flat_ens - flat_tgt.unsqueeze(1)).norm(dim=-1).mean(dim=-1) # (N,) + + # Term 2: pairwise spread in batches to cap peak memory. + CHUNK = 65536 + term2_parts: list[torch.Tensor] = [] + for i in range(0, N, CHUNK): + pw = torch.cdist(flat_ens[i : i + CHUNK], flat_ens[i : i + CHUNK], p=2) + if fair: + mask = ~torch.eye(M, device=pw.device, dtype=torch.bool) + term2_parts.append((pw * mask).sum(dim=(-2, -1)) / (2.0 * M * (M - 1))) + else: + term2_parts.append(pw.sum(dim=(-2, -1)) / (2.0 * M * M)) + term2 = torch.cat(term2_parts) # (N,) + + es = (term1 - term2).reshape(B, K, Hs, Ws).mean(dim=(0, 2, 3)) # (K,) + return es.detach().cpu().numpy() + + +# --------------------------------------------------------------------------- +# Plotting (matplotlib-gated; skip silently if unavailable so the lightweight +# smoke test still runs in headless CI). +# --------------------------------------------------------------------------- + + +def _import_matplotlib(): + try: + import matplotlib # noqa: F401 + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + return plt + except ImportError: + return None + + +# Variable ordering for the scorecard — mirrors Figure 2a of the paper. +# Groups: surface vars first, then z (geopotential) and q (humidity) most +# prominent (as in paper), then t, u, v, w. +# 9 standard pressure levels: compact (~58 rows total). +_KEY_LEVELS = ["1000", "925", "850", "700", "500", "300", "200", "100", "50"] +_SCORECARD_GROUPS: list[tuple[str, list[str]]] = [ + ("surface", ["t2m", "msl", "u10m", "v10m", "sst"]), + ("z", [f"z{p}" for p in _KEY_LEVELS]), + ("q", [f"q{p}" for p in _KEY_LEVELS]), + ("t", [f"t{p}" for p in _KEY_LEVELS]), + ("u", [f"u{p}" for p in _KEY_LEVELS]), + ("v", [f"v{p}" for p in _KEY_LEVELS]), + ("w", [f"w{p}" for p in _KEY_LEVELS]), +] + + +def plot_crps_scorecard( + crps: np.ndarray, + variables: Sequence[str], + lead_hours: Sequence[float], + out_path: str, + title: str = "CRPS scorecard", +) -> bool: + """Figure 2a-style heatmap: rows = variables (grouped by type), cols = lead times. + + Each row is normalised to its own [min, max] range so the colormap shows + the relative degradation with lead time, making all variables comparable + regardless of their absolute CRPS magnitude. + """ + plt = _import_matplotlib() + if plt is None: + return False + + var_list = list(variables) + # Build ordered rows: (display_name, channel_index) + rows: list[tuple[str, int]] = [] + group_boundaries: list[int] = [] # row indices where a new group starts + group_labels: list[tuple[int, str]] = [] # (center_row, group_name) + + for group_name, names in _SCORECARD_GROUPS: + present = [(n, var_list.index(n)) for n in names if n in var_list] + if not present: + continue + group_boundaries.append(len(rows)) + group_labels.append((len(rows) + len(present) // 2, group_name)) + rows.extend(present) + + if not rows: + return False + + R = len(rows) + K = len(lead_hours) + # Build data matrix (R, K), normalise each row to [0, 1] + data = np.zeros((R, K), dtype=np.float32) + for ri, (_, ci) in enumerate(rows): + row = crps[:, ci].astype(np.float32) + lo, hi = row.min(), row.max() + data[ri] = (row - lo) / (hi - lo + 1e-12) + + fig, ax = plt.subplots(figsize=(max(4, K * 0.6 + 2), max(4, R * 0.12 + 1.5)), + constrained_layout=True) + im = ax.imshow(data, aspect="auto", cmap="Blues", vmin=0, vmax=1, + interpolation="nearest") + + # x-axis: lead times in hours (or convert to days if ≥ 48 h) + lh = np.asarray(lead_hours) + if lh[-1] >= 48: + x_labels = [f"{h/24:.0f}d" for h in lh] + ax.set_xlabel("lead time (days)", fontsize=9) + else: + x_labels = [f"{h:.0f}h" for h in lh] + ax.set_xlabel("lead time (hours)", fontsize=9) + ax.set_xticks(range(K)) + ax.set_xticklabels(x_labels, fontsize=8) + + # y-axis: variable names (right side shows group labels) + y_names = [r[0] for r in rows] + ax.set_yticks(range(R)) + ax.set_yticklabels(y_names, fontsize=7) + + # Horizontal separators between groups + for b in group_boundaries[1:]: + ax.axhline(b - 0.5, color="white", linewidth=1.5) + + # Group labels on the right + ax2 = ax.twinx() + ax2.set_ylim(ax.get_ylim()) + ax2.set_yticks([c for _, (c, _) in enumerate(group_labels)]) + ax2.set_yticks([c for c, _ in group_labels]) + ax2.set_yticklabels([n for _, n in group_labels], fontsize=8, fontstyle="italic") + ax2.tick_params(length=0) + + fig.colorbar(im, ax=ax, fraction=0.02, pad=0.12, label="normalised (per variable)") + ax.set_title(title, fontsize=10, pad=6) + fig.savefig(out_path, dpi=100) + plt.close(fig) + return True + + +# Figure 2b-f: spread-skill calibration for these 5 key variables. +_SPREAD_SKILL_VARS = ["z500", "q700", "t850", "t2m", "u10m"] +_SPREAD_SKILL_LABELS: dict[str, str] = {"t2m": "2t", "u10m": "10u"} + + +def plot_spread_skill_lines( + spread: np.ndarray, + rmse: np.ndarray, + variables: Sequence[str], + lead_hours: Sequence[float], + out_path: str, +) -> bool: + """Figure 2b-f: spread vs ensemble-mean RMSE for 5 key variables. + + Each panel shows spread (dashed) and RMSE (solid) vs lead time so the + reader can judge calibration: a well-calibrated ensemble has spread ≈ RMSE. + Variables: z500, q700, t850, 2t, 10u. + """ + plt = _import_matplotlib() + if plt is None: + return False + + var_list = list(variables) + pairs = [ + (_SPREAD_SKILL_LABELS.get(v, v), var_list.index(v)) + for v in _SPREAD_SKILL_VARS + if v in var_list + ] + if not pairs: + return False + + n = len(pairs) + lh = np.asarray(lead_hours, dtype=float) + x_label = "lead time (days)" if lh[-1] >= 48 else "lead time (hours)" + x_vals = lh / 24.0 if lh[-1] >= 48 else lh + + fig, axes = plt.subplots(1, n, figsize=(3.2 * n, 3.2), constrained_layout=True) + if n == 1: + axes = [axes] + for ax, (name, ci) in zip(axes, pairs): + ax.plot(x_vals, rmse[:, ci], color="C0", linewidth=1.5, label="RMSE") + ax.plot(x_vals, spread[:, ci], color="C0", linewidth=1.5, + linestyle="--", label="spread") + ax.set_title(name, fontsize=10, pad=4) + ax.set_xlabel(x_label, fontsize=8) + ax.grid(True, alpha=0.3) + ax.tick_params(labelsize=7) + axes[0].set_ylabel("std dev (normalised units)", fontsize=8) + axes[-1].legend(fontsize=8) + fig.suptitle("Spread-skill calibration", fontsize=10) + fig.savefig(out_path, dpi=100) + plt.close(fig) + return True + + +def plot_metric_vs_lead( + metric: np.ndarray, + variables: Sequence[str], + steps: Sequence[float], + ylabel: str, + title: str, + out_path: str, + hline_y: float | None = None, + xlabel: str = "lead time (hours)", +) -> bool: + """One line per channel over lead-time axis. Returns True if plotted.""" + plt = _import_matplotlib() + if plt is None: + return False + fig, ax = plt.subplots(figsize=(10, 5)) + for ci, name in enumerate(variables): + ax.plot(steps, metric[:, ci], marker="o", markersize=4, label=name) + if hline_y is not None: + ax.axhline(hline_y, color="k", linestyle="--", linewidth=0.7) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.grid(True, alpha=0.3) + # Legend outside axes so it never overlaps data (70+ channels) + ax.legend( + fontsize=7, ncol=2, loc="upper left", + bbox_to_anchor=(1.01, 1), borderaxespad=0, + ) + fig.tight_layout() + fig.savefig(out_path, dpi=120, bbox_inches="tight") + plt.close(fig) + return True + + +def plot_rank_histograms( + histograms: np.ndarray, + variables: Sequence[str], + out_path: str, +) -> bool: + """Grid of rank histograms (one per channel) for calibration inspection.""" + plt = _import_matplotlib() + if plt is None: + return False + C, bins = histograms.shape + ncols = min(4, C) + nrows = (C + ncols - 1) // ncols + fig, axes = plt.subplots( + nrows, ncols, figsize=(3 * ncols, 2.5 * nrows), + constrained_layout=True, squeeze=False, + ) + for ci, name in enumerate(variables): + ax = axes[ci // ncols][ci % ncols] + ax.bar(np.arange(bins), histograms[ci], color="#2266aa") + expected = histograms[ci].sum() / bins + ax.axhline(expected, color="k", linestyle="--", linewidth=0.7) + ax.set_title(name, fontsize=8, pad=3) + ax.set_xlabel("rank", fontsize=7) + ax.set_ylabel("count", fontsize=7) + ax.tick_params(labelsize=6) + for j in range(C, nrows * ncols): + axes[j // ncols][j % ncols].set_visible(False) + fig.suptitle("Rank histograms (uniform = well calibrated)", fontsize=11) + fig.savefig(out_path, dpi=120) + plt.close(fig) + return True + + +# Paper Figure 3e-j: spectra for these 3 variables at 2 lead times. +_SPECTRA_VARS_PAPER = ["t2m", "q700", "z500"] +# Display labels matching the paper (t2m → "2t", etc.) +_SPECTRA_LABELS: dict[str, str] = {"t2m": "2t"} + + +def plot_power_spectra( + k: np.ndarray, + ens_spectra: np.ndarray, + tgt_spectra: np.ndarray, + variables: Sequence[str], + lead_hours_all: np.ndarray, + out_path: str, + grid_deg: float = 0.25, + var_subset: Sequence[str] | None = None, + target_lead_hours: Sequence[float] = (12, 360), +) -> bool: + """Figure 3 e-j: 2×3 grid — rows = lead times, cols = variables. + + Rows correspond to the two ``target_lead_hours`` (defaults: 12 h and + 15 d = 360 h); the closest available lead is used when the exact value + is not present. Columns show the paper variables {2t, q700, z500} + (or ``var_subset`` if provided). + """ + plt = _import_matplotlib() + if plt is None: + return False + + subset = list(var_subset) if var_subset is not None else _SPECTRA_VARS_PAPER + var_list = list(variables) + pairs = [(name, var_list.index(name)) for name in subset if name in var_list] + if not pairs: + pairs = [(var_list[i], i) for i in range(min(3, len(var_list)))] + + # Find closest available lead indices for the requested target hours + lh = np.asarray(lead_hours_all, dtype=float) + lead_indices = [int(np.argmin(np.abs(lh - th))) for th in target_lead_hours] + # Deduplicate while preserving order + seen: set[int] = set() + lead_indices = [i for i in lead_indices if not (i in seen or seen.add(i))] # type: ignore[func-returns-value] + lead_labels = [f"Mean power at {lh[i]:.0f} h" for i in lead_indices] + + nrows = len(lead_indices) + ncols = len(pairs) + km_per_deg = 111.0 + grid_km = grid_deg * km_per_deg + kk = k[1:] + n_cells = round(40030 / grid_km) + wavelength_km = n_cells * grid_km / kk + + fig, axes = plt.subplots( + nrows, ncols, figsize=(5 * ncols, 4 * nrows), + constrained_layout=True, squeeze=False, + ) + for ri, (li, row_label) in enumerate(zip(lead_indices, lead_labels)): + for ci_plot, (name, ci) in enumerate(pairs): + ax = axes[ri][ci_plot] + ax.loglog(wavelength_km, ens_spectra[li, ci, 1:], label="FGN", color="C0", linewidth=1.5) + ax.loglog(wavelength_km, tgt_spectra[li, ci, 1:], label="truth", color="k", linewidth=1.5) + ax.invert_xaxis() + display_name = _SPECTRA_LABELS.get(name, name) + if ri == 0: + ax.set_title(display_name, fontsize=10, pad=4) + if ri == nrows - 1: + ax.set_xlabel("Wavelength (km)", fontsize=9) + if ci_plot == 0: + ax.set_ylabel(row_label, fontsize=9) + ax.grid(True, which="both", alpha=0.3) + # Limit x-ticks to avoid overlap + ax.xaxis.set_major_locator( + __import__("matplotlib.ticker", fromlist=["LogLocator"]).LogLocator(numticks=5) + ) + ax.xaxis.set_major_formatter( + __import__("matplotlib.ticker", fromlist=["LogFormatter"]).LogFormatter(minor_thresholds=(2, 0.5)) + ) + ax.tick_params(axis="x", labelsize=7) + ax.tick_params(axis="y", labelsize=7) + if ri == 0 and ci_plot == ncols - 1: + ax.legend(fontsize=8) + fig.suptitle("Spherical Harmonic Power Spectrum", fontsize=11) + fig.savefig(out_path, dpi=120) + plt.close(fig) + return True + + +def pooled_crps_per_lead( + ensemble: torch.Tensor, + target: torch.Tensor, + pool_sizes: Sequence[int], + pool_type: str = "avg", +) -> np.ndarray: + """Pooled CRPS at multiple spatial scales (Figure 3 a-b of arXiv:2506.10772). + + Coarsens ensemble and target by pooling P×P grid-cell windows and then + computes fair CRPS on the coarsened field. Tests calibration at scales + larger than a single grid point. + + pool_sizes [4, 8, 16, 32] ≈ [120, 240, 480, 960] km at 0.25° resolution. + + Parameters + ---------- + pool_type : {"avg", "max"} + ``"avg"`` (Figure 3a) averages over the P×P window; ``"max"`` + (Figure 3b) takes the maximum — tests tail / extreme calibration. + + Shapes: ensemble (B, K, M, C, H, W), target (B, K, C, H, W). + Returns numpy array of shape (len(pool_sizes), K, C). + """ + import torch.nn.functional as F + + if pool_type not in ("avg", "max"): + raise ValueError(f"pool_type must be 'avg' or 'max', got {pool_type!r}") + _check_shapes(ensemble, target) + B, K, M, C, H, W = ensemble.shape + results = [] + for P in pool_sizes: + ens_flat = ensemble.reshape(B * K * M, C, H, W) + tgt_flat = target.reshape(B * K, C, H, W) + if pool_type == "avg": + ens_p = F.avg_pool2d(ens_flat, kernel_size=P, stride=P, ceil_mode=True) + tgt_p = F.avg_pool2d(tgt_flat, kernel_size=P, stride=P, ceil_mode=True) + else: + ens_p = F.max_pool2d(ens_flat, kernel_size=P, stride=P, ceil_mode=True) + tgt_p = F.max_pool2d(tgt_flat, kernel_size=P, stride=P, ceil_mode=True) + Hp, Wp = ens_p.shape[-2], ens_p.shape[-1] + ens_p = ens_p.reshape(B, K, M, C, Hp, Wp) + tgt_p = tgt_p.reshape(B, K, C, Hp, Wp) + results.append(crps_per_variable_per_lead(ens_p, tgt_p)) # (K, C) + return np.stack(results, axis=0) # (len(pool_sizes), K, C) + + +def plot_pooled_crps( + pooled: np.ndarray, + pool_sizes: Sequence[int], + variables: Sequence[str], + lead_hours: Sequence[float], + out_path: str, + title: str = "Pooled CRPS", + grid_deg: float = 0.25, +) -> bool: + """Figure 3 a-b: two side-by-side heatmaps (avg | max), rows = variables. + + Mirrors the paper layout: rows follow _SCORECARD_GROUPS (surface, z, q …), + columns = pool sizes in km, value = CRPS averaged over all lead times. + Each row is normalised to its own [min, max] range so variables with + different CRPS magnitudes can share the same colormap. + + ``pooled`` shape: (P, K, C) — pool sizes × leads × channels. + Pass the avg-pooled array for the left panel and the max-pooled array for + the right panel by calling this function twice with different ``out_path`` + values, or pass a dict (see ``plot_pooled_crps_pair``). + """ + plt = _import_matplotlib() + if plt is None: + return False + + P, _K, _C = pooled.shape + km_per_cell = grid_deg * 111.0 + km_labels = [f"{int(ps * km_per_cell)}" for ps in pool_sizes] + var_list = list(variables) + + # Build rows using the same groups as the scorecard + rows: list[tuple[str, int]] = [] + group_boundaries: list[int] = [] + group_labels: list[tuple[int, str]] = [] + for group_name, names in _SCORECARD_GROUPS: + present = [(n, var_list.index(n)) for n in names if n in var_list] + if not present: + continue + group_boundaries.append(len(rows)) + group_labels.append((len(rows) + len(present) // 2, group_name)) + rows.extend(present) + + if not rows: + return False + + R = len(rows) + # Average over all lead times → (P, C), then select per-row + crps_pk = pooled.mean(axis=1) # (P, C) + + data = np.zeros((R, P), dtype=np.float32) + for ri, (_, ci) in enumerate(rows): + row = crps_pk[:, ci].astype(np.float32) + lo, hi = row.min(), row.max() + data[ri] = (row - lo) / (hi - lo + 1e-12) + + fig, ax = plt.subplots( + figsize=(max(3, P * 0.7 + 2), max(4, R * 0.12 + 1.5)), + constrained_layout=True, + ) + im = ax.imshow(data, aspect="auto", cmap="Blues", vmin=0, vmax=1, + interpolation="nearest") + + ax.set_xticks(range(P)) + ax.set_xticklabels(km_labels, fontsize=8) + ax.set_xlabel("Pool size (km)", fontsize=9) + + y_names = [r[0] for r in rows] + ax.set_yticks(range(R)) + ax.set_yticklabels(y_names, fontsize=7) + + for b in group_boundaries[1:]: + ax.axhline(b - 0.5, color="white", linewidth=1.5) + + ax2 = ax.twinx() + ax2.set_ylim(ax.get_ylim()) + ax2.set_yticks([c for c, _ in group_labels]) + ax2.set_yticklabels([n for _, n in group_labels], fontsize=8, fontstyle="italic") + ax2.tick_params(length=0) + + fig.colorbar(im, ax=ax, fraction=0.02, pad=0.12, label="normalised (per variable)") + ax.set_title(title, fontsize=10, pad=6) + fig.savefig(out_path, dpi=100) + plt.close(fig) + return True + + +def save_summary(metrics: dict[str, Any], out_path: str) -> None: + """Persist a flat dict of numpy arrays + scalars as a single .npz file.""" + np.savez(out_path, **{k: np.asarray(v) for k, v in metrics.items()}) + + +# --------------------------------------------------------------------------- +# Bad-seed detector — paper §6.2 (Discussion / Weaknesses) +# --------------------------------------------------------------------------- + + +def flag_bad_seeds( + forecast_spectra: np.ndarray, + truth_spectra: np.ndarray, + tail_fraction: float = 0.2, + threshold: float = 3.0, +) -> list[int]: + """Flag seeds whose high-wavenumber power diverges from the truth. + + Paper §6.2 (Discussion): *"we found that a particular training seed + produced a number of unstable rollouts, which was detected by + examining the averaged spectra of the validation year forecasts. We + removed this seed and retrained that particular model with a different + seed."* + + The diagnostic operates on the azimuthal 1D power spectra computed by + :func:`power_spectra_per_variable` (or any equivalent pipeline). For + each seed we compare the mean power in the top ``tail_fraction`` of + wavenumbers against the same tail of the ground-truth spectrum; a seed + whose ratio exceeds ``threshold`` on **any** channel/lead pair has + amplified high-frequency content and is treated as unstable. + + Parameters + ---------- + forecast_spectra : np.ndarray + Shape ``(S, K, C, B)`` — S seeds, K lead times, C channels, + B wavenumber bins. Values are power (>= 0). + truth_spectra : np.ndarray + Shape ``(K, C, B)`` — ground-truth spectra shared by all seeds. + tail_fraction : float, default 0.2 + Fraction of highest-wavenumber bins to average over. 0.2 = top 20%. + threshold : float, default 3.0 + ``forecast_tail / truth_tail`` ratio above which a seed is flagged. + Paper does not specify a number; 3x is a defensible starting + default and is surfaced as a knob so callers can tune it. + + Returns + ------- + list[int] + Seed indices (into axis 0 of ``forecast_spectra``) that should be + dropped before running deep-ensemble inference. + """ + if forecast_spectra.ndim != 4: + raise ValueError( + f"forecast_spectra must have shape (S, K, C, B), got {forecast_spectra.shape}" + ) + if truth_spectra.shape != forecast_spectra.shape[1:]: + raise ValueError( + f"truth_spectra shape {truth_spectra.shape} must match " + f"forecast_spectra[1:] {forecast_spectra.shape[1:]}" + ) + if not 0 < tail_fraction <= 1: + raise ValueError(f"tail_fraction must be in (0, 1], got {tail_fraction}") + if threshold <= 0: + raise ValueError(f"threshold must be > 0, got {threshold}") + + B = forecast_spectra.shape[-1] + tail_start = max(0, B - max(1, int(round(B * tail_fraction)))) + + # Mean power in the high-wavenumber tail — shape (S, K, C) / (K, C). + fore_tail = forecast_spectra[..., tail_start:].mean(axis=-1) + truth_tail = truth_spectra[..., tail_start:].mean(axis=-1) + + # Guard against a truth-tail of zero (degenerate variables like a + # constant field) by clamping before the divide. + safe_truth = np.where(truth_tail > 0, truth_tail, np.finfo(np.float32).eps) + ratio = fore_tail / safe_truth # (S, K, C) + + flagged = [ + s for s in range(forecast_spectra.shape[0]) if np.any(ratio[s] > threshold) + ] + return flagged diff --git a/examples/weather/fgn/utils/nn.py b/examples/weather/fgn/utils/nn.py new file mode 100644 index 0000000000..c5a59c27f2 --- /dev/null +++ b/examples/weather/fgn/utils/nn.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +import torch +import torch.nn.functional as F +from torch import nn + +from physicsnemo.core import ModelMetaData, Module + + +def nested_to( + x: torch.Tensor | Mapping | list | tuple | Any, **kwargs +) -> torch.Tensor | dict | list | Any: + """Move tensors inside a nested structure to a device / dtype. + + Mirrors ``examples/weather/stormcast/utils/nn.nested_to`` so the two + recipes share the same container-handling convention. + """ + if isinstance(x, Mapping): + return {k: nested_to(v, **kwargs) for (k, v) in x.items()} + if isinstance(x, (list, tuple)): + return [nested_to(v, **kwargs) for v in x] + if not isinstance(x, torch.Tensor): + return x + return x.to(**kwargs) + + +class _ConditionalResidualBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, cond_dim: int, groups: int): + super().__init__() + norm_groups_in = min(groups, in_channels) + norm_groups_out = min(groups, out_channels) + self.norm1 = nn.GroupNorm(norm_groups_in, in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(norm_groups_out, out_channels) + self.cond = nn.Linear(cond_dim, 2 * out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.skip = ( + nn.Identity() + if in_channels == out_channels + else nn.Conv2d(in_channels, out_channels, kernel_size=1) + ) + + def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + residual = self.skip(x) + h = self.conv1(F.silu(self.norm1(x))) + scale, shift = self.cond(cond).chunk(2, dim=-1) + scale = scale.unsqueeze(-1).unsqueeze(-1) + shift = shift.unsqueeze(-1).unsqueeze(-1) + h = self.norm2(h) + h = h * (1.0 + scale) + shift + h = self.conv2(F.silu(h)) + return h + residual + + +class FGNUNet(Module, register=True): + r"""Latent-conditioned U-Net backbone for Functional Generative Networks (FGN). + + A shallow encoder-decoder U-Net whose decoder features are modulated by a + per-step latent noise vector ``z``, implementing the stochastic generator + :math:`G_\theta(x_{t-T:t}, z_t)` from arXiv:2506.10772 §2.1. + + Parameters + ---------- + state_channels : int + Number of prognostic channels :math:`C` (output channels = input state + channels per frame). + history_frames : int, optional, default=2 + Number of past frames :math:`T` concatenated as input. + background_channels : int, optional, default=0 + Number of slowly-varying background channels (e.g. SST) appended to + the encoder input. + invariant_channels : int, optional, default=0 + Number of static invariant channels (e.g. land-sea mask, orography) + appended to the encoder input. + latent_dim : int, optional, default=16 + Dimensionality :math:`d_z` of the latent noise vector ``z``. + hidden_channels : int, optional, default=32 + Base width :math:`H` of the U-Net. Channel counts at successive + encoder levels are :math:`H`, :math:`2H`. + group_norm_groups : int, optional, default=8 + Number of groups in all ``GroupNorm`` layers. + + Forward + ------- + history : torch.Tensor + Past state frames of shape :math:`(B, T, C, H_{in}, W_{in})`. + latent : torch.Tensor + Noise sample of shape :math:`(B, d_z)`. + background : torch.Tensor, optional + Background field of shape :math:`(B, C_{bg}, H_{in}, W_{in})`. + invariants : torch.Tensor, optional + Static invariants of shape :math:`(B, C_{inv}, H_{in}, W_{in})`. + + Outputs + ------- + torch.Tensor + Predicted next state of shape :math:`(B, C, H_{in}, W_{in})`. + + Examples + -------- + >>> import torch + >>> model = FGNUNet(state_channels=4, history_frames=2, latent_dim=8, hidden_channels=16) + >>> history = torch.randn(2, 2, 4, 32, 48) + >>> latent = torch.randn(2, 8) + >>> out = model(history=history, latent=latent) + >>> out.shape + torch.Size([2, 4, 32, 48]) + """ + + def __init__( + self, + state_channels: int, + history_frames: int = 2, + background_channels: int = 0, + invariant_channels: int = 0, + latent_dim: int = 16, + hidden_channels: int = 32, + group_norm_groups: int = 8, + ): + super().__init__(meta=ModelMetaData()) + self.state_channels = state_channels + self.history_frames = history_frames + self.background_channels = background_channels + self.invariant_channels = invariant_channels + self.latent_dim = latent_dim + self.hidden_channels = hidden_channels + self.group_norm_groups = group_norm_groups + + input_channels = history_frames * state_channels + input_channels += background_channels + invariant_channels + + cond_dim = hidden_channels * 4 + self.latent_mlp = nn.Sequential( + nn.Linear(latent_dim, cond_dim), + nn.SiLU(), + nn.Linear(cond_dim, cond_dim), + ) + + self.stem = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, padding=1) + self.down1 = _ConditionalResidualBlock( + hidden_channels, hidden_channels, cond_dim, group_norm_groups + ) + self.down2 = _ConditionalResidualBlock( + hidden_channels, hidden_channels * 2, cond_dim, group_norm_groups + ) + self.bottleneck = _ConditionalResidualBlock( + hidden_channels * 2, hidden_channels * 2, cond_dim, group_norm_groups + ) + self.up1 = _ConditionalResidualBlock( + hidden_channels * 4, hidden_channels * 2, cond_dim, group_norm_groups + ) + self.up2 = _ConditionalResidualBlock( + hidden_channels * 3, hidden_channels, cond_dim, group_norm_groups + ) + self.head = nn.Conv2d(hidden_channels, state_channels, kernel_size=1) + + def forward( + self, + history: torch.Tensor, + latent: torch.Tensor, + background: torch.Tensor | None = None, + invariants: torch.Tensor | None = None, + ) -> torch.Tensor: + if history.ndim != 5: + raise ValueError("history must have shape [B, T, C, H, W]") + batch, frames, channels, height, width = history.shape + if frames != self.history_frames or channels != self.state_channels: + raise ValueError("history shape does not match model configuration") + + pieces = [history.reshape(batch, frames * channels, height, width)] + if background is not None: + pieces.append(background) + if invariants is not None: + pieces.append(invariants) + x = torch.cat(pieces, dim=1) + + cond = self.latent_mlp(latent) + + stem = self.stem(x) + skip1 = self.down1(stem, cond) + x = F.avg_pool2d(skip1, kernel_size=2) + skip2 = self.down2(x, cond) + x = F.avg_pool2d(skip2, kernel_size=2) + x = self.bottleneck(x, cond) + + x = F.interpolate( + x, size=skip2.shape[-2:], mode="bilinear", align_corners=False + ) + x = self.up1(torch.cat([x, skip2], dim=1), cond) + x = F.interpolate( + x, size=skip1.shape[-2:], mode="bilinear", align_corners=False + ) + x = self.up2(torch.cat([x, skip1], dim=1), cond) + return self.head(x) + + +def build_model( + cfg, + state_channels: int, + background_channels: int, + invariant_channels: int, +) -> FGNUNet: + if cfg.model.background_channels not in ("auto", background_channels): + raise ValueError("config model.background_channels disagrees with dataset") + if cfg.model.invariant_channels not in ("auto", invariant_channels): + raise ValueError("config model.invariant_channels disagrees with dataset") + + return FGNUNet( + state_channels=state_channels, + history_frames=int(cfg.model.history_frames), + background_channels=background_channels, + invariant_channels=invariant_channels, + latent_dim=int(cfg.model.latent_dim), + hidden_channels=int(cfg.model.hidden_channels), + group_norm_groups=int(cfg.model.group_norm_groups), + ) diff --git a/examples/weather/fgn/utils/parallel.py b/examples/weather/fgn/utils/parallel.py new file mode 100644 index 0000000000..5bf076315e --- /dev/null +++ b/examples/weather/fgn/utils/parallel.py @@ -0,0 +1,273 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Data- and domain-parallel helpers for FGN training. + +Slim adaptation of ``examples/weather/stormcast/utils/parallel.py`` tailored +to the FGN recipe: same FSDP-plus-ShardTensor strategy, same sharded- +dataloader conventions, but without the diffusion noise-scheduler plumbing +StormCast needs. + +See StormCast's ``utils/parallel.py`` for the reference implementation and +docstrings; this file mirrors its public API so the two recipes can stay in +lockstep. +""" + +from __future__ import annotations + +from collections.abc import Iterator, Mapping +from typing import Any + +import numpy as np +import torch +from datasets.dataset import worker_init +from torch.distributed.fsdp import ( + BackwardPrefetch, + ShardingStrategy, +) +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, +) +from torch.distributed.tensor import DTensor, distribute_module, distribute_tensor +from torch.distributed.tensor.placement_types import Replicate, Shard +from utils.nn import nested_to + +from physicsnemo.distributed import DistributedManager +from physicsnemo.domain_parallel.shard_tensor import scatter_tensor + + +class ParallelHelper: + """Manage data + domain parallelism for the FGN recipe. + + Mirrors StormCast's ``ParallelHelper`` so FGN inherits the same tested + pattern: a 2D device mesh with a ``ddp`` axis and a ``domain`` axis, FSDP + on the ddp axis, optional ShardTensor spatial sharding on the domain + axis. + + Parameters + ---------- + domain_parallel_size : int + Number of ranks in the domain-parallel dimension. Use 1 for pure DDP + or single-process runs. + use_shard_tensor : bool + Whether to shard batches and selected module parameters across the + domain mesh. Typically ``domain_parallel_size > 1`` OR + ``force_sharding`` is true. + shard_dim : int, default 2 + Spatial dimension along which tensors are partitioned for domain + parallelism. For ``(B, C, H, W)`` sharded along height, use ``2``. + """ + + def __init__( + self, + domain_parallel_size: int, + use_shard_tensor: bool = False, + shard_dim: int = 2, + ): + if not DistributedManager.is_initialized(): + DistributedManager.initialize() + self.dist = DistributedManager() + self.domain_parallel_size = domain_parallel_size + self.shard_dim = shard_dim + + if self.dist.world_size % domain_parallel_size != 0: + raise ValueError( + "domain_parallel_size must evenly divide the number of processes" + ) + self.data_parallel_size = self.dist.world_size // domain_parallel_size + self.mesh = self.dist.initialize_mesh( + mesh_shape=(self.data_parallel_size, domain_parallel_size), + mesh_dim_names=["ddp", "domain"], + ) + self.domain_rank = self.mesh["domain"].get_local_rank() + self.use_shard_tensor = use_shard_tensor + + def get_domain_group_zero_rank(self) -> int: + return torch.distributed.get_global_rank(self.mesh["domain"].get_group(), 0) + + def local_batch_size(self, global_batch_size: int) -> int: + return global_batch_size // self.data_parallel_size + + def sharded_dataloader( + self, + dataset: torch.utils.data.Dataset, + batch_size: int = 1, + seed: int | None = None, + num_workers: int = 2, + shuffle: bool = True, + ) -> torch.utils.data.DataLoader: + """Build a rank-sharded DataLoader. + + Each rank sees a contiguous slice of ``range(len(dataset))`` (rather + than a strided slice as in ``DistributedSampler``), which plays + nicely with caches that key on neighbouring time indices. + """ + global_samples = np.arange(len(dataset)) + num_samples_global = len(global_samples) + source_rank = ( + global_samples / num_samples_global * self.dist.world_size + ).astype(int) + local_samples = global_samples[source_rank == self.dist.rank] + + def sampler() -> Iterator[int]: + local_seed = None if seed is None else seed + self.dist.rank + rng = np.random.default_rng(seed=local_seed) + while True: + if shuffle: + rng.shuffle(local_samples) + yield from local_samples + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler(), + num_workers=num_workers, + worker_init_fn=worker_init, + drop_last=True, + pin_memory=torch.cuda.is_available(), + prefetch_factor=2 if num_workers > 0 else None, + ) + + def sharded_data_iter( + self, + dataloader: torch.utils.data.DataLoader, + num_samples: int | None = None, + ) -> Iterator[Any]: + data_iter = iter(dataloader) + i = 0 + batch: Any = None + domain_group = self.mesh["domain"].get_group() + while True: + source_rank_in_mesh = i % self.domain_parallel_size + source_rank = torch.distributed.get_global_rank( + domain_group, source_rank_in_mesh + ) + if source_rank == self.dist.rank or i == 0: + batch = nested_to( + next(data_iter), + device=self.dist.device, + non_blocking=True, + ) + + yield ( + self.nested_scatter(batch, source_rank) + if self.use_shard_tensor + else batch + ) + + i += 1 + if i == num_samples: + break + + def distribute_tensor(self, x: torch.Tensor) -> torch.Tensor: + if self.use_shard_tensor: + return self.nested_scatter(x, self.get_domain_group_zero_rank()) + return x + + def distribute_model(self, model: torch.nn.Module) -> torch.nn.Module: + """Wrap a model with FSDP, with optional ShardTensor domain sharding.""" + if self.use_shard_tensor: + model = distribute_module( + model, + device_mesh=self.mesh["domain"], + partition_fn=partition_model_selective, + ) + return FSDP( + model, + device_mesh=self.mesh["ddp"], + use_orig_params=False, # required for ShardTensor compatibility + sharding_strategy=ShardingStrategy.NO_SHARD, + sync_module_states=True, + forward_prefetch=True, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + ) + + def replicate_tensor(self, t: torch.Tensor) -> torch.Tensor: + if not self.use_shard_tensor or isinstance(t, DTensor): + return t + return DTensor.from_local( + t, device_mesh=self.mesh["domain"], placements=[Replicate()] + ) + + def nested_scatter( + self, + x: torch.Tensor | Mapping | list | tuple | Any, + global_rank_of_source: int, + shard_dim: int | None = None, + ) -> Any: + if shard_dim is None: + shard_dim = self.shard_dim + if isinstance(x, Mapping): + return { + k: self.nested_scatter(v, global_rank_of_source, shard_dim=shard_dim) + for (k, v) in x.items() + } + if isinstance(x, (list, tuple)): + return [ + self.nested_scatter(v, global_rank_of_source, shard_dim=shard_dim) + for v in x + ] + + x_type = type(x) + is_scalar = not isinstance(x, torch.Tensor) + if is_scalar: + x = torch.as_tensor(x, device=self.dist.device) + + placement = ( + Shard(shard_dim) + if (x.ndim >= 3 and x.shape[shard_dim] > 1) + else Replicate() + ) + x = scatter_tensor( + x, + global_rank_of_source, + self.mesh["domain"], + placements=(placement,), + global_shape=x.shape, + dtype=x.dtype, + ) + if is_scalar: + x = x_type(x.cpu()) + return x + + +def shard_dim_selector(param_name: str) -> int | None: + """Return the spatial axis along which a parameter should be sharded, if any. + + Matches the FGN backbone's spatial-parameter naming. Currently returns + ``None`` since the U-Net has no spatial positional embeddings; add names + here when the CLN / graph-transformer backbone lands (e.g. + ``"pos_embed"``, ``"mesh_pos_embed"``). + """ + sharded_params: tuple[str, ...] = () # e.g. ("pos_embed",) in the future + return 1 if any(p in param_name for p in sharded_params) else None + + +def partition_model_selective( + name: str, # noqa: ARG001 — signature required by distribute_module + submodule: torch.nn.Module, + device_mesh: torch.distributed.device_mesh.DeviceMesh, +) -> None: + """Parameter-by-parameter domain-mesh placement selector. + + Mirrors StormCast's ``partition_model_selective``: every parameter is + wrapped in a ``DTensor`` (Shard or Replicate) so that + ``distribute_module``'s internal ``replicate_module_params_buffers`` + never sees a plain tensor and cannot silently flip ``requires_grad`` on + frozen params. + """ + for key, param in submodule._parameters.items(): + if param is None: + continue + if (shard_dim := shard_dim_selector(key)) is not None: + dt = distribute_tensor( + param, device_mesh=device_mesh, placements=[Shard(shard_dim)] + ) + else: + dt = distribute_tensor( + param, device_mesh=device_mesh, placements=[Replicate()] + ) + submodule.register_parameter( + key, torch.nn.Parameter(dt, requires_grad=param.requires_grad) + ) diff --git a/examples/weather/fgn/utils/trainer.py b/examples/weather/fgn/utils/trainer.py new file mode 100644 index 0000000000..7253ce1ce8 --- /dev/null +++ b/examples/weather/fgn/utils/trainer.py @@ -0,0 +1,595 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path + +import numpy as np +import torch +from datasets import dataset_classes +from omegaconf import OmegaConf +from utils.config import TrainMainConfig +from utils.loss import ( + build_area_weights, + build_channel_weights, + ensemble_mean_mse, + fair_crps, +) +from utils.metrics import ( + crps_per_variable_per_lead, + derived_variable_crps, + energy_score_per_lead, + ensemble_rmse_per_variable_per_lead, + plot_metric_vs_lead, + plot_power_spectra, + plot_rank_histograms, + power_spectra_per_variable, + rank_histogram_per_variable, + save_summary, + spread_skill_per_variable_per_lead, +) +from utils.nn import build_model +from utils.parallel import ParallelHelper + +from physicsnemo.distributed import DistributedManager +from physicsnemo.utils import load_checkpoint, save_checkpoint +from physicsnemo.utils.logging import PythonLogger, RankZeroLoggingWrapper + + +def find_latest_model_checkpoint(checkpoint_dir: Path) -> str: + candidates = sorted(checkpoint_dir.glob("*.mdlus")) + if not candidates: + raise FileNotFoundError(f"No .mdlus checkpoints found in {checkpoint_dir}") + return str(candidates[-1]) + + +class Trainer: + def __init__(self, cfg): + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + self.cfg = TrainMainConfig(**cfg_dict) + + self.dist = DistributedManager() + self.device = self.dist.device + # Rank-0-only logger mirrors the StormCast convention + # (examples/weather/stormcast/utils/logging.ExperimentLogger) — uses + # physicsnemo.utils.logging.PythonLogger so output flushes on each + # record instead of sitting in a print() stdio buffer under srun. + self.logger = RankZeroLoggingWrapper(PythonLogger("fgn"), self.dist) + self.logger.info("Trainer.__init__ starting") + + # Data + domain parallel setup. For single-process runs we skip the + # ParallelHelper entirely: DistributedManager may be in its fallback + # "single process" state (no process group), which is incompatible + # with ShardTensor mesh creation. StormCast's trainer always builds a + # ParallelHelper because it assumes a real distributed init; the FGN + # recipe keeps a no-helper path so the CPU-only smoke test stays + # runnable without an init_process_group call. + self.parallel_helper: ParallelHelper | None = None + domain_parallel_size = int(self.cfg.training.domain_parallel_size) + force_sharding = bool(self.cfg.training.force_sharding) + self.use_shard_tensor = domain_parallel_size > 1 or force_sharding + if self.dist.world_size > 1 or self.use_shard_tensor: + self.parallel_helper = ParallelHelper( + domain_parallel_size=domain_parallel_size, + use_shard_tensor=self.use_shard_tensor, + ) + if ( + self.use_shard_tensor + and self.parallel_helper.local_batch_size( + int(self.cfg.training.batch_size) + ) + > 1 + ): + raise ValueError("Domain parallelism requires a local batch size of 1") + + self.checkpoint_dir = ( + Path(self.cfg.training.rundir) / self.cfg.training.checkpoint_dir + ) + if self.dist.rank == 0: + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # All ranks use the same seed so parameter initialization is identical. + torch.manual_seed(int(self.cfg.training.seed)) + + dataset_cls = dataset_classes[self.cfg.dataset.name] + self.logger.info(f"Building datasets: {self.cfg.dataset.name}") + self.train_dataset = dataset_cls(self.cfg.dataset, train=True) + self.valid_dataset = dataset_cls(self.cfg.dataset, train=False) + self.logger.info( + f"Dataset ready: train={len(self.train_dataset)} val={len(self.valid_dataset)}" + ) + + self.logger.info("Fetching dataset invariants") + invariants = self.train_dataset.get_invariants() + self.invariants = None + invariant_channels = 0 + if invariants is not None: + self.invariants = torch.from_numpy(invariants).to( + self.device, dtype=torch.float32 + ) + invariant_channels = int(self.invariants.shape[0]) + + self.logger.info("Building model") + self.model = build_model( + self.cfg, + state_channels=len(self.train_dataset.state_channels()), + background_channels=len(self.train_dataset.background_channels()), + invariant_channels=invariant_channels, + ).to(self.device) + self.logger.info( + f"Model ready on {self.device} " + f"(params={sum(p.numel() for p in self.model.parameters()):,})" + ) + + # Wrap with FSDP / ShardTensor when running distributed. Domain- + # sharded invariant tensor so forward passes on sharded inputs find + # the invariant in the same layout. + if self.parallel_helper is not None: + self.model = self.parallel_helper.distribute_model(self.model) + if self.invariants is not None and self.use_shard_tensor: + self.invariants = self.parallel_helper.distribute_tensor( + self.invariants + ) + + # Optimizer must be built after FSDP wrapping. + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=float(self.cfg.training.optimizer.lr), + betas=tuple(self.cfg.training.optimizer.betas), + weight_decay=float(self.cfg.training.optimizer.weight_decay), + ) + + # Train/val loaders: ranks get disjoint contiguous index slices via + # ParallelHelper.sharded_dataloader. Single-process falls back to a + # plain DataLoader so we don't depend on a process group. + batch_size = int(self.cfg.training.batch_size) + num_workers = int(self.cfg.training.num_data_workers) + seed = int(self.cfg.training.seed) + if self.parallel_helper is not None: + local_batch = self.parallel_helper.local_batch_size(batch_size) + self.train_loader = self.parallel_helper.sharded_dataloader( + self.train_dataset, + batch_size=local_batch, + seed=seed, + num_workers=num_workers, + shuffle=True, + ) + self.valid_loader = self.parallel_helper.sharded_dataloader( + self.valid_dataset, + batch_size=local_batch, + seed=seed + 1, + num_workers=0, + shuffle=False, + ) + # Cap validation length: the parallel_helper sampler is infinite + # by design (StormCast convention), so we bound iteration the + # same way StormCast does — `sharded_data_iter(loader, N)`. By + # default sweep one local epoch over each rank's shard. + local_valid = max( + 1, + len(self.valid_dataset) + // (max(self.dist.world_size, 1) * max(local_batch, 1)), + ) + self.validation_steps = int( + getattr(self.cfg.training, "validation_steps", local_valid) + or local_valid + ) + else: + from datasets.dataset import worker_init + from torch.utils.data import DataLoader + + self.train_loader = DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + worker_init_fn=worker_init if num_workers else None, + ) + self.valid_loader = DataLoader( + self.valid_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + ) + self.validation_steps = None # plain DataLoader is finite + + # Optional per-channel + cos(lat) loss weights. Channel weights + # follow GraphCast/GenCast scheme with geopotential halved per FGN + # §2.2.3; area weights normalise cos(lat) so the mean row-sum over + # latitudes equals 1 (preserves loss scale when toggled on/off). + self.loss_weights: torch.Tensor | None = None + channel_w = None + if bool(self.cfg.training.loss.use_channel_weights): + channel_w = torch.from_numpy( + build_channel_weights(self.train_dataset.state_channels()) + ).to(self.device, dtype=torch.float32) + area_w = None + if bool(self.cfg.training.loss.use_area_weights): + H, _ = self.train_dataset.image_shape() + area_w = torch.from_numpy(build_area_weights(H)).to( + self.device, dtype=torch.float32 + ) # shape (H, 1) + if channel_w is not None or area_w is not None: + # Build a (1, C, H, W)-broadcastable tensor. + H, W = self.train_dataset.image_shape() + combined = torch.ones(1, 1, H, W, device=self.device, dtype=torch.float32) + if channel_w is not None: + combined = combined * channel_w.view(1, -1, 1, 1) + if area_w is not None: + combined = combined * area_w.view(1, 1, H, 1) + self.loss_weights = combined + + self.step = 0 + self.best_val_loss = float("inf") + self._resume_if_needed() + + def _resume_if_needed(self) -> None: + resume = self.cfg.training.resume_checkpoint + if resume is None: + return + if not self.checkpoint_dir.exists(): + return + epoch = None if resume == "latest" else int(resume) + metadata = {} + loaded = load_checkpoint( + self.checkpoint_dir, + models=self.model, + optimizer=self.optimizer, + epoch=epoch, + metadata_dict=metadata, + device=self.device, + ) + self.step = int(loaded) + if metadata.get("best_val_loss") is not None: + self.best_val_loss = float(metadata["best_val_loss"]) + + def _step_ensemble( + self, + history: torch.Tensor, + background: torch.Tensor, + invariants: torch.Tensor | None, + num_samples: int, + ) -> torch.Tensor: + """Run `num_samples` forward passes of the model, one per latent draw. + + Returns a tensor of shape ``(B, num_samples, C, H, W)``. + """ + + members = [] + for _ in range(num_samples): + latent = torch.randn( + history.shape[0], + int(self.cfg.model.latent_dim), + device=self.device, + dtype=torch.float32, + ) + members.append( + self.model( + history=history, + latent=latent, + background=background, + invariants=invariants, + ) + ) + return torch.stack(members, dim=1) + + def _loss(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: + history = batch["history"].to(self.device, dtype=torch.float32) + target = batch["target"].to(self.device, dtype=torch.float32) + background = batch["background"].to(self.device, dtype=torch.float32) + + # Normalize target layout to (B, K, C, H, W); datasets may emit (B, C, H, W). + if target.ndim == 4: + target = target.unsqueeze(1) + if target.ndim != 5: + raise ValueError( + f"target must have shape [B, K, C, H, W] or [B, C, H, W], got {tuple(target.shape)}" + ) + + ar_steps = int(target.shape[1]) + cfg_ar = int(getattr(self.cfg.training, "ar_steps", 1)) + if cfg_ar != ar_steps: + raise ValueError( + f"training.ar_steps={cfg_ar} but dataset produced {ar_steps} future frames; " + "set future_frames to match ar_steps" + ) + + invariants = None + if self.invariants is not None: + invariants = self.invariants.unsqueeze(0).expand( + history.shape[0], -1, -1, -1 + ) + + num_samples = int(self.cfg.training.loss.num_samples) + mse_weight = float(self.cfg.training.loss.mse_weight) + + # For each rollout step, run N-member ensemble, score against that + # step's ground truth, then advance history by appending each member's + # prediction (so the N trajectories diverge in parallel). + # History shape per member: (B, T, C, H, W). + B, T, C, H, W = history.shape + per_member_hist = ( + history.unsqueeze(1).expand(B, num_samples, T, C, H, W).contiguous() + ) + + step_losses: list[torch.Tensor] = [] + for k in range(ar_steps): + members = [] + for n in range(num_samples): + hist_n = per_member_hist[:, n] + latent = torch.randn( + hist_n.shape[0], + int(self.cfg.model.latent_dim), + device=self.device, + dtype=torch.float32, + ) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available()): + members.append( + self.model( + history=hist_n, + latent=latent, + background=background, + invariants=invariants, + ) + ) + preds = torch.stack(members, dim=1).float() # (B, N, C, H, W) + + step_loss = fair_crps(preds, target[:, k], weights=self.loss_weights) + if mse_weight > 0.0: + step_loss = step_loss + mse_weight * ensemble_mean_mse( + preds, target[:, k], weights=self.loss_weights + ) + step_losses.append(step_loss) + + if k < ar_steps - 1: + # Paper §3: predicted-only channels (e.g. tp06) must not be + # fed back as input on the next AR step — mirrors + # earth2studio gencast_mini's zeroing of tp12 in inputs. + # Clone before mutating because ``preds`` is still used in + # ``step_loss`` and autograd is tracking it. + next_frame = preds + output_only = self.train_dataset.output_only_channels() + if output_only: + next_frame = next_frame.clone() + for ci in output_only: + next_frame[:, :, ci].zero_() + per_member_hist = torch.cat( + [per_member_hist[:, :, 1:], next_frame.unsqueeze(2)], dim=2 + ) + + return torch.stack(step_losses).mean() + + def _validation_loss(self) -> float: + self.model.eval() + losses = [] + # Mirror StormCast: with parallel_helper the sampler is infinite, so + # bound iteration via sharded_data_iter(loader, N). Plain DataLoader + # path is finite and falls through to the default for-loop. + if self.parallel_helper is not None: + iterator = self.parallel_helper.sharded_data_iter( + self.valid_loader, self.validation_steps + ) + else: + iterator = self.valid_loader + with torch.no_grad(): + losses.extend(float(self._loss(batch).detach().cpu()) for batch in iterator) + self.model.train() + return sum(losses) / max(len(losses), 1) + + def _run_validation_metrics(self) -> None: + """Figure 2 + 3 diagnostics on a single validation batch. + + Runs an ensemble rollout across all ``ar_steps`` lead times and + writes per-variable CRPS / RMSE / spread-skill / rank hist / 1D + power spectra to ``rundir/validation/step=/``. No-op on + non-rank-0 ranks. + """ + if self.dist.rank != 0: + return + try: + batch = next(iter(self.valid_loader)) + except StopIteration: + return + + self.model.eval() + history = batch["history"].to(self.device, dtype=torch.float32) + target = batch["target"].to(self.device, dtype=torch.float32) + background = batch["background"].to(self.device, dtype=torch.float32) + if target.ndim == 4: + target = target.unsqueeze(1) + K = target.shape[1] + + invariants = None + if self.invariants is not None: + invariants = self.invariants.unsqueeze(0).expand( + history.shape[0], -1, -1, -1 + ) + + M = int(self.cfg.training.validation_ensemble_size) + latent_dim = int(self.cfg.model.latent_dim) + + # N parallel trajectories diverge step-by-step exactly as in the + # training loop, but we don't need gradients. + B, T, C, H, W = history.shape + per_member_hist = history.unsqueeze(1).expand(B, M, T, C, H, W).contiguous() + preds_all: list[torch.Tensor] = [] + with torch.no_grad(): + for k in range(K): + members = [] + for n in range(M): + latent = torch.randn( + B, latent_dim, device=self.device, dtype=torch.float32 + ) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available()): + pred = self.model( + history=per_member_hist[:, n], + latent=latent, + background=background, + invariants=invariants, + ) + members.append(pred.float()) + preds = torch.stack(members, dim=1) # (B, M, C, H, W) + preds_all.append(preds) + if k < K - 1: + # Paper §3: zero predicted-only channels (e.g. tp06) + # before feeding them back as next-step history. + next_frame = preds + output_only = self.train_dataset.output_only_channels() + if output_only: + next_frame = next_frame.clone() + for ci in output_only: + next_frame[:, :, ci].zero_() + per_member_hist = torch.cat( + [per_member_hist[:, :, 1:], next_frame.unsqueeze(2)], dim=2 + ) + + self.model.train() + ensemble = torch.stack(preds_all, dim=1) # (B, K, M, C, H, W) + + variables = list(self.train_dataset.state_channels()) + crps_kc = crps_per_variable_per_lead(ensemble, target) + rmse_kc = ensemble_rmse_per_variable_per_lead(ensemble, target) + spread_kc, skill_kc, ratio_kc = spread_skill_per_variable_per_lead( + ensemble, target + ) + ranks_cb = rank_histogram_per_variable(ensemble, target) + es_k = energy_score_per_lead(ensemble, target) + derived = derived_variable_crps(ensemble, target, variables) + ensemble_mean = ensemble.mean(dim=2) + k_vec, ens_spec, tgt_spec = power_spectra_per_variable(ensemble_mean, target) + + out_dir = Path(self.cfg.training.rundir) / "validation" / f"step={self.step}" + out_dir.mkdir(parents=True, exist_ok=True) + + summary = { + "crps_per_lead_per_channel": crps_kc, + "rmse_per_lead_per_channel": rmse_kc, + "spread_per_lead_per_channel": spread_kc, + "skill_per_lead_per_channel": skill_kc, + "spread_skill_ratio": ratio_kc, + "rank_histograms": ranks_cb, + "energy_score_per_lead": es_k, + "variables": np.array(variables, dtype=object), + "lead_steps": np.arange(1, K + 1, dtype=np.int64), + "power_spectrum_k": k_vec, + "power_spectrum_forecast": ens_spec, + "power_spectrum_truth": tgt_spec, + } + for dname, vals in derived.items(): + summary[f"derived_crps_{dname}"] = vals + save_summary(summary, str(out_dir / "metrics.npz")) + + leads = np.arange(1, K + 1) + plot_metric_vs_lead( + crps_kc, + variables, + leads, + "CRPS", + "fCRPS per lead (lower is better)", + str(out_dir / "crps_vs_lead.png"), + ) + plot_metric_vs_lead( + rmse_kc, + variables, + leads, + "ensemble-mean RMSE", + "Ensemble-mean RMSE per lead", + str(out_dir / "rmse_vs_lead.png"), + ) + plot_metric_vs_lead( + ratio_kc, + variables, + leads, + "spread / skill", + "Spread-skill ratio (1.0 = calibrated)", + str(out_dir / "spread_skill_vs_lead.png"), + hline_y=1.0, + ) + plot_rank_histograms(ranks_cb, variables, str(out_dir / "rank_histograms.png")) + # Energy score is a (K,) scalar — plot as a single-series lead curve. + plot_metric_vs_lead( + es_k[:, None], + ["multivariate"], + leads, + "energy score", + "Energy score per lead (lower is better)", + str(out_dir / "energy_score_vs_lead.png"), + ) + plot_power_spectra( + k_vec, + ens_spec, + tgt_spec, + variables, + lead_idx=K - 1, + out_path=str(out_dir / f"power_spectra_lead{K}.png"), + ) + + def save_checkpoint(self) -> None: + save_checkpoint( + self.checkpoint_dir, + models=self.model, + optimizer=self.optimizer, + epoch=self.step, + metadata={"best_val_loss": self.best_val_loss}, + ) + + def _make_train_iter(self) -> Iterator: + # When domain parallelism is active, sharded_data_iter handles both + # data-parallel sample routing and spatial scatter (ShardTensor). + # Mirrors StormCast's pattern (stormcast/utils/trainer.py). + if self.parallel_helper is not None: + remaining = int(self.cfg.training.total_train_steps) - self.step + return self.parallel_helper.sharded_data_iter( + self.train_loader, num_samples=remaining + ) + # Plain single-process / DDP path: restart the DataLoader on exhaustion. + def _plain() -> Iterator: + loader_iter = iter(self.train_loader) + while True: + try: + yield next(loader_iter) + except StopIteration: + loader_iter = iter(self.train_loader) + yield next(loader_iter) + + return _plain() + + def train(self) -> None: + self.model.train() + total_steps = int(self.cfg.training.total_train_steps) + + for batch in self._make_train_iter(): + self.optimizer.zero_grad(set_to_none=True) + loss = self._loss(batch) + loss.backward() + + clip = float(self.cfg.training.clip_grad_norm) + if clip > 0.0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip) + + self.optimizer.step() + self.step += 1 + + if self.step % int(self.cfg.training.print_progress_freq) == 0: + self.logger.info( + f"step={self.step} train_loss={float(loss.detach().cpu()):.6f}" + ) + + if self.step % int(self.cfg.training.validation_freq) == 0: + val_loss = self._validation_loss() + self.best_val_loss = min(self.best_val_loss, val_loss) + self.logger.info(f"step={self.step} val_loss={val_loss:.6f}") + if bool(self.cfg.training.validation_metrics): + self._run_validation_metrics() + + if self.step % int(self.cfg.training.checkpoint_freq) == 0: + self.save_checkpoint() + + if self.step >= total_steps: + break + + if self.step % int(self.cfg.training.checkpoint_freq) != 0: + self.save_checkpoint()