Skip to content

[weather] Add Fgn model #1660

Open
kashif wants to merge 11 commits into
NVIDIA:mainfrom
kashif:fgn
Open

[weather] Add Fgn model #1660
kashif wants to merge 11 commits into
NVIDIA:mainfrom
kashif:fgn

Conversation

@kashif
Copy link
Copy Markdown

@kashif kashif commented May 21, 2026

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 21, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 21, 2026

Greptile Summary

This PR adds a new examples/weather/fgn/ example implementing the Functional Generative Network (FGN) weather model scaffold in PhysicsNeMo, covering training, inference, datasets, loss/metrics utilities, and SLURM scripts. Two P1 bugs need attention before merging.

  • Inference rollout bug: _rollout in inference.py uses rollout_history[:, 1] which hardcodes history_frames=2; any other history size causes a shape mismatch runtime error on the second step.
  • AMP silently ignored: TrainingConfig.amp is declared and the production shell script passes training.amp=true, but the trainer never reads this flag — no autocast or GradScaler is applied.
  • Personal paths in SLURM scripts: /mnt/home/kashif/ and /mnt/data/kashif/ are hardcoded in scripts/train_2024_val.sh, compute_stats_2024.sh, and prefetch_arco_2024.sh, making them unusable for any other user.

Important Files Changed

Filename Overview
examples/weather/fgn/inference.py Adds autoregressive inference with deep-ensemble support; contains a P1 bug in _rollout where rollout_history[:, 1] hardcodes history_frames=2, breaking at runtime for any other history window size.
examples/weather/fgn/utils/trainer.py Full training loop with AR rollout, validation metrics, and checkpoint management; AMP config field declared and used in production scripts but never implemented in the training loop.
examples/weather/fgn/utils/loss.py Implements fair CRPS (eq. 4-5 of the FGN paper) and GraphCast-style channel weights with correct geopotential halving; logic appears sound and well-tested.
examples/weather/fgn/utils/nn.py Latent-conditioned U-Net (FGNUNet) scaffold with conditional residual blocks; straightforward implementation, no issues found.
examples/weather/fgn/utils/parallel.py FSDP + ShardTensor data/domain parallelism helper adapted from StormCast; sharded dataloader and nested scatter logic look correct.
examples/weather/fgn/datasets/arco.py ERA5/ARCO dataset wrapper with tp accumulation, SST NaN imputation, and z-score normalization; num_samples formula is correct, SST impute logic is sound.
examples/weather/fgn/utils/metrics.py Validation diagnostics (CRPS, RMSE, spread-skill, rank histograms, power spectra, derived variables); logic appears correct, well-documented limitations noted.
examples/weather/fgn/scripts/train_2024_val.sh Production SLURM training script with personal home-directory paths hardcoded (/mnt/home/kashif/, /mnt/data/kashif/) that will fail for any other user without modification.
examples/weather/fgn/utils/config.py Pydantic dataclass configs for training/inference/model; amp: bool = False declared but the flag is never consumed by the trainer.
examples/weather/fgn/datasets/mock.py Synthetic smoke-test dataset with deterministic wave fields; no issues found.

Reviews (1): Last reviewed commit: "Merge branch 'main' into fgn" | Re-trigger Greptile

Comment thread examples/weather/fgn/inference.py Outdated
Comment on lines +109 to +112
rollout_history = torch.stack(
[rollout_history[:, 1], next_frame],
dim=1,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The rollout history update hardcodes history_frames=2 by stacking rollout_history[:, 1] (always index 1) with the new prediction. When history_frames > 2, this silently discards all but the last old frame and produces a 2-frame history for a model that expects T frames — causing a shape mismatch on the very next model call. The trainer's _loss method handles this correctly with torch.cat([per_member_hist[:, :, 1:], next_frame.unsqueeze(2)], dim=2).

Suggested change
rollout_history = torch.stack(
[rollout_history[:, 1], next_frame],
dim=1,
)
rollout_history = torch.cat(
[rollout_history[:, 1:], next_frame.unsqueeze(1)],
dim=1,
)

# record instead of sitting in a print() stdio buffer under srun.
self.logger = RankZeroLoggingWrapper(PythonLogger("fgn"), self.dist)
self.logger.info("Trainer.__init__ starting")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 amp config field is silently ignored throughout the trainer. TrainingConfig declares amp: bool = False and the production shell script in scripts/train_2024_val.sh explicitly passes training.amp=true, but self.cfg.training.amp is never read. No torch.cuda.amp.autocast context or GradScaler is used anywhere in train() or _loss(), so AMP is a silent no-op even when enabled. Users following the provided scripts would expect half-precision training but get none of the memory savings or throughput benefits.

Comment on lines +10 to +16
#SBATCH --output=/mnt/home/kashif/physicsnemo/examples/weather/fgn/logs/train_2024_val_%j.log

set -euo pipefail

EXAMPLE_DIR="/mnt/home/kashif/physicsnemo/examples/weather/fgn"

source /mnt/data/kashif/miniconda3/etc/profile.d/conda.sh
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Personal paths hardcoded in the SLURM scripts. Both the #SBATCH --output directive (line 10) and EXAMPLE_DIR / conda source path (lines 14, 16) reference /mnt/home/kashif/ and /mnt/data/kashif/, which will silently fail for any other user. The same pattern appears in scripts/compute_stats_2024.sh and scripts/prefetch_arco_2024.sh. These should use environment variables or relative paths so the scripts are usable without modification.

Comment thread examples/weather/fgn/utils/trainer.py Outdated
Comment on lines +319 to +330
for k in range(ar_steps):
members = []
for n in range(num_samples):
hist_n = per_member_hist[:, n]
pred_n = self._step_ensemble(
history=hist_n,
background=background,
invariants=invariants,
num_samples=1,
)[:, 0]
members.append(pred_n)
preds = torch.stack(members, dim=1) # (B, N, C, H, W)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant wrapping of single model forward pass through _step_ensemble. The outer for n in range(num_samples) loop calls _step_ensemble(..., num_samples=1) for each member, which internally runs its own for _ in range(1) loop. This double-loops and makes the code harder to follow; calling the model directly (as the validation loop in _run_validation_metrics does) would be cleaner and avoids the vestigial num_samples=1 sentinel.

kashif added 3 commits May 21, 2026 12:01
Remove cluster-specific slurm scripts (local paths), untrack FGN.md
(dev notes), add .gitignore, and fix README references.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…e indirection

- Fix inference _rollout: torch.cat([history[:, 1:], next_frame.unsqueeze(1)])
  so history window slides correctly for any history_frames value, not just 2
- Remove unimplemented amp config field from TrainingConfig and default.yaml
- Inline model call in _loss AR loop instead of routing through _step_ensemble
  with num_samples=1 (each member needs its own history, so the single-call
  collapse doesn't apply; direct call is cleaner)

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Run ruff format + fix across all fgn/ Python files
- Remove unused imports (Sequence, Callable, ShardTensor, math, torch)
- Replace assert with if/raise (S101), fix import order (I001),
  simplify loops to list-comprehension/extend (PERF401/102)
- Add noqa: E402 on intentional post-path-insert imports in stage4
- Upgrade FGNUNet docstring to MOD-003 (r-string, NumPy sections,
  Parameters/Forward/Outputs with LaTeX shapes, Examples)
- Add CHANGELOG.md entry under [2.1.0a0]

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
kashif added 8 commits May 21, 2026 12:18
…nfig

- utils/metrics.py: add energy_score_per_lead — fair energy score
  (multivariate CRPS) over the variable axis with spatial subsampling;
  new in earth2studio 0.13.0, captures cross-channel calibration
- utils/trainer.py: wire energy_score_per_lead into validation hook,
  save to metrics.npz and plot energy_score_vs_lead.png
- config/fgn.yaml: base Hydra config required by train.py
  (@hydra.main config_name="fgn") with model defaults and dataset
  skeleton; was missing, causing Hydra to error without all overrides
- config/fgn_arco.yaml: practical single-GPU ARCO ERA5 training config
  (2018–2022 train / 2023 val, hidden_channels=64, 5000 steps, full
  loss weights) for runs beyond the smoke-test default

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- datasets/__init__.py: auto-discovery registry (mirrors stormcast)
  that populates dataset_classes dict by scanning all FGNDataset
  subclasses; fixes ImportError since the regular HF `datasets`
  package beat the namespace package without __init__.py
- datasets/dataset.py: FGNDataset ABC (state_channels, background_channels,
  image_shape, get_invariants, output_only_channels) + worker_init;
  mirrors stormcast/datasets/dataset.py convention
- utils/loss.py: fair_crps (paper eq. 4), ensemble_mean_mse,
  build_channel_weights (§2.2.3 GraphCast scheme with z halved),
  build_area_weights (cos-lat normalised to unit mean)

All three files existed locally before the branch cleanup but were
never committed; this adds them properly.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Use _make_train_iter() to route through sharded_data_iter when domain
  parallelism is active (mirrors StormCast trainer pattern); plain DDP
  path gets an infinite-restart iterator instead of the old bare iter()
- Wrap both model forward sites in torch.autocast(bfloat16) and call
  .float() on preds to keep loss computation in fp32; halves activation
  memory at full 721x1440 resolution on H100 80GB
- train_fgn.sh: batch_size=1, domain_parallel_size=1 (DDP), run_id
  Hydra string quoting fix, PYTORCH_CUDA_ALLOC_CONF=expandable_segments

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
With domain_parallel_size=1 and 2 GPUs, data_parallel_size=2 so
local_batch = batch_size // 2; batch_size=1 → local_batch=0 causing
BatchSampler ValueError. Use batch_size=2 (global) = 1 per GPU.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Mark tp06, multi-rank sanity, AR stage scheduler, bad-seed detector
as done. Add status for currently running 5000-step job 99807.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Replaces the stale MVP-scaffold README with a full recipe README
modelled on stormcast/README.md: problem overview, dataset (ARCO),
getting started, configuration table, training (single-GPU / torchrun /
SLURM), AR fine-tuning schedule, inference, custom dataset interface,
memory guidance, and references.

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Link to developers.google.com/weathernext/guides/models and
  model-specs-vmg from the intro and References section
- Clarify production deployment: 64 members (4 seeds × 16 each)
- Note u100m/v100m omission: ERA5/ARCO lacks 100m winds

Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant