[weather] Add Fgn model #1660
Conversation
Greptile SummaryThis PR adds a new
Important Files Changed
Reviews (1): Last reviewed commit: "Merge branch 'main' into fgn" | Re-trigger Greptile |
| rollout_history = torch.stack( | ||
| [rollout_history[:, 1], next_frame], | ||
| dim=1, | ||
| ) |
There was a problem hiding this comment.
The rollout history update hardcodes
history_frames=2 by stacking rollout_history[:, 1] (always index 1) with the new prediction. When history_frames > 2, this silently discards all but the last old frame and produces a 2-frame history for a model that expects T frames — causing a shape mismatch on the very next model call. The trainer's _loss method handles this correctly with torch.cat([per_member_hist[:, :, 1:], next_frame.unsqueeze(2)], dim=2).
| rollout_history = torch.stack( | |
| [rollout_history[:, 1], next_frame], | |
| dim=1, | |
| ) | |
| rollout_history = torch.cat( | |
| [rollout_history[:, 1:], next_frame.unsqueeze(1)], | |
| dim=1, | |
| ) |
| # record instead of sitting in a print() stdio buffer under srun. | ||
| self.logger = RankZeroLoggingWrapper(PythonLogger("fgn"), self.dist) | ||
| self.logger.info("Trainer.__init__ starting") | ||
|
|
There was a problem hiding this comment.
amp config field is silently ignored throughout the trainer. TrainingConfig declares amp: bool = False and the production shell script in scripts/train_2024_val.sh explicitly passes training.amp=true, but self.cfg.training.amp is never read. No torch.cuda.amp.autocast context or GradScaler is used anywhere in train() or _loss(), so AMP is a silent no-op even when enabled. Users following the provided scripts would expect half-precision training but get none of the memory savings or throughput benefits.
| #SBATCH --output=/mnt/home/kashif/physicsnemo/examples/weather/fgn/logs/train_2024_val_%j.log | ||
|
|
||
| set -euo pipefail | ||
|
|
||
| EXAMPLE_DIR="/mnt/home/kashif/physicsnemo/examples/weather/fgn" | ||
|
|
||
| source /mnt/data/kashif/miniconda3/etc/profile.d/conda.sh |
There was a problem hiding this comment.
Personal paths hardcoded in the SLURM scripts. Both the
#SBATCH --output directive (line 10) and EXAMPLE_DIR / conda source path (lines 14, 16) reference /mnt/home/kashif/ and /mnt/data/kashif/, which will silently fail for any other user. The same pattern appears in scripts/compute_stats_2024.sh and scripts/prefetch_arco_2024.sh. These should use environment variables or relative paths so the scripts are usable without modification.
| for k in range(ar_steps): | ||
| members = [] | ||
| for n in range(num_samples): | ||
| hist_n = per_member_hist[:, n] | ||
| pred_n = self._step_ensemble( | ||
| history=hist_n, | ||
| background=background, | ||
| invariants=invariants, | ||
| num_samples=1, | ||
| )[:, 0] | ||
| members.append(pred_n) | ||
| preds = torch.stack(members, dim=1) # (B, N, C, H, W) |
There was a problem hiding this comment.
Redundant wrapping of single model forward pass through
_step_ensemble. The outer for n in range(num_samples) loop calls _step_ensemble(..., num_samples=1) for each member, which internally runs its own for _ in range(1) loop. This double-loops and makes the code harder to follow; calling the model directly (as the validation loop in _run_validation_metrics does) would be cleaner and avoids the vestigial num_samples=1 sentinel.
Remove cluster-specific slurm scripts (local paths), untrack FGN.md (dev notes), add .gitignore, and fix README references. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…e indirection - Fix inference _rollout: torch.cat([history[:, 1:], next_frame.unsqueeze(1)]) so history window slides correctly for any history_frames value, not just 2 - Remove unimplemented amp config field from TrainingConfig and default.yaml - Inline model call in _loss AR loop instead of routing through _step_ensemble with num_samples=1 (each member needs its own history, so the single-call collapse doesn't apply; direct call is cleaner) Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Run ruff format + fix across all fgn/ Python files - Remove unused imports (Sequence, Callable, ShardTensor, math, torch) - Replace assert with if/raise (S101), fix import order (I001), simplify loops to list-comprehension/extend (PERF401/102) - Add noqa: E402 on intentional post-path-insert imports in stage4 - Upgrade FGNUNet docstring to MOD-003 (r-string, NumPy sections, Parameters/Forward/Outputs with LaTeX shapes, Examples) - Add CHANGELOG.md entry under [2.1.0a0] Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
…nfig - utils/metrics.py: add energy_score_per_lead — fair energy score (multivariate CRPS) over the variable axis with spatial subsampling; new in earth2studio 0.13.0, captures cross-channel calibration - utils/trainer.py: wire energy_score_per_lead into validation hook, save to metrics.npz and plot energy_score_vs_lead.png - config/fgn.yaml: base Hydra config required by train.py (@hydra.main config_name="fgn") with model defaults and dataset skeleton; was missing, causing Hydra to error without all overrides - config/fgn_arco.yaml: practical single-GPU ARCO ERA5 training config (2018–2022 train / 2023 val, hidden_channels=64, 5000 steps, full loss weights) for runs beyond the smoke-test default Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- datasets/__init__.py: auto-discovery registry (mirrors stormcast) that populates dataset_classes dict by scanning all FGNDataset subclasses; fixes ImportError since the regular HF `datasets` package beat the namespace package without __init__.py - datasets/dataset.py: FGNDataset ABC (state_channels, background_channels, image_shape, get_invariants, output_only_channels) + worker_init; mirrors stormcast/datasets/dataset.py convention - utils/loss.py: fair_crps (paper eq. 4), ensemble_mean_mse, build_channel_weights (§2.2.3 GraphCast scheme with z halved), build_area_weights (cos-lat normalised to unit mean) All three files existed locally before the branch cleanup but were never committed; this adds them properly. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Use _make_train_iter() to route through sharded_data_iter when domain parallelism is active (mirrors StormCast trainer pattern); plain DDP path gets an infinite-restart iterator instead of the old bare iter() - Wrap both model forward sites in torch.autocast(bfloat16) and call .float() on preds to keep loss computation in fp32; halves activation memory at full 721x1440 resolution on H100 80GB - train_fgn.sh: batch_size=1, domain_parallel_size=1 (DDP), run_id Hydra string quoting fix, PYTORCH_CUDA_ALLOC_CONF=expandable_segments Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
With domain_parallel_size=1 and 2 GPUs, data_parallel_size=2 so local_batch = batch_size // 2; batch_size=1 → local_batch=0 causing BatchSampler ValueError. Use batch_size=2 (global) = 1 per GPU. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Mark tp06, multi-rank sanity, AR stage scheduler, bad-seed detector as done. Add status for currently running 5000-step job 99807. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
Replaces the stale MVP-scaffold README with a full recipe README modelled on stormcast/README.md: problem overview, dataset (ARCO), getting started, configuration table, training (single-GPU / torchrun / SLURM), AR fine-tuning schedule, inference, custom dataset interface, memory guidance, and references. Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
- Link to developers.google.com/weathernext/guides/models and model-specs-vmg from the intro and References section - Clarify production deployment: 64 members (4 seeds × 16 each) - Note u100m/v100m omission: ERA5/ARCO lacks 100m winds Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
PhysicsNeMo Pull Request
Description
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.