Skip to content

Radiation transport example#1647

Open
melo-gonzo wants to merge 68 commits into
NVIDIA:mainfrom
melo-gonzo:radiation-transport-example
Open

Radiation transport example#1647
melo-gonzo wants to merge 68 commits into
NVIDIA:mainfrom
melo-gonzo:radiation-transport-example

Conversation

@melo-gonzo
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

Radiation Transport Surrogate Model with Transolver

This PR adds a new PhysicsNeMo example under examples/nuclear_engineering/radiation_transport/ that trains a Transolver surrogate for the 2-D linear radiation transport equation on two benchmark problems relevant for nuclear reactor assembly design and inertial confinement fusion — the Lattice and Hohlraum benchmarks from Kusch et al. 2025, with data generated by the KiT-RT simulation code. The example is built end-to-end on PhysicsNeMo's: Mesh datapipes, DataLoader / Compose / Normalize transforms, the Transolver model, the CombinedOptimizer (Muon + AdamW), and physicsnemo.utils.checkpoint. It supports distributed training, well-defined quantities of interest, a differentiable physics loss, and reusable modules for extending to other archetectures.

Why

As with many other scientific and engineering pipelines, running simulations is the bottleneck. This workflow demonstrates how PhysicsNeMo, and models traditionally used in CFD, CAE, and other domains, can be reused for radiation transport. Because the KiT-RT code is "benchmark-style," it is a natural interface for validating scientific ML and surrogate models.

Key Changes

New example tree

examples/nuclear_engineering/radiation_transport/
├── README.md                 # walkthrough: science, install, dataset, training, eval
├── DATASET_CARD.md           # dataset card describing the .pmsh layout
└── src/                      # 12 Python modules, 8 YAML configs
    ├── train.py              # Hydra entry — composes case/data/model/train
    ├── inference.py          # Hydra-driven evaluation; writes metrics + figures
    ├── trainer.py            # training loop (DDP, AMP, gradient accumulation, warmup+cosine)
    ├── dataset.py            # `RTEBaseDataset` over a directory of `.pmsh/` stores
    ├── loader.py             # `TransolverAdapter`, `collate_no_padding`, `build_dataloaders`
    ├── transforms.py         # RTE-specific `Transform`s registered with the datapipes registry
    ├── losses.py             # region-weighted MSE + QoI physics loss
    ├── qoi.py                # differentiable QoI evaluators (final-time, T=1)
    ├── evaluation_metrics.py # field + QoI aggregators
    ├── checkpointing.py      # `best_model/` checkpointing, Muon + AdamW combo optimizer
    ├── compute_normalizations.py  # standalone CLI to produce flux / material stats YAMLs
    ├── viz.py                # 3-panel flux plot + per-region QoI scatter
    └── conf/                 # Hydra groups: case/, data/, model/, train/, inference/

Data layout

Each simulation is one <name>.pmsh/ directory (written by physicsnemo.mesh.Mesh.save) next to a <name>.attrs.json sidecar. RTEBaseDataset._load uses physicsnemo.mesh.Mesh.load for the memmap tensors and reads raw_attrs from the sidecar, exposing it as a NonTensorData metadata entry on the returned TensorDict. Splits are basename arrays; the reader appends .pmsh when opening stores.

Training

  • DDP-ready via torchrun --nproc_per_node=N src/train.py; single-process
    works via plain python (no DDP-specific code is gated on launch
    detection beyond what DistributedManager provides).
  • AMP via torch.amp.autocast + GradScaler (fp16) / direct autocast (bf16).
  • Optimizer is Adam by default; train.optimizer.type=muon returns a
    CombinedOptimizer with torch.optim.Muon for 2-D weight matrices
    • AdamW for everything else.
  • Scheduler is SequentialLR([LinearLR, CosineAnnealingLR]) for warmup
    • cosine annealing.
  • Single best-by-val_loss checkpoint kept at checkpoints/best_model/.

Known Limitations

  • Hohlraum boundary input flux is present in the source data but is
    not used as a model input in this example.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 14, 2026

Greptile Summary

This PR adds a new end-to-end example under examples/nuclear_engineering/radiation_transport/ that trains a Transolver surrogate for 2-D linear radiation transport on the Lattice and Hohlraum benchmarks from Kusch et al. 2025. The implementation is well-structured, leveraging PhysicsNeMo's datapipes, transforms, Transolver model, CombinedOptimizer, and checkpoint utilities with DDP, AMP, gradient accumulation, and a differentiable physics (QoI) loss.

  • Checkpoint metadata bug (trainer.py): save_best_checkpoint is passed the old best_val_loss in the metadata dict, not the new val_loss that triggered the save. On resume, the incorrect threshold is loaded, allowing future validation losses that are worse than the true best to overwrite the checkpoint.
  • set_seed counterproductive flag (trainer.py): torch.backends.cudnn.benchmark = True is set inside set_seed, which disables deterministic algorithm selection and contradicts the reproducibility goal.
  • Minor style issues: max_grad_norm is hardcoded at 10.0 without a config override, and _parse_amp (a private symbol) is imported across module boundaries in train.py.

Important Files Changed

Filename Overview
examples/nuclear_engineering/radiation_transport/src/trainer.py Core training loop with DDP, AMP, gradient accumulation, and checkpointing. Contains a P1 bug: stale best_val_loss written into checkpoint metadata breaks resume logic. Also sets cudnn.benchmark=True inside set_seed (counterproductive for reproducibility) and hardcodes max_grad_norm=10.0 without config exposure.
examples/nuclear_engineering/radiation_transport/src/train.py Hydra entry point for training; imports private _parse_amp helper from trainer.py which is fragile. Otherwise well-structured.
examples/nuclear_engineering/radiation_transport/src/checkpointing.py Optimizer creation, checkpoint save/resume helpers. Logic is sound; stale metadata issue originates in caller (trainer.py), not here.
examples/nuclear_engineering/radiation_transport/src/dataset.py Filename-indexed dataset and mesh reader with in-memory static-array caching. Overrides _load_and_transform using a private _PrefetchResult protocol, which is a fragile dependency on PhysicsNeMo internals.
examples/nuclear_engineering/radiation_transport/src/losses.py Region-weighted MSE, physics (QoI) loss, LR scheduler construction. Well-structured with correct dispatch and warmup ramp logic.
examples/nuclear_engineering/radiation_transport/src/loader.py DataLoader construction, TransolverAdapter transform, and collate function. Correctly assembles the per-phase transform pipeline with Fourier features, normalization, and spatial subsampling.
examples/nuclear_engineering/radiation_transport/src/transforms.py RTE-specific transforms: log-clip flux, coordinate backup, Fourier features, spatial subsampling, final-time extraction, material property stacking. Transform ordering in loader.py is correct (SpatialSampler before RTEBackupCoords).
examples/nuclear_engineering/radiation_transport/src/qoi.py Differentiable QoI evaluators for lattice and hohlraum benchmarks. Uses only batch_size=1 (documented) and correctly handles the KiT-RT symmetric wall behavior.
examples/nuclear_engineering/radiation_transport/src/compute_normalizations.py One-shot CLI to compute flux and material normalization statistics. Numerically stable variance accumulation, consistent with training-pipeline preprocessing.
examples/nuclear_engineering/radiation_transport/src/inference.py Hydra-driven evaluation script yielding per-sample metrics, QoI comparisons, and visualizations. Clean and well-guarded against missing optional fields.
examples/nuclear_engineering/radiation_transport/src/evaluation_metrics.py Pointwise and QoI metric aggregation. Straightforward numpy operations with correct error handling for empty inputs.

Comments Outside Diff (1)

  1. examples/nuclear_engineering/radiation_transport/src/trainer.py, line 1771-1789 (link)

    P1 Stale best_val_loss written into checkpoint metadata

    save_best_checkpoint is called with metadata={"best_val_loss": best_val_loss, ...} where best_val_loss is still the old best (before the save). The function only runs when val_loss < best_val_loss and returns float(val_loss), updating the Python local variable after the checkpoint is on disk.

    When training is later resumed, resume_if_available reads metadata.get("best_val_loss", float("inf")), which will be the stale pre-save value (e.g., inf on the very first save). This means the resume threshold is reset too high, allowing any future validation loss—even one worse than the actual best model—to overwrite the "best" checkpoint. The fix is to pass val_loss as "best_val_loss" in the metadata dict.

Reviews (1): Last reviewed commit: "docs: add split file to args for trainin..." | Re-trigger Greptile

Comment thread examples/nuclear_engineering/radiation_transport/src/trainer.py Outdated
case_type = cfg.case.type
use_amp, amp_dtype = _parse_amp(cfg)
accum_steps = cfg.train.get("gradient_accumulation_steps", 1)
max_grad_norm = 10.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 max_grad_norm = 10.0 is hardcoded inside the training loop with no way to override it from config. Other training hyperparameters are all configurable via cfg.train, but gradient clipping is silently fixed at 10.

Suggested change
max_grad_norm = 10.0
max_grad_norm = cfg.train.get("max_grad_norm", 10.0)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in a164911

Comment thread examples/nuclear_engineering/radiation_transport/src/train.py
pretrain_checkpoint: null
resume_checkpoint: null

# objective = mse_weight * regression_mse + physics_loss.weight * (regression_mse + qoi_loss); region_weighted swaps regression_mse for the weighted variant.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment claims objective = mse_weight * regression_mse + physics_loss.weight * (regression_mse + qoi_loss), but trainer.compute_losses (line 286 in src/trainer.py) actually computes loss = mse_w * loss_mse + physics_w * loss_qoi — i.e., physics_loss.weight multiplies only qoi_loss, not regression_mse + qoi_loss. Consider updating the comment.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in bb5d5d8

directly (skipping ``self._load``), so we override here too. Thread
pool + CUDA-stream wiring is inherited from the base class.
"""
from physicsnemo.datapipes.protocols import _PrefetchResult
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a private API import. Unrelated to this PR, but if this method is useful in custom datapipes, should we consider making it a public API?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good catch, and is the result of packing in some non-tensor fields into the custom dataset. This can be resolved by updating the dataset.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went ahead and updated this - would have been a maintenance burden and confusing for anyone reading the code. The HuggingFace dataset has also been updated to make wiring in the dataset keys smoother. 23c889e

`physicsnemo.mesh.Mesh.save(...)`. The loader uses the first and final
`scalar_flux` snapshots and ignores intermediate snapshots. The fields are:

`Mesh.points` — `(N, 2)` float32 cell-center coordinates.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 5bbd56e. Missed this initially as a result of updating the dataset structure

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be really nice to add an image for what problem is being solved and how the predictions would look like.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some images! 73c6514


The dataset is available on Hugging Face:
[Linear Radiation Transport][hf-rte]. Alternatively, raw simulation data
may be curated from the [KiT-RT repositories](https://github.com/KiT-RT).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not find any information about how this curation should be done.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curation from KiT-RT repo will be left as an exercise for the reader at this point, scripts my be added to that repo in the future.

Copy link
Copy Markdown
Collaborator

@mnabian mnabian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left a few minor comments.

`physicsnemo.mesh.Mesh.save(...)`. The loader uses the first and final
`scalar_flux` snapshots and ignores intermediate snapshots. The fields are:

`Mesh.points` — `(N, 2)` float32 cell-center coordinates.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 5bbd56e. Missed this initially as a result of updating the dataset structure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants