diff --git a/CHANGELOG.md b/CHANGELOG.md index b4e938a567..ded280b59a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Adds radiation transport example (`examples/nuclear_engineering/radiation_transport`) - Adds xDeepONet to experimental models (`physicsnemo.experimental.models.xdeeponet.DeepONet`). A single dimension-generic (2D/3D) DeepONet that accepts a spatial or MLP branch, diff --git a/docs/img/radiation_transport/transolver_hohlraum.png b/docs/img/radiation_transport/transolver_hohlraum.png new file mode 100644 index 0000000000..88d1e6a44b Binary files /dev/null and b/docs/img/radiation_transport/transolver_hohlraum.png differ diff --git a/docs/img/radiation_transport/transolver_lattice.png b/docs/img/radiation_transport/transolver_lattice.png new file mode 100644 index 0000000000..d56145bca8 Binary files /dev/null and b/docs/img/radiation_transport/transolver_lattice.png differ diff --git a/examples/nuclear_engineering/radiation_transport/README.md b/examples/nuclear_engineering/radiation_transport/README.md new file mode 100644 index 0000000000..efead25654 --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/README.md @@ -0,0 +1,493 @@ +# Radiation Transport with Transolver + +A PhysicsNeMo example that trains a [Transolver](https://arxiv.org/abs/2402.02366) +surrogate model for the 2-D linear radiation transport benchmark defined in +[Reference solutions for linear radiation transport: the Hohlraum and Lattice +benchmarks](https://arxiv.org/pdf/2505.17284). The pipeline learns the +final-time mapping from the initial flux snapshot to the final scalar flux, +using a physics-informed loss that combines region-weighted MSE with a +quantity-of-interest (QoI) penalty based on absorption in key regions. + +The dataset used for this example was generated using +[KiT-RT](https://github.com/KiT-RT) [^1], and can be found on Hugging Face: +[Linear Radiation Transport][hf-rte]. + +[hf-rte]: https://huggingface.co/datasets/nvidia/Linear-Radiation-Transport + +--- + +## 1. The science + +The model approximates the final-time scalar flux `φ(x)` of the 2-D linear +radiative-transfer equation. The simulator is run forward in time and the +training target is the last snapshot — the underlying transport problem +is not run to convergence. Inputs to the surrogate are: + +- **Coordinates** `(x, y)` per cell, normalized to `[-1, 1]` and augmented with + Fourier features (3 frequencies × 2 axes × {sin, cos} = 12 extra channels). +- **Material properties** per cell: absorption coefficient `σ_a`, scattering + coefficient `σ_s`, total cross-section `σ_t`, and, for lattice cases, heat + source `Q`. Boundary input flux may be incorporated from upstream hohlraum + data, but it is not used as a model input in this example. + +The surrogate predicts the **z-score-of-log scalar flux**, which is then +inverted via `transforms.denormalize_flux` to recover the physical flux. + +### 1.1 Lattice benchmark + +A square domain partitioned into a 7×7 grid of material blocks. Each block is +either **absorber** (high `σ_a`, low `σ_s`), **scatterer** (low `σ_a`, high +`σ_s`), or **source** (interior `Q > 0`). The model has to capture sharp flux +discontinuities at material interfaces and reproduce the integrated +absorption in the absorbing regions. + +**QoI** — matches **QoI-3** of the reference paper (Kusch et al. 2025, §3.1): +the final-time radiation absorption over the absorbing blocks `B`: + +$$\mathrm{QoI}_{\mathrm{Lattice}} = \int_{B} \sigma_a(x)\,\phi(x, T)\,dx.$$ + +In code this is `cur_absorption`, computed as +`Σ_{c ∈ B} σ_a,c · φ_c · A_c` over absorber cells. + +![Lattice: target, prediction, absolute error of final-time flux][lattice-fig] + +[lattice-fig]: ../../../docs/img/radiation_transport/transolver_lattice.png + +### 1.2 Hohlraum benchmark + +An axisymmetric cylindrical cavity with interior void regions, +representing a simplified inertial-confinement-fusion target. There is no +interior heat source — flux enters from boundary conditions and propagates +through the cavity. Geometry parameters (upper/lower laser-entry radii, +center offsets) vary across simulations. + +**QoI** — variation of **QoI-2** of the reference paper (Kusch et al. +2025, §3.2): per-material final-time absorption, evaluated separately +over each of three regions `S ∈ {G ∪ B, R, K}`: + +$$\mathrm{QoI}_{\mathrm{Hohlraum}, S} = \int_{S} \sigma_a(x)\,\phi(x, T)\,dx.$$ + +In code the three regions are labeled +`cur_absorption_{center, vertical, horizontal}` and each is computed as +`Σ_{c ∈ S} σ_a,c · φ_c · A_c`. The training-time physics loss +additionally synthesizes a fourth `total` term as the mean of the three, +so every region contributes to the gradient (mean-of-four). Inference +reports the three component QoIs only. + +![Hohlraum: target, prediction, absolute error of final-time flux][hohlraum-fig] + +[hohlraum-fig]: ../../../docs/img/radiation_transport/transolver_hohlraum.png + +--- + +## 2. Installation + +Prerequisites: + +- **PhysicsNeMo** — install the host repo with `[model-extras,datapipes-extras]` + to get `physicsnemo.models.transolver.Transolver` and the `tensordict`-based + data utilities. + +From the PhysicsNeMo repo root, install the example dependencies: + +```bash +uv pip install -e ".[model-extras,datapipes-extras]" tensorboard +``` + +--- + +## 3. Dataset + +### 3.1 Data source + +The dataset is available on Hugging Face: +[Linear Radiation Transport][hf-rte]. Alternatively, raw simulation data +may be curated from the [KiT-RT repositories](https://github.com/KiT-RT). + +### 3.2 Expected on-disk layout + +The runtime data format is the PhysicsNeMo `Mesh` memmap layout. Each +simulation lives in a `.pmsh/` directory next to a `.attrs.json` +sidecar, loaded via `physicsnemo.mesh.Mesh.load(.pmsh)`. + +```text +/ +├── lattice/ +│ ├── lattice_abs_scatter_p

_q.pmsh/ +│ ├── lattice_abs_scatter_p

_q.attrs.json +│ └── ... +├── hohlraum/ +│ ├── hohlraum_variable_cl<...>_q<...>_ulr<...>_llr<...>_<...>.pmsh/ +│ ├── hohlraum_variable_cl<...>_q<...>_ulr<...>_llr<...>_<...>.attrs.json +│ └── ... +├── splits/ +│ ├── lattice_splits.json # train/val/test split lists +│ └── hohlraum_splits.json +└── stats/ + ├── lattice_flux_stats.yaml + ├── lattice_material_stats.yaml + ├── hohlraum_flux_stats.yaml + └── hohlraum_material_stats.yaml +``` + +### 3.3 What's in each mesh store + +Each `*.pmsh/` directory is one simulation written via +`physicsnemo.mesh.Mesh.save(...)`. The flux series is stored as just +the first and final snapshots (`T = 2`); only those are used. + +Cell-center coordinates and per-cell areas are not stored as fields — +the loader derives them from the mesh topology via `mesh.cell_centroids` +and `mesh.cell_areas`. + +`Mesh.cell_data` (per-cell tensors the loader requires): + +| Key | Shape | Dtype | Notes | +|---|---|---|---| +| `scalar_flux` | `(N, 2)` | float32 | flux at first / final snapshot, cells-first | +| `material_id` | `(N,)` | int64 | region IDs (mapped by `LatticeMaterialMapper` / `HohlraumMaterialMapper`) | +| `sigma_a`, `sigma_s`, `sigma_t` | `(N,)` | float32 | absorption / scattering / total cross-section | +| `Q` | `(N,)` | float32 | heat source (non-zero in lattice; zeros in hohlraum) | + +`Mesh.global_data`: the loader consumes only `sim_time` (shape `(2,)`, +simulation time of each flux snapshot). Other simulation diagnostics +shipped with the data (`cur_absorption`, `total_absorption`, `mass`, +...) are ignored at training time, but may be useful for other downstream tasks. + +`.attrs.json` (sidecar): JSON with `case_type`, +`simulation_params`, `solver_config`, and `mesh_info`. The loader +exposes the full dict as a `metadata` `NonTensorData` entry on the +returned `TensorDict`. + +`N` is the number of cells per simulation (~tens of thousands). Different +simulations may have different `N` — point-cloud collation handles this. + +### 3.4 Splits file format + +The dataset reader (`dataset._load_split_from_file`) expects a wrapped +JSON document with a `"splits"` key: + +```json +{ + "case_type": "lattice", + "split_name": "default", + "total_samples": 707, + "train_size": 494, + "val_size": 106, + "test_size": 107, + "splits": { + "train": ["lattice_abs52.5_scatter4.6_p0.015_q6", ...], + "val": ["lattice_abs85.0_scatter9.1_p0.015_q6", ...], + "test": ["lattice_abs77.5_scatter4.1_p0.015_q6", ...] + } +} +``` + +Filenames in the splits arrays are **basenames** without any format +suffix; the reader appends `.pmsh` when opening stores. + +If the splits file is named with a different suffix, point at it explicitly: + +```bash +... case.split_file=/splits/my_split_file.json +``` + +### 3.5 Computing normalization stats + +If `/stats/_{flux,material}_stats.yaml` are missing (e.g. you +re-curated the data, or you started from a fresh download that only ships +flux stats), generate them with: + +```bash +python src/compute_normalizations.py \ + --data_path /Datasets/lattice \ + --case_type lattice \ + --split_file /Datasets/splits/lattice_splits.json \ + --output_dir /Datasets/stats + +python src/compute_normalizations.py \ + --data_path /Datasets/hohlraum \ + --case_type hohlraum \ + --split_file /Datasets/splits/hohlraum_splits.json \ + --output_dir /Datasets/stats +``` + +`--split_file` is required so stats are computed over the same train split +used by training. + +The flux stats YAML contains the log-flux mean/std/min/max + `clip_threshold`, +used by `RTEFluxLogClip` and `denormalize_flux`. The material stats YAML +contains per-channel mean/std/min/max for `{σ_a, σ_s, σ_t, Q}`. + +--- + +## 4. Training + +### 4.1 Quick start + +Full-mesh training used at least a 48 GB GPU during development (RTX6000 Ada). + +Lattice: + +```bash +python src/train.py case=lattice data=lattice \ + case.data_root= \ + case.split_file=./path/to/lattice_splits.json +``` + +Hohlraum: + +```bash +python src/train.py case=hohlraum data=hohlraum \ + case.data_root= \ + case.split_file=./path/to/hohlraum_splits.json +``` + +Single-process default: 500 epochs, AMP-bf16, cosine LR with 10 warmup epochs, +peak LR 3e-5, physics loss enabled at weight 0.005 (lattice) / 0.01 (hohlraum). + +### 4.2 Multi-GPU + +```bash +torchrun --nproc_per_node=N src/train.py \ + case=lattice data=lattice case.data_root= +``` + +Use `torchrun` for DDP. A plain `python src/train.py ...` launch runs as a +single process. + +### 4.3 Common overrides + +| Override | Effect | +|---|---| +| `train.epochs=200` | Shorter run | +| `train.optimizer.type=muon` | Use `torch.optim.Muon` for 2-D weights, Adam for biases / norms | +| `train.amp=false` | Disable mixed precision (debug / numerical parity) | +| `train.physics_loss.weight=0.0` | Pure MSE training (disables QoI penalty) | +| `train.max_grad_norm=1.0` | Tighter gradient L2-norm clip (default `10.0`) | +| `train.dataloader.num_streams=4` | CUDA streams used by `physicsnemo.datapipes.DataLoader` for prefetch overlap (no CPU fork workers) | +| `train.dataloader.use_streams=false` | Disable CUDA-stream prefetching — useful for debugging or CPU-only runs | +| `train.dataloader.prefetch_factor=4` | How many batches to prefetch ahead | +| `model.num_spatial_points=8192` | Subsample cells per training step (–1 = use all) | +| `model.n_layers=12 model.n_hidden=384` | Bigger Transolver | +| `model.use_te=true` | Use NVIDIA TransformerEngine layers (requires `[model-extras]`) | +| `train.resume_checkpoint=.../checkpoints/best_model` | Resume from a checkpoint directory | + +### 4.4 Output structure + +Per run, under `outputs/${project.name}/${case.type}/${exp_tag}/`: + +```text +outputs/RTE_Transolver/lattice/transolver/ +├── hydra/ +│ ├── config.yaml # resolved Hydra config (canonical record of the run) +│ ├── hydra.yaml +│ └── overrides.yaml +├── checkpoints/ +│ └── best_model/ # the lowest-val_loss snapshot to date +│ ├── checkpoint.0.0.pt # training state (optimizer, scheduler, scaler, metadata) +│ └── Transolver.0.0.mdlus # model state dict +├── tensorboard/ # TB event files (open with `tensorboard --logdir tensorboard/`) +└── train.log +``` + +Inference defaults to `checkpoints/best_model/` — the single +best-by-val_loss checkpoint maintained during training. No periodic, +rolling, or per-epoch snapshots are kept. + +--- + +## 5. Evaluation + +### 5.1 Run inference + +Inference is Hydra-driven; supply the checkpoint path, data root, and split +file as standard Hydra overrides: + +```bash +RUN=outputs/RTE_Transolver/lattice/transolver +python src/inference.py \ + case=lattice data=lattice \ + case.data_root=/path/to/data_root \ + case.split_file=/path/to/splits.json \ + inference.checkpoint_path=$RUN/checkpoints/best_model \ + inference.output_dir=$RUN/evaluation +``` + +The flux normalization stats file is read from +`cfg.data.flux_normalization_stats_file` (interpolated from `case.data_root` +by default); override it directly via +`data.flux_normalization_stats_file=` if you keep stats elsewhere. + +Inference-specific config keys (under `inference.*`): + +| Key | Effect | +|---|---| +| `inference.checkpoint_path` | Required. Directory containing `Transolver.0.0.mdlus` + `checkpoint.0.0.pt`. Point at the `best_model/` directory under the run's `checkpoints/`. | +| `inference.output_dir` | Required. Where to write `metrics.yaml`, `qoi_metrics.yaml`, and `figures/`. | +| `inference.num_samples` | Cap on the number of test simulations (default: `null` = all). | +| `inference.num_plot_samples` | Number of `flux_panels_.png` figures to write (default: 3, evenly sampled across the test set). | +| `inference.device` | Override torch device (default: `null` = CUDA if available). | +| `inference.use_amp` | Autocast in eval; bf16 on CUDA, off on CPU (default: `true`). | + +The case (`lattice` / `hohlraum`) is selected the same way as in training: +`case= data=`. The dataset root, split file, and material/flux +stats paths interpolate from `case.data_root` exactly as during training. + +### 5.2 Outputs + +```text +/ +├── metrics.yaml # field-level metrics over the whole test set +├── qoi_metrics.yaml # per-region QoI relative error +└── figures/ + ├── flux_panels_0000.png # target / prediction / error 3-panel per plotted sample + ├── ... + └── qoi_true_vs_pred.png # predicted vs ground-truth QoI scatter (one panel per region) +``` + +### 5.3 Metric definitions + +`metrics.yaml::overall` is computed once over **all** evaluation samples +flattened together (denormalized to physical flux): + +| Key | Definition | +|---|---| +| `mse` | `mean((pred − target)^2)` | +| `rmse` | `sqrt(mse)` | +| `mae` | `mean(|pred − target|)` | +| `l2_relative_error` | `‖pred − target‖₂ / ‖target‖₂` — the headline number | +| `relative_error` | `mean(|pred − target| / |target|)` — sensitive to near-zero target cells, often dominated by void regions | +| `max_error` | `max(|pred − target|)` | + +`metrics.yaml::per_sample_aggregate` reports `{mean, std, min, max}` of each +metric across simulations — useful for catching outliers (one bad simulation +dominating the mean). + +`qoi_metrics.yaml` reports per-region: + +| Key | Definition | +|---|---| +| `mae` | mean absolute error of the integrated QoI scalar | +| `rmse` | RMSE of the integrated QoI scalar | +| `max_error` | worst single-simulation QoI error | +| `mean_relative_error_pct` | mean of `100 · |Q_pred − Q_true| / |Q_true|` | +| `median_relative_error_pct` | median of the same | +| `max_relative_error_pct` | worst single-simulation relative error | + +For lattice, the only region is `cur_absorption`. For hohlraum, inference +reports `cur_absorption_{center, vertical, horizontal}` when geometry +metadata is available on the sample (the training-time physics loss +additionally averages in a synthesized `total` term across those three +regions; inference does not). + +### 5.4 Comparing runs + +The single most useful comparison is +**`qoi_metrics.yaml::::mean_relative_error_pct`**. On the default +randomized splits, a well-trained surrogate should reach low single-digit +percent QoI error. + +For field-level comparisons, use `metrics.yaml::overall::l2_relative_error`, +which helps interpret global flux structure and sharp interface features. + +--- + +## 6. Interpreting model performance + +### 6.1 What "good" looks like + +A converged model on either benchmark typically reaches `l2_relative_error` +in the **1–2%** range and per-region QoI `mean_relative_error_pct` **below +1%**. + +### 6.2 Reading the training log + +Each epoch logs train/validation loss and any per-component sub-losses +present (`mse`, `qoi`, `qoi_`, ...) followed by the current +learning rate. A typical line looks like: + +```text +Epoch 500: train_loss=1.7081e-05, val_loss=2.0973e-05, + train_mse=1.7032e-05, val_mse=2.0900e-05, + train_qoi=9.8040e-06, val_qoi=1.4658e-05, lr=1.00e-06 +``` + +A `best_model/` checkpoint is written whenever `val_loss` improves; no +periodic per-epoch snapshots are kept. + +### 6.3 Reading the inference figures + +- **`flux_panels_.png`** — three panels per sample: target, + prediction, absolute error. +- **`qoi_true_vs_pred.png`** — predicted vs ground-truth QoI scatter, one + panel per region. Points should lie close to the `y = x` diagonal + across the full test set. + +--- + +## 7. Configuration reference + +All training hyperparameters live under `src/conf/`, composed by Hydra: + +```text +src/conf/ +├── config.yaml # root: composes case / data / model / train / inference +├── case/{lattice,hohlraum}.yaml +├── data/{lattice,hohlraum}.yaml +├── model/transolver.yaml +├── train/base.yaml +└── inference/default.yaml +``` + +`config.yaml` defaults list: + +```yaml +defaults: + - case: lattice + - data: lattice + - model: transolver + - train: base + - inference: default + - _self_ +``` + +CLI overrides follow Hydra's standard syntax: + +```bash +python src/train.py \ + case=hohlraum data=hohlraum \ + case.data_root=/path/to/data \ + train.epochs=300 \ + train.optimizer.type=muon \ + train.physics_loss.weight=0.02 \ + model.n_layers=12 model.n_hidden=384 +``` + +The Hydra group structure means `case=hohlraum` swaps the entire +`case/hohlraum.yaml` (including `physics_loss_weight`, +`include_q_in_embedding`, and `embedding_dim_override`). The downstream +`train/base.yaml` and `model/transolver.yaml` interpolate from `${case.*}` +so case-specific overrides propagate automatically. + +--- + +## References + +[^1]: Kusch, J., Schotthöfer, S., Stammer, P., Wolters, J., & Xiao, T. (2023). +"KiT-RT: An extendable framework for radiative transfer and therapy." +*ACM Transactions on Mathematical Software*, **49**(4), 1–24. + +```bibtex +@article{kitrt2023, + title = {KiT-RT: An extendable framework for radiative transfer and therapy}, + author = {Kusch, Jonas and Schotth{\"o}fer, Steffen and Stammer, Pia + and Wolters, Jannick and Xiao, Tianbai}, + journal = {ACM Transactions on Mathematical Software}, + volume = {49}, + number = {4}, + pages = {1--24}, + year = {2023}, + publisher = {ACM New York, NY} +} +``` diff --git a/examples/nuclear_engineering/radiation_transport/src/checkpointing.py b/examples/nuclear_engineering/radiation_transport/src/checkpointing.py new file mode 100644 index 0000000000..bf059d7a1d --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/checkpointing.py @@ -0,0 +1,256 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import shutil +from pathlib import Path +from typing import Any, Dict, Tuple, Union + +import hydra +import torch +import torch.nn as nn +from omegaconf import DictConfig, OmegaConf + +from physicsnemo.distributed import DistributedManager +from physicsnemo.optim import CombinedOptimizer +from physicsnemo.utils.checkpoint import load_checkpoint + + +def create_optimizer( + model: nn.Module, + optimizer_type: str = "adam", + learning_rate: float = 1e-3, + weight_decay: float = 0.0, + muon_momentum_beta: float = 0.95, + logger=None, +) -> torch.optim.Optimizer: + """Create optimizer based on configuration. + + For ``optimizer_type='muon'`` returns a hybrid: Muon for 2D weight + matrices, AdamW for 1D params (biases, layer norms, embeddings). Muon + only supports 2D weight matrices, hence the split. The shared + ``learning_rate`` drives both halves because Muon is constructed with + ``adjust_lr_fn='match_rms_adamw'``. + """ + if optimizer_type not in ("adam", "muon"): + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + + if optimizer_type == "muon": + return _create_muon_optimizer( + model=model, + learning_rate=learning_rate, + weight_decay=weight_decay, + muon_momentum_beta=muon_momentum_beta, + logger=logger, + ) + + optimizer = torch.optim.Adam( + model.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + ) + if logger: + logger.info( + f"Using Adam optimizer with lr={learning_rate}, weight_decay={weight_decay}" + ) + return optimizer + + +def _create_muon_optimizer( + model: nn.Module, + learning_rate: float, + weight_decay: float, + muon_momentum_beta: float, + logger=None, +) -> torch.optim.Optimizer: + """Build a Muon + AdamW combined optimizer (Muon for 2D, AdamW for the rest). + + Requires PyTorch >= 2.9 for ``torch.optim.Muon`` with the + ``adjust_lr_fn`` argument. + """ + if not hasattr(torch.optim, "Muon"): + raise ImportError( + "Muon optimizer requires PyTorch >= 2.9. " + "Install a newer PyTorch or use optimizer.type=adam." + ) + base_model = model.module if hasattr(model, "module") else model + muon_params = [p for p in base_model.parameters() if p.ndim == 2] + other_params = [p for p in base_model.parameters() if p.ndim != 2] + + if logger: + logger.info( + f"Muon optimizer: {len(muon_params)} 2D params, " + f"{len(other_params)} other params, lr={learning_rate}" + ) + + muon = ( + torch.optim.Muon( + muon_params, + lr=learning_rate, + momentum=muon_momentum_beta, + weight_decay=weight_decay, + adjust_lr_fn="match_rms_adamw", + ) + if muon_params + else None + ) + adamw = ( + torch.optim.AdamW( + other_params, + lr=learning_rate, + weight_decay=weight_decay, + ) + if other_params + else None + ) + + if muon and adamw: + return CombinedOptimizer([muon, adamw]) + return muon or adamw + + +def save_best_checkpoint( + checkpoint_dir: Path, + val_loss: float, + best_val_loss: float, + save_checkpoint_fn, + logger=None, + **checkpoint_kwargs, +) -> float: + """Save a single ``best_model/`` checkpoint when ``val_loss`` improves. + + Returns the (possibly unchanged) ``best_val_loss``. Skips with a warning + when ``val_loss`` is not finite, and is a no-op when the current loss does + not beat the previous best. + """ + if not math.isfinite(float(val_loss)): + if logger: + logger.warning( + " Skipping best-checkpoint save: non-finite val_loss=%s", val_loss + ) + return best_val_loss + + if val_loss >= best_val_loss: + return best_val_loss + + checkpoint_dir = Path(checkpoint_dir) + best_model_dir = checkpoint_dir / "best_model" + if best_model_dir.exists(): + shutil.rmtree(best_model_dir) + best_model_dir.mkdir(parents=True, exist_ok=True) + + epoch = checkpoint_kwargs.pop("epoch") + save_checkpoint_fn(path=str(best_model_dir), epoch=epoch, **checkpoint_kwargs) + + if logger: + logger.info( + f" New best model! val_loss={val_loss:.6f} (prev best: {best_val_loss:.6f})" + ) + return float(val_loss) + + +def resume_if_available( + cfg: DictConfig, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Any, + scaler: Any, + dist: DistributedManager, + logger: Any, +) -> Tuple[int, float]: + """Resume full training state or load pretrain weights, if configured. + + Returns ``(start_epoch, best_val_loss)``. PhysicsNeMo's ``load_checkpoint`` + raises on missing files, so no pre-validation is performed here. + """ + resume_checkpoint = cfg.train.get("resume_checkpoint", None) + pretrain_checkpoint = cfg.train.get("pretrain_checkpoint", None) + + if resume_checkpoint: + resume_path = Path(str(resume_checkpoint)) + if dist.rank == 0: + logger.info(f"\nResuming from checkpoint: {resume_path}") + metadata: Dict[str, Any] = {} + start_epoch = load_checkpoint( + path=str(resume_path), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + metadata_dict=metadata, + device=dist.device, + ) + best_val_loss = float(metadata.get("best_val_loss", float("inf"))) + if dist.rank == 0: + logger.info(f" Resumed from epoch {start_epoch}") + if best_val_loss < float("inf"): + logger.info(f" Best val_loss: {best_val_loss:.6f}") + return start_epoch + 1, best_val_loss + + if pretrain_checkpoint: + pretrain_path = Path(str(pretrain_checkpoint)) + if dist.rank == 0: + logger.info( + f"\nLoading pretrained weights for fine-tuning: {pretrain_path}" + ) + load_checkpoint(path=str(pretrain_path), models=model, device=dist.device) + if dist.rank == 0: + logger.info(" Pretrained weights loaded; starting from epoch 0") + return 0, float("inf") + + return 0, float("inf") + + +def load_model_from_checkpoint( + checkpoint_path: Union[str, Path], + cfg: DictConfig, + device: torch.device, +) -> Tuple[nn.Module, Dict[str, Any]]: + """Build the Transolver model from cfg.model and load weights from checkpoint_path. + + The caller supplies the full Hydra cfg (so the model definition is + fully controlled by the inference-time config, not pulled from a + saved training-time snapshot). The checkpoint_path must contain + matching ``checkpoint.0.*.pt`` + ``Transolver.0.*.mdlus`` shards. + + Returns (model in eval mode, metadata dict from the checkpoint). + """ + checkpoint_path = Path(checkpoint_path) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_path}") + + # Build model from cfg.model. Strip RTE-specific keys consumed elsewhere. + cfg_model = OmegaConf.to_container(cfg.model, resolve=True) + for k in ("num_spatial_points", "include_q_in_embedding"): + cfg_model.pop(k, None) + model = hydra.utils.instantiate(cfg_model).to(device) + + metadata: Dict[str, Any] = {} + epoch = load_checkpoint( + path=str(checkpoint_path), + models=model, + metadata_dict=metadata, + device=device, + ) + metadata.setdefault("epoch", epoch) + + model.eval() + print( + f"Loaded model from {checkpoint_path} " + f"(epoch={metadata.get('epoch', '?')}, " + f"params={sum(p.numel() for p in model.parameters()):,})" + ) + return model, metadata diff --git a/examples/nuclear_engineering/radiation_transport/src/compute_normalizations.py b/examples/nuclear_engineering/radiation_transport/src/compute_normalizations.py new file mode 100644 index 0000000000..a2057ae7b1 --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/compute_normalizations.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standalone CLI to compute flux + material statistics over a mesh data root. + +Run this once before training to produce the two YAML statistics files the +training pipeline expects: + + /_flux_stats.yaml + /_material_stats.yaml + +Usage:: + + python src/compute_normalizations.py \\ + --data_path /lattice \\ + --case_type lattice \\ + --split_file /splits/lattice_splits.json \\ + --output_dir /stats + +The flux statistics walk the training split of the dataset, log-clip the raw +``scalar_flux`` field, and accumulate (mean, std, min, max) plus the clip +threshold the training pipeline must use. The material statistics walk the +training split, read the precomputed ``sigma_a / sigma_s / sigma_t / Q`` +fields from each store, and accumulate per-property (mean, std, min, max) +across all cells. +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Dict + +import numpy as np +import torch +import yaml + +from dataset import RTEBaseDataset +from transforms import MaterialPropertyExtractor + + +def compute_flux_statistics( + data_path: Path, + case_type: str, + output_file: Path, + split_file: Path, + clip_threshold: float = 1e-8, +) -> Dict[str, float]: + """Compute flux normalization statistics from the training split. + + Args: + data_path: path to the mesh stores for one case. + case_type: ``"lattice"`` or ``"hohlraum"``. + output_file: destination YAML path. + split_file: split JSON used to select the training split. + clip_threshold: minimum flux value before ``log10``. + Returns: + The statistics dict written to ``output_file``. + """ + print(f"Computing flux statistics for {case_type} [final time only]") + print(f"Data path: {data_path}") + print(f"Split file: {split_file}") + + dataset = RTEBaseDataset( + data_path=data_path, + case_type=case_type, + phase="train", + split_file=split_file, + ) + + print(f"\nProcessing {len(dataset)} training simulations...") + + n_samples = 0 + sum_log_flux = 0.0 + sum_log_flux_sq = 0.0 + min_log_flux = float("inf") + max_log_flux = float("-inf") + + for i in range(len(dataset)): + sample, _ = dataset[i] + flux = sample["scalar_flux"] + if isinstance(flux, torch.Tensor): + flux = flux.detach().cpu().numpy() + flux = np.asarray(flux) + + # ``scalar_flux`` from the reader is shape (T, n_cells) with T=2 + # (first + final snapshots). The target the model predicts is the + # final-time only. + if flux.ndim > 1: + flux = flux[-1] + + # match training-pipeline preprocessing + flux = np.clip(flux, clip_threshold, None) + log_flux = np.log10(flux + clip_threshold) + + n = log_flux.size + n_samples += n + sum_log_flux += float(np.sum(log_flux)) + sum_log_flux_sq += float(np.sum(log_flux**2)) + min_log_flux = min(min_log_flux, float(np.min(log_flux))) + max_log_flux = max(max_log_flux, float(np.max(log_flux))) + + if (i + 1) % 10 == 0: + print(f" Processed {i + 1}/{len(dataset)} simulations") + + mean = sum_log_flux / n_samples + variance = (sum_log_flux_sq / n_samples) - (mean**2) + std = float(np.sqrt(max(variance, 0.0))) + + stats = { + "log_flux_mean": float(mean), + "log_flux_std": float(std), + "log_flux_min": float(min_log_flux), + "log_flux_max": float(max_log_flux), + "clip_threshold": float(clip_threshold), + "num_samples": int(n_samples), + "num_simulations": len(dataset), + "case_type": case_type, + } + + stats["note"] = "computed from the final-time snapshot only" + + output_file = Path(output_file) + output_file.parent.mkdir(parents=True, exist_ok=True) + with open(output_file, "w") as f: + yaml.dump(stats, f, default_flow_style=False, sort_keys=False) + + print("\nFlux statistics:") + print(f" Mean (log flux): {mean:.6f}") + print(f" Std (log flux): {std:.6f}") + print(f" Min (log flux): {min_log_flux:.6f}") + print(f" Max (log flux): {max_log_flux:.6f}") + print(f" Total samples: {n_samples:,}") + print(f"\nSaved to: {output_file}") + + return stats + + +def compute_material_statistics( + data_path: Path, + case_type: str, + output_file: Path, + split_file: Path, +) -> Dict[str, Dict[str, float]]: + """Compute per-property material statistics from the training split. + + Args: + data_path: path to the mesh stores for one case. + case_type: ``"lattice"`` or ``"hohlraum"``. + output_file: destination YAML path. + split_file: split JSON used to select the training split. + + Returns: + The nested statistics dict written to ``output_file``. + """ + print(f"\nComputing material statistics for {case_type}") + print(f"Data path: {data_path}") + print(f"Split file: {split_file}") + + dataset = RTEBaseDataset( + data_path=data_path, + case_type=case_type, + phase="train", + split_file=split_file, + ) + extractor = MaterialPropertyExtractor() + print(f"Dataset loaded: {len(dataset)} samples") + + print("\nAccumulating physical_properties...") + + # We track count, running mean, and M2 (sum of squared deviations from + # the running mean); the population std is sqrt(M2 / count) + prop_names = ("sigma_a", "sigma_s", "sigma_t", "Q") + count = 0 + mean_running = np.zeros(len(prop_names), dtype=np.float64) + m2_running = np.zeros(len(prop_names), dtype=np.float64) + min_running = np.full(len(prop_names), np.inf, dtype=np.float64) + max_running = np.full(len(prop_names), -np.inf, dtype=np.float64) + + for i in range(len(dataset)): + td, _ = dataset[i] + sample = extractor(td) + props = sample["physical_properties"] + if isinstance(props, torch.Tensor): + props = props.detach().cpu().numpy() + # Cast to float64 for the accumulator; the on-disk tensors are fp32. + props = np.asarray(props, dtype=np.float64) + n_i = props.shape[0] + if n_i == 0: + continue + + # Per-batch sufficient stats (mean and M2) for combination + batch_mean = props.mean(axis=0) + batch_m2 = ((props - batch_mean) ** 2).sum(axis=0) + + new_count = count + n_i + delta = batch_mean - mean_running + mean_running = mean_running + delta * (n_i / new_count) + m2_running = m2_running + batch_m2 + (delta**2) * (count * n_i / new_count) + count = new_count + + np.minimum(min_running, props.min(axis=0), out=min_running) + np.maximum(max_running, props.max(axis=0), out=max_running) + + if (i + 1) % 100 == 0: + print(f" Processed {i + 1}/{len(dataset)} samples") + + if count == 0: + raise RuntimeError( + "compute_material_statistics: dataset produced zero cells; " + "cannot compute stats." + ) + + std_running = np.sqrt(m2_running / count) + + stats = { + name: { + "mean": float(mean_running[j]), + "std": float(std_running[j]), + "min": float(min_running[j]), + "max": float(max_running[j]), + } + for j, name in enumerate(prop_names) + } + + print("\nMaterial statistics:") + print("-" * 60) + for prop_name, prop_stats in stats.items(): + print(f"{prop_name}:") + for stat_name, value in prop_stats.items(): + print(f" {stat_name:6s}: {value:10.4f}") + + output_file = Path(output_file) + output_file.parent.mkdir(parents=True, exist_ok=True) + with open(output_file, "w") as f: + yaml.dump(stats, f, default_flow_style=False, sort_keys=False) + print(f"\nSaved to: {output_file}") + + return stats + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Compute flux + material normalization statistics over a mesh data " + "root. Emits two YAML files: _flux_stats.yaml and " + "_material_stats.yaml in the output directory." + ) + ) + parser.add_argument( + "--data_path", + type=Path, + required=True, + help="Path to the mesh data root for one case (e.g. /lattice).", + ) + parser.add_argument( + "--case_type", + type=str, + required=True, + choices=["lattice", "hohlraum"], + help="Case type.", + ) + parser.add_argument( + "--output_dir", + type=Path, + required=True, + help="Directory to write the two YAML statistics files into.", + ) + parser.add_argument( + "--split_file", + type=Path, + required=True, + help="Required split JSON; statistics are computed on its training split.", + ) + parser.add_argument( + "--clip_threshold", + type=float, + default=1e-8, + help="Flux clip threshold used during log-transform (default: 1e-8).", + ) + return parser.parse_args() + + +def main() -> int: + """CLI entry: compute and write flux + material statistics YAMLs.""" + args = _parse_args() + + output_dir: Path = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + + flux_output = output_dir / f"{args.case_type}_flux_stats.yaml" + material_output = output_dir / f"{args.case_type}_material_stats.yaml" + + print("=" * 80) + print("COMPUTE NORMALIZATIONS") + print("=" * 80) + + compute_flux_statistics( + data_path=args.data_path, + case_type=args.case_type, + output_file=flux_output, + split_file=args.split_file, + clip_threshold=args.clip_threshold, + ) + + compute_material_statistics( + data_path=args.data_path, + case_type=args.case_type, + output_file=material_output, + split_file=args.split_file, + ) + + print("\n" + "=" * 80) + print("DONE") + print("=" * 80) + print(f" Flux stats: {flux_output}") + print(f" Material stats: {material_output}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/nuclear_engineering/radiation_transport/src/conf/case/hohlraum.yaml b/examples/nuclear_engineering/radiation_transport/src/conf/case/hohlraum.yaml new file mode 100644 index 0000000000..ac2fb9bc1b --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/conf/case/hohlraum.yaml @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Hohlraum benchmark: axisymmetric cylindrical geometry with optional +# interior void regions. No interior heat source (Q is omitted from the +# embedding). Material properties are mapped via HohlraumMaterialMapper. + +type: hohlraum +data_root: ??? # set by user (HF download root) +data_path: ${case.data_root}/hohlraum +split_file: ${case.data_root}/splits/hohlraum_splits.json + +# Physics-loss configuration (hohlraum-specific override). +physics_loss_weight: 0.01 + +# Embedding/material configuration: hohlraum has no Q field. +include_q_in_embedding: false +embedding_dim_override: 3 # sigma_a, sigma_s, sigma_t diff --git a/examples/nuclear_engineering/radiation_transport/src/conf/case/lattice.yaml b/examples/nuclear_engineering/radiation_transport/src/conf/case/lattice.yaml new file mode 100644 index 0000000000..61f4273d1d --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/conf/case/lattice.yaml @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lattice benchmark: regular grid geometry with heterogeneous material blocks +# and an interior heat source Q. Material properties are mapped from +# integer region labels via LatticeMaterialMapper. + +type: lattice +data_root: ??? # set by user (HF download root) +data_path: ${case.data_root}/lattice +split_file: ${case.data_root}/splits/lattice_splits.json + +# Physics-loss configuration (lattice-specific) +physics_loss_weight: 0.005 + +# Embedding/material configuration +include_q_in_embedding: true # lattice has a heat source +embedding_dim_override: 4 # sigma_a, sigma_s, sigma_t, Q diff --git a/examples/nuclear_engineering/radiation_transport/src/conf/config.yaml b/examples/nuclear_engineering/radiation_transport/src/conf/config.yaml new file mode 100644 index 0000000000..5598a7db3b --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/conf/config.yaml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - case: lattice + - data: lattice + - model: transolver + - train: base + - inference: default + - _self_ + +project: + name: RTE_Transolver + +exp_tag: transolver +output: outputs/${project.name}/${case.type}/${exp_tag} + +hydra: + run: + dir: ${output} + output_subdir: hydra diff --git a/examples/nuclear_engineering/radiation_transport/src/conf/data/hohlraum.yaml b/examples/nuclear_engineering/radiation_transport/src/conf/data/hohlraum.yaml new file mode 100644 index 0000000000..ebb261fb35 --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/conf/data/hohlraum.yaml @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +input_dir: ${case.data_path} +flux_normalization_stats_file: ${case.data_root}/stats/hohlraum_flux_stats.yaml +flux_clip_threshold: 1.0e-8 +normalize_coordinates: true +cache_static_arrays: true + +use_fourier_features: true +fourier_features: + num_frequencies: 3 + coord_dims: 2 + base_frequency: 2.0 diff --git a/examples/nuclear_engineering/radiation_transport/src/conf/data/lattice.yaml b/examples/nuclear_engineering/radiation_transport/src/conf/data/lattice.yaml new file mode 100644 index 0000000000..adbe5a91de --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/conf/data/lattice.yaml @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +input_dir: ${case.data_path} +flux_normalization_stats_file: ${case.data_root}/stats/lattice_flux_stats.yaml +flux_clip_threshold: 1.0e-8 +normalize_coordinates: true +cache_static_arrays: true + +# Fourier features for coordinates (adds 2 * coord_dims * num_frequencies features). +# Default: 3 freq * 2 coords * 2 (sin/cos) = 12 extra features, on top of 2 raw coords (x, y). +use_fourier_features: true +fourier_features: + num_frequencies: 3 + coord_dims: 2 + base_frequency: 2.0 diff --git a/examples/nuclear_engineering/radiation_transport/src/conf/inference/default.yaml b/examples/nuclear_engineering/radiation_transport/src/conf/inference/default.yaml new file mode 100644 index 0000000000..1b03053ea7 --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/conf/inference/default.yaml @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Inference / evaluation knobs. All fields are required to be set +# explicitly (no fallbacks from the training-time config); the user +# supplies them via Hydra overrides on the CLI. + +checkpoint_path: ??? # path to the best_model directory (contains checkpoint.0.*.pt + Transolver.0.*.mdlus) +output_dir: ??? # where to write metrics.yaml, qoi_metrics.yaml, and figures/ +num_samples: null # cap on number of test samples; null = all +num_plot_samples: 3 # number of per-sample flux-panel figures (evenly sampled across the test set) +device: null # torch device override; null = cuda if available else cpu +use_amp: true # autocast in eval; bf16 on CUDA, off on CPU diff --git a/examples/nuclear_engineering/radiation_transport/src/conf/model/transolver.yaml b/examples/nuclear_engineering/radiation_transport/src/conf/model/transolver.yaml new file mode 100644 index 0000000000..a6f7951a33 --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/conf/model/transolver.yaml @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Transolver hyperparameters. The `_target_` is consumed by hydra.utils.instantiate +# inside train.py::build_model. The two RTE-specific keys at the bottom +# (num_spatial_points, include_q_in_embedding) are stripped before instantiation +# — they configure the data adapter, not the model. + +_target_: physicsnemo.models.transolver.Transolver +functional_dim: 14 # 2 coords + 12 Fourier features (2-D simulations) +embedding_dim: ${case.embedding_dim_override} +out_dim: 1 # predicted flux +n_layers: 8 +n_hidden: 256 +n_head: 16 +slice_num: 128 +mlp_ratio: 4 +dropout: 0.0 +use_te: false # default false for portability; flip to true if Transformer Engine is installed. +structured_shape: null + +# RTE-specific (stripped before instantiation in train.py::build_model) +num_spatial_points: -1 +include_q_in_embedding: ${case.include_q_in_embedding} diff --git a/examples/nuclear_engineering/radiation_transport/src/conf/train/base.yaml b/examples/nuclear_engineering/radiation_transport/src/conf/train/base.yaml new file mode 100644 index 0000000000..5baee1d5f0 --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/conf/train/base.yaml @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +epochs: 501 +gradient_accumulation_steps: 1 +seed: 6 +amp: true +amp_dtype: bf16 # bf16 | fp16 +loss_metric: mse +tensorboard: true + +optimizer: + type: muon # adam | muon + weight_decay: 0.0 + muon_momentum_beta: 0.95 # Muon-only; AdamW (for 1D params) uses the shared learning_rate via match_rms_adamw + +learning_rate: 3.0e-5 +min_learning_rate: 1.0e-6 +warmup_epochs: 10 +max_grad_norm: 10.0 # gradient L2-norm clip applied each optimizer step + +pretrain_checkpoint: null +resume_checkpoint: null + +# objective = mse_weight * regression_mse + physics_loss.weight * qoi_loss; region_weighted swaps regression_mse for the weighted variant. +# Physics-informed loss (case-specific weight). +use_physics_loss: true +physics_loss: + weight: ${case.physics_loss_weight} + mse_weight: 1.0 + warmup_epochs: 0 + warmup_start_fraction: 0.0 + +# Region-weighted MSE (heavier penalty on low absorption (void) points). +use_region_weighted_loss: true +region_weights: + void_weight: 10.0 + material_weight: 1.0 + +dataloader: + batch_size: 1 # only batch_size=1 is supported (point-cloud adapter) + prefetch_factor: 4 + num_streams: 4 + use_streams: true + +sampler: + shuffle: true + drop_last: false + +val: + dataloader: + batch_size: 1 + prefetch_factor: 4 + num_streams: 4 + use_streams: true + sampler: + shuffle: false + drop_last: false diff --git a/examples/nuclear_engineering/radiation_transport/src/dataset.py b/examples/nuclear_engineering/radiation_transport/src/dataset.py new file mode 100644 index 0000000000..b34cbeb06f --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/dataset.py @@ -0,0 +1,381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Dict, List, Mapping, Optional, Sequence, Union + +import torch +import yaml +from physicsnemo.datapipes.dataset import Dataset as PhysicsNeMoDataset +from physicsnemo.datapipes.readers.base import Reader +from physicsnemo.datapipes.registry import register +from physicsnemo.datapipes.transforms.base import Transform +from physicsnemo.mesh import Mesh +from tensordict import TensorDict + + +@register("RTEMeshReader") +class MeshDataReader(Reader): + """Filename-indexed reader over a directory of RTE Mesh memmap stores. + + The ``TensorDict`` returned by ``load(filename)`` carries the tensor + fields RTE training and inference rely on. The on-disk format is the + PhysicsNeMo ``Mesh`` memmap layout (``.pmsh/`` + ``.attrs.json`` + sidecar). + + Example: + >>> reader = MeshDataReader( + ... "/path/to/mesh_stores/lattice", + ... filenames=["lattice_abs10.0_scatter0.1_p0.015_q6.pmsh"], + ... ) + >>> td = reader.load(reader.get_filenames()[0]) + >>> print(td["coordinates"].shape) # (N, 2) + """ + + def __init__( + self, + data_path: Path | str, + filenames: Sequence[str], + case_type: Optional[str] = None, + cache_static_arrays: bool = True, + ): + super().__init__(pin_memory=False, include_index_in_metadata=False) + + self.data_path = Path(data_path) + self.case_type = case_type + self.cache_static_arrays = cache_static_arrays + + # Plain dict cache of static-only fields keyed by filename. Mesh + # stores are small enough that the full train+val split fits in RAM + # without eviction. + self._static_cache: Dict[str, Dict[str, torch.Tensor]] = {} + + self._metadata_cache: Dict[str, Dict] = {} + + if not self.data_path.exists(): + raise ValueError(f"Data path {self.data_path} does not exist") + if not self.data_path.is_dir(): + raise ValueError(f"Data path {self.data_path} is not a directory") + + # ``filenames`` is required for train/val/test list + # (typically derived from a split JSON) so that + # ``Reader.__getitem__(idx)`` maps to a stable, intended file. + self._filenames: List[str] = list(filenames) + + def __len__(self) -> int: + return len(self._filenames) + + def _load_sample(self, index: int) -> Dict[str, torch.Tensor]: + td = self.load(self._filenames[index]) + return {key: td[key] for key in td.keys() if isinstance(td[key], torch.Tensor)} + + def _get_sample_metadata(self, index: int) -> Dict: + filename = self._filenames[index] + meta = self.get_metadata(filename) + meta["filename"] = filename + return meta + + def get_filenames(self) -> List[str]: + """Return a copy of the filenames the reader was constructed with.""" + return list(self._filenames) + + def _sidecar_path(self, filename: str) -> Path: + # ``.pmsh`` -> ``.attrs.json`` + stem = filename[: -len(".pmsh")] if filename.endswith(".pmsh") else filename + return self.data_path / f"{stem}.attrs.json" + + def _read_sidecar(self, filename: str) -> Dict: + sidecar = self._sidecar_path(filename) + if not sidecar.exists(): + return {} + with open(sidecar, "r", encoding="utf-8") as f: + return json.load(f) + + def load(self, filename: str) -> TensorDict: + """Load a Mesh memmap store into a ``TensorDict``. + + Reads cell-primary fields from ``mesh.cell_data`` and derives + ``coordinates`` and ``cell_areas`` from the mesh topology. Returned + tensor fields: ``coordinates``, ``cell_areas``, ``scalar_flux``, + ``sim_times``, ``material_properties``, ``sigma_a/s/t``, ``Q``, + plus the eight hohlraum geometry parameters (``ulr, llr, urr, lrr, + hlr, hrr, cx, cy``) when present on the store (hohlraum only). + """ + filepath = self.data_path / filename + if not filepath.exists(): + raise FileNotFoundError(f"Mesh store {filepath} not found") + + mesh = Mesh.load(str(filepath)) + cell_data = mesh.cell_data + global_data = mesh.global_data + + # Flux + timesteps (first -> final-time snapshots from the curated + # time series). ``cell_data['scalar_flux']`` is ``(n_cells, T)``. + if "scalar_flux" not in cell_data.keys(): + raise KeyError(f"cell_data['scalar_flux'] missing from {filepath}") + flux_nT = cell_data["scalar_flux"] + num_timesteps = flux_nT.shape[1] if flux_nT.ndim == 2 else 1 + full = flux_nT.transpose(0, 1).contiguous().to(torch.float32) # (T, n_cells) + resolved = [0] if num_timesteps == 1 else [0, num_timesteps - 1] + + td = TensorDict({}, batch_size=[]) + td["scalar_flux"] = full[resolved].contiguous() + if "sim_time" in global_data.keys() and global_data["sim_time"].numel() > 0: + td["sim_times"] = ( + global_data["sim_time"].to(torch.float32)[resolved].contiguous() + ) + + if self.cache_static_arrays and filename in self._static_cache: + for key, tensor in self._static_cache[filename].items(): + td[key] = tensor + else: + # Coordinates and cell areas come from the topology (Mesh + # properties) so the cell-primary fields share the same (n_cells,) + # indexing. + td["coordinates"] = mesh.cell_centroids.to(torch.float32).contiguous() + td["cell_areas"] = mesh.cell_areas.to(torch.float32).contiguous() + if "material_id" not in cell_data.keys(): + raise KeyError(f"cell_data['material_id'] missing from {filepath}") + td["material_properties"] = ( + cell_data["material_id"].to(torch.int32).contiguous() + ) + for key in ("sigma_t", "sigma_s", "sigma_a", "Q"): + if key not in cell_data.keys(): + raise KeyError(f"cell_data['{key}'] missing from {filepath}") + td[key] = cell_data[key].to(torch.float32).contiguous() + + # Hohlraum geometry parameters: eight 0-D float32 tensors in + # ``mesh.global_data``. + for key in ("ulr", "llr", "urr", "lrr", "hlr", "hrr", "cx", "cy"): + if key in global_data.keys(): + td[key] = global_data[key].to(torch.float32).contiguous() + + if self.cache_static_arrays: + cached_keys = ( + "coordinates", + "cell_areas", + "material_properties", + "sigma_t", + "sigma_s", + "sigma_a", + "Q", + "ulr", + "llr", + "urr", + "lrr", + "hlr", + "hrr", + "cx", + "cy", + ) + self._static_cache[filename] = { + k: td[k] for k in cached_keys if k in td + } + + return td + + def get_metadata(self, filename: str) -> Dict: + """Return metadata (sidecar attrs + shape facts) without a full load.""" + cached = self._metadata_cache.get(filename) + if cached is not None: + return cached + + filepath = self.data_path / filename + mesh = Mesh.load(str(filepath)) + cell_data = mesh.cell_data + global_data = mesh.global_data + + sidecar = self._read_sidecar(filename) + metadata: Dict = {k: v for k, v in sidecar.items() if k != "missing_fields"} + + if "scalar_flux" not in cell_data.keys(): + raise KeyError(f"cell_data['scalar_flux'] missing from {filepath}") + flux_shape = cell_data["scalar_flux"].shape # (n_cells, T) + metadata["num_cells"] = int(flux_shape[0]) + metadata["num_timesteps"] = int(flux_shape[1]) if len(flux_shape) > 1 else 1 + + metadata["has_material_properties"] = "material_id" in cell_data.keys() + has_sim_times = ( + "sim_time" in global_data.keys() and global_data["sim_time"].numel() > 0 + ) + metadata["has_sim_times"] = has_sim_times + if has_sim_times: + metadata["max_sim_time"] = float(global_data["sim_time"][-1].item()) + + self._metadata_cache[filename] = metadata + return metadata + + +class RTEBaseDataset(PhysicsNeMoDataset): + """File-indexed final-time dataset over a directory of mesh stores. + + Wraps :class:`MeshDataReader` and produces ``(TensorDict, metadata)`` + tuples per the :class:`physicsnemo.datapipes.Dataset` contract. The + metadata dict carries the source sidecar attrs plus ``filename``, + ``max_timestep``, ``max_sim_time`` and the resolved ``sim_time`` so the + rest of the pipeline can read them without unpacking ``NonTensorData``. + + The TensorDict still carries the per-sample tensor fields the reader + returned (``coordinates``, ``cell_areas``, ``scalar_flux``, etc.). + Transforms run on it in order; the trailing model adapter (e.g. + :class:`TransolverAdapter`) is wired in by the caller via the + ``transforms`` arg. + """ + + def __init__( + self, + data_path: Path | str, + case_type: Optional[str] = None, + phase: str = "train", + split_file: Optional[Path | str] = None, + seed: Optional[int] = None, + cache_static_arrays: bool = True, + transforms: Optional[Transform | Sequence[Transform]] = None, + device: Optional[Union[str, torch.device]] = None, + ): + self.data_path = Path(data_path) + self.case_type = case_type + self.phase = phase + self.split_file = Path(split_file) if split_file else None + self.seed = seed + + if self.split_file is None: + raise ValueError( + "split_file is required. RTE datasets must use explicit " + "train/val/test splits from a JSON split file." + ) + self.filenames = self._load_split_from_file() + + if not self.filenames: + raise ValueError(f"No files in {phase} split") + + # Hand the split list to the reader so its int-indexed + # ``__getitem__`` (called by ``Dataset._load``) resolves to the + # split's files. + reader = MeshDataReader( + data_path=data_path, + filenames=self.filenames, + case_type=case_type, + cache_static_arrays=cache_static_arrays, + ) + + super().__init__(reader=reader, transforms=transforms, device=device) + + def _load_split_from_file(self) -> List[str]: + if not self.split_file.exists(): + raise FileNotFoundError(f"Split file not found: {self.split_file}") + with open(self.split_file, "r", encoding="utf-8") as f: + split_data = json.load(f) + if "splits" not in split_data: + raise ValueError("Invalid split file format: missing 'splits' key") + if self.phase not in split_data["splits"]: + raise ValueError( + f"Phase '{self.phase}' not found in split file. " + f"Available: {list(split_data['splits'].keys())}" + ) + filenames = split_data["splits"][self.phase] + # Split files may list basenames with or without a ``.pmsh`` suffix. + # Normalize to always point at a mesh store. + normalized: List[str] = [] + for f in filenames: + base = f[: -len(".pmsh")] if f.endswith(".pmsh") else f + normalized.append(base + ".pmsh") + return normalized + + +def load_flux_stats(path: Union[str, Path]) -> dict: + """Read an RTE flux statistics YAML. + + Returns a plain dict with keys ``log_flux_mean``, ``log_flux_std``, + ``clip_threshold``. Raises if any required key is missing. + """ + stats_path = Path(path) + if not stats_path.exists(): + raise FileNotFoundError(f"Flux statistics file not found: {stats_path}") + with open(stats_path, "r") as f: + stats = yaml.safe_load(f) + for key in ("log_flux_mean", "log_flux_std", "clip_threshold"): + if key not in stats: + raise ValueError(f"Flux statistics file missing required key: {key}") + return stats + + +def flux_normalize_kwargs( + stats: Mapping, + field: str = "scalar_flux", +) -> dict: + """Build ``Normalize`` kwargs for the log-clipped flux field. + + Example: + stats = load_flux_stats(path) + Normalize(**flux_normalize_kwargs(stats)) + """ + return { + "input_keys": [field], + "method": "mean_std", + "means": {field: float(stats["log_flux_mean"])}, + "stds": {field: float(stats["log_flux_std"])}, + } + + +def load_material_stats(path: Union[str, Path]) -> dict: + """Read an RTE material statistics YAML. + + Returns the full per-property nested dict. Each of ``sigma_a``, + ``sigma_s``, ``sigma_t``, ``Q`` must be present with ``mean``, ``std``, + ``min``, ``max`` sub-keys. + """ + stats_path = Path(path) + if not stats_path.exists(): + raise FileNotFoundError(f"Material statistics file not found: {stats_path}") + with open(stats_path, "r") as f: + stats = yaml.safe_load(f) + required = ("sigma_a", "sigma_s", "sigma_t", "Q") + for key in required: + if key not in stats: + raise ValueError( + f"Material statistics file missing required property: {key}" + ) + for sub in ("mean", "std"): + if sub not in stats[key]: + raise ValueError( + f"Material statistics[{key!r}] missing required sub-key: {sub!r}" + ) + return stats + + +def material_normalize_kwargs( + stats: Mapping, + field: str = "physical_properties", + order: Sequence[str] = ("sigma_a", "sigma_s", "sigma_t", "Q"), +) -> dict: + """Build ``Normalize`` kwargs for ``physical_properties`` as (N, 4). + + The 4 columns are normalized independently via broadcasting: a per-column + ``torch.Tensor`` of shape ``(4,)`` is passed as the mean and the std, + delegating the math to ``physicsnemo.datapipes.transforms.Normalize``. + """ + means = torch.tensor([float(stats[k]["mean"]) for k in order], dtype=torch.float32) + stds = torch.tensor([float(stats[k]["std"]) for k in order], dtype=torch.float32) + return { + "input_keys": [field], + "method": "mean_std", + "means": {field: means}, + "stds": {field: stds}, + } diff --git a/examples/nuclear_engineering/radiation_transport/src/evaluation_metrics.py b/examples/nuclear_engineering/radiation_transport/src/evaluation_metrics.py new file mode 100644 index 0000000000..fbe53d7a0e --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/evaluation_metrics.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import numpy as np +import torch + +from qoi import ( + evaluate_hohlraum_qoi_torch, + evaluate_lattice_qoi_torch, + extract_geometry_params, +) + +__all__ = [ + "compute_metrics", + "aggregate_metrics", + "compute_sample_qoi", + "aggregate_qoi", +] + + +def compute_metrics( + pred: np.ndarray, target: np.ndarray, eps: float = 1e-10 +) -> Dict[str, float]: + """Compute the full metric panel for one ``(pred, target)`` pair.""" + pred_flat = pred.flatten() + target_flat = target.flatten() + diff = pred_flat - target_flat + abs_diff = np.abs(diff) + mse = float(np.mean(diff**2)) + return { + "mse": mse, + "rmse": float(np.sqrt(mse)), + "mae": float(np.mean(abs_diff)), + "l2_relative_error": float( + np.linalg.norm(diff) / (np.linalg.norm(target_flat) + eps) + ), + "relative_error": float(np.mean(abs_diff / (np.abs(target_flat) + eps))), + "max_error": float(np.max(abs_diff)), + } + + +def aggregate_metrics(per_sample: list[Dict[str, float]]) -> Dict[str, float]: + """Aggregate per-sample metrics into mean/min/max.""" + if not per_sample: + return {} + keys = per_sample[0].keys() + out: Dict[str, float] = {} + for k in keys: + vals = [s[k] for s in per_sample] + out[f"{k}_mean"] = float(np.mean(vals)) + out[f"{k}_std"] = float(np.std(vals)) + out[f"{k}_min"] = float(np.min(vals)) + out[f"{k}_max"] = float(np.max(vals)) + return out + + +def compute_sample_qoi( + pred: torch.Tensor, + target: torch.Tensor, + cell_centers: torch.Tensor, + cell_areas: torch.Tensor, + sigma_t: torch.Tensor, + sigma_s: torch.Tensor, + sample: Any, + case_type: str, +) -> Optional[Dict[str, Dict[str, float]]]: + """Compute QoI(pred) vs QoI(target) for one sample on the tensors' device. + + All tensor inputs may live on GPU; only the scalar QoI values are + materialized to host (via ``.item()``). Returns ``{region: {predicted, + ground_truth, absolute_error, relative_error_pct}}`` or ``None`` for the + hohlraum case when geometry params are missing from ``sample``. + + Args: + sample: For ``case_type="hohlraum"``, a per-sample mapping carrying + the eight 0-D float32 geometry tensors (``ulr`` ... ``cy``), + typically the batch sliced at index ``b``, or a fresh dict + built from those entries by the caller. Ignored for lattice. + """ + # The QoI evaluators expect ``(1, N)`` batched flux + flat (N,) cell fields. + pred_batched = pred.float().reshape(1, -1) + target_batched = target.float().reshape(1, -1) + centers = cell_centers.float() + areas = cell_areas.float().flatten() + sigma_t_flat = sigma_t.float().flatten() + sigma_s_flat = sigma_s.float().flatten() + # Placeholder — the final-time QoI evaluators accept ``sim_times`` only + # for callsite uniformity with the time-dependent variants. + sim_times = torch.zeros(1, device=pred.device) + + if case_type == "lattice": + qoi_pred = evaluate_lattice_qoi_torch( + centers, areas, sigma_t_flat, sigma_s_flat, pred_batched, sim_times + ) + qoi_target = evaluate_lattice_qoi_torch( + centers, areas, sigma_t_flat, sigma_s_flat, target_batched, sim_times + ) + elif case_type == "hohlraum": + geometry_params = extract_geometry_params(sample) + if not geometry_params: + return None + qoi_pred = evaluate_hohlraum_qoi_torch( + centers, + areas, + sigma_t_flat, + sigma_s_flat, + pred_batched, + sim_times, + geometry_params, + ) + qoi_target = evaluate_hohlraum_qoi_torch( + centers, + areas, + sigma_t_flat, + sigma_s_flat, + target_batched, + sim_times, + geometry_params, + ) + else: + raise ValueError(f"Unknown case_type: {case_type}") + + out: Dict[str, Dict[str, float]] = {} + for region in qoi_pred: + pred_value = float(qoi_pred[region][0].item()) + target_value = float(qoi_target[region][0].item()) + abs_err = abs(pred_value - target_value) + out[region] = { + "predicted": pred_value, + "ground_truth": target_value, + "absolute_error": abs_err, + "relative_error_pct": abs_err / (abs(target_value) + 1e-10) * 100.0, + } + return out + + +def aggregate_qoi( + per_sample_qoi: list[Dict[str, Dict[str, float]]], +) -> Dict[str, Dict[str, float]]: + """Aggregate per-sample QoI dicts into per-region summary statistics.""" + by_region: Dict[str, list] = {} + for sample in per_sample_qoi: + if not sample: + continue + for region, entry in sample.items(): + by_region.setdefault(region, []).append(entry) + + summary: Dict[str, Dict[str, float]] = {} + for region, entries in by_region.items(): + abs_errs = np.array([e["absolute_error"] for e in entries]) + rel_errs = np.array([e["relative_error_pct"] for e in entries]) + summary[region] = { + "num_samples": len(entries), + "mae": float(np.mean(abs_errs)), + "rmse": float(np.sqrt(np.mean(abs_errs**2))), + "max_error": float(np.max(abs_errs)), + "mean_relative_error_pct": float(np.mean(rel_errs)), + "median_relative_error_pct": float(np.median(rel_errs)), + "max_relative_error_pct": float(np.max(rel_errs)), + } + return summary diff --git a/examples/nuclear_engineering/radiation_transport/src/inference.py b/examples/nuclear_engineering/radiation_transport/src/inference.py new file mode 100644 index 0000000000..2ebe04f693 --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/inference.py @@ -0,0 +1,260 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any, Dict, Iterator, Optional, Tuple + +import hydra +import numpy as np +import torch +import torch.nn as nn +import yaml +from omegaconf import DictConfig, OmegaConf +from torch.amp import autocast +from tqdm import tqdm + +from physicsnemo.datapipes import DataLoader + +from checkpointing import load_model_from_checkpoint +from dataset import load_flux_stats +from evaluation_metrics import ( + aggregate_metrics, + aggregate_qoi, + compute_metrics, + compute_sample_qoi, +) +from loader import build_dataloaders, collate_no_padding +from transforms import denormalize_flux +from viz import plot_flux_panels, plot_qoi_true_vs_pred + +from physicsnemo.distributed import DistributedManager + + +@torch.no_grad() +def run_evaluation( + model: nn.Module, + dataloader: DataLoader, + device: torch.device, + flux_stats: Dict[str, float], + case_type: str, + use_amp: bool = True, + max_samples: Optional[int] = None, +) -> Iterator[ + Tuple[ + np.ndarray, + np.ndarray, + Optional[Dict[str, Dict[str, float]]], + Optional[np.ndarray], + Optional[str], + ] +]: + """Yield ``(prediction, target, qoi, coordinates, filename)`` per sample. + + Predictions and targets are denormalized to physical-flux units and + returned as flattened numpy arrays for downstream pointwise metrics and + plotting. The QoI dict (or ``None``) is computed on-device before the + GPU->CPU transfer to avoid round-tripping per-mesh tensors through numpy. + ``coordinates`` is the per-sample point cloud (or ``None`` if absent); + ``filename`` is the sidecar filename (or ``None``). + """ + model.eval() + n = 0 + + for batch in tqdm(dataloader, desc="evaluating"): + if max_samples is not None and n >= max_samples: + break + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } + + amp_enabled = use_amp and device.type == "cuda" + with autocast(device_type=device.type, enabled=amp_enabled): + pred = model(fx=batch["fx"], embedding=batch["embedding"]) + pred = pred.float() + target = batch["flux_target"].float() + + # Denormalize back to physical flux. ``denormalize_flux`` handles the + # full RTEFluxLogClip + Normalize inverse using the stats dict that + # the dataset transform recorded on the sample. + stats = batch.get("flux_normalization_stats", flux_stats) + if isinstance(stats, list): + stats = stats[0] if stats else flux_stats + + coords_t = batch.get("coordinates_unnormalized") + cell_areas_t = batch.get("cell_areas") + sigma_t_t = batch.get("sigma_t") + sigma_s_t = batch.get("sigma_s") + raw_meta = batch.get("metadata") or {} + if isinstance(raw_meta, list): + raw_meta = raw_meta[0] if raw_meta else {} + + # Batches always carry an outer batch dim of 1 (collate_no_padding). + for b in range(pred.shape[0]): + pred_phys_t = denormalize_flux(pred[b].squeeze(-1), stats).flatten() + target_phys_t = denormalize_flux(target[b].squeeze(-1), stats).flatten() + + qoi: Optional[Dict[str, Dict[str, float]]] = None + if ( + coords_t is not None + and cell_areas_t is not None + and sigma_t_t is not None + and sigma_s_t is not None + ): + qoi = compute_sample_qoi( + pred_phys_t, + target_phys_t, + coords_t[b], + cell_areas_t[b], + sigma_t_t[b], + sigma_s_t[b], + batch, + case_type, + ) + + coords_np: Optional[np.ndarray] = None + if coords_t is not None: + coords_np = coords_t[b].detach().cpu().numpy() + filename = raw_meta.get("filename") if isinstance(raw_meta, dict) else None + + n += 1 + yield ( + pred_phys_t.detach().cpu().numpy(), + target_phys_t.detach().cpu().numpy(), + qoi, + coords_np, + filename, + ) + if max_samples is not None and n >= max_samples: + return + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + """Hydra entry: load checkpoint, run evaluation, write metrics + figures.""" + DistributedManager.initialize() + # Full-mesh evaluation always — disable any training-time subsampling. + OmegaConf.update(cfg, "model.num_spatial_points", -1) + + output_dir = Path(cfg.inference.output_dir) + figures_dir = output_dir / "figures" + figures_dir.mkdir(parents=True, exist_ok=True) + + device = torch.device( + cfg.inference.device or ("cuda" if torch.cuda.is_available() else "cpu") + ) + + # Downstream calls (load_model_from_checkpoint, build_dataloaders -> + # MeshDataReader / split-file loader, load_flux_stats) all raise + # ``FileNotFoundError`` with the offending path if anything is missing. + model, _ = load_model_from_checkpoint( + Path(cfg.inference.checkpoint_path), cfg, device + ) + + # Build the test loader. ``test_batch_size=1`` matches the point-cloud + # adapter's invariant. + loaders, _ = build_dataloaders( + cfg, + dist=None, + collate_fn=collate_no_padding, + phases=("test",), + test_batch_size=1, + ) + test_loader = loaders["test"] + print(f"Test set size: {len(test_loader.dataset)}") + + flux_stats = load_flux_stats(cfg.data.flux_normalization_stats_file) + case_type = cfg.case.type + + # Evenly sample plot indices across the test set. + n_total = cfg.inference.num_samples or len(test_loader.dataset) + n_plots = cfg.inference.num_plot_samples + plot_indices: set[int] = set() + if n_plots > 0: + plot_indices = set(np.linspace(0, n_total - 1, n_plots, dtype=int).tolist()) + + per_sample_metrics: list[Dict[str, float]] = [] + per_sample_qoi: list[Dict[str, Dict[str, float]]] = [] + all_targets: list[np.ndarray] = [] + all_preds: list[np.ndarray] = [] + + for idx, (pred, target, qoi, coords, _filename) in enumerate( + run_evaluation( + model, + test_loader, + device, + flux_stats, + case_type, + use_amp=cfg.inference.use_amp, + max_samples=cfg.inference.num_samples, + ) + ): + per_sample_metrics.append(compute_metrics(pred, target)) + if qoi is not None: + per_sample_qoi.append(qoi) + all_targets.append(target) + all_preds.append(pred) + + if idx in plot_indices and coords is not None: + plot_flux_panels( + coords, + target, + pred, + figures_dir / f"flux_panels_{idx:04d}.png", + log_flux=case_type == "lattice", + ) + + if not per_sample_metrics: + raise RuntimeError("No samples evaluated; check the test split / data path.") + + # Aggregate metrics over every sample (concatenate first for global stats). + all_target_arr = np.concatenate(all_targets) + all_pred_arr = np.concatenate(all_preds) + overall_metrics = compute_metrics(all_pred_arr, all_target_arr) + aggregated = aggregate_metrics(per_sample_metrics) + + metrics_out: Dict[str, Any] = { + "num_samples": len(per_sample_metrics), + "overall": overall_metrics, + "per_sample_aggregate": aggregated, + } + with open(output_dir / "metrics.yaml", "w") as f: + yaml.safe_dump(metrics_out, f, sort_keys=False) + print("\nMetrics:") + for k, v in overall_metrics.items(): + print(f" {k}: {v:.6e}") + + # QoI summary. + if per_sample_qoi: + qoi_summary = aggregate_qoi(per_sample_qoi) + with open(output_dir / "qoi_metrics.yaml", "w") as f: + yaml.safe_dump(qoi_summary, f, sort_keys=False) + plot_qoi_true_vs_pred(per_sample_qoi, figures_dir / "qoi_true_vs_pred.png") + print("\nQoI summary:") + for region, stats in qoi_summary.items(): + print( + f" {region}: mae={stats['mae']:.4e}, " + f"mean_rel_err={stats['mean_relative_error_pct']:.3f}%" + ) + + print(f"\nResults written to: {output_dir}") + print(" metrics.yaml") + if per_sample_qoi: + print(" qoi_metrics.yaml") + + +if __name__ == "__main__": + main() diff --git a/examples/nuclear_engineering/radiation_transport/src/loader.py b/examples/nuclear_engineering/radiation_transport/src/loader.py new file mode 100644 index 0000000000..e0025520fa --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/loader.py @@ -0,0 +1,489 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, +) + +import torch +from omegaconf import DictConfig +from physicsnemo.datapipes import DataLoader +from physicsnemo.datapipes.registry import register +from physicsnemo.datapipes.transforms import Compose, Normalize, Scale, Translate +from physicsnemo.datapipes.transforms.base import Transform +from tensordict import TensorDict +from torch.utils.data import Sampler +from torch.utils.data.distributed import DistributedSampler + +from dataset import ( + RTEBaseDataset, + flux_normalize_kwargs, + load_flux_stats, + load_material_stats, + material_normalize_kwargs, +) +from transforms import ( + FourierFeatures, + MaterialPropertyExtractor, + RTEBackupCoords, + RTEFluxLogClip, + SpatialSampler, + FinalTimeSampler, + coord_translate_scale_params, +) + +__all__ = [ + "TransolverAdapter", + "collate_no_padding", + "build_dataloaders", +] + + +@register("RTETransolverAdapter") +class TransolverAdapter(Transform): + """Pack a transformed RTE ``TensorDict`` into Transolver-ready fields. + + Output TensorDict keys: + + * ``fx`` — spatial coordinates (plus Fourier features when enabled). + * ``embedding`` — material properties ``[sigma_a, sigma_s, sigma_t, Q]`` + (or just the first three when ``include_q_in_embedding=False``). + * ``flux_target`` — target flux to predict. + + Pass-through fields when present: ``coordinates_unnormalized``, + ``material_labels``, ``cell_areas``, ``sigma_t``, ``sigma_s``, + ``sim_time``, ``flux_normalization_stats`` (NonTensorData), and the + eight hohlraum geometry parameters (``ulr``, ``llr``, ``urr``, + ``lrr``, ``hlr``, ``hrr``, ``cx``, ``cy``). + + The output has no batch dimension; :func:`collate_no_padding` adds one. + """ + + def __init__(self, include_q_in_embedding: bool = True): + super().__init__() + self.include_q_in_embedding = include_q_in_embedding + + # Simple passthroughs: same key on both sides, no transform. + _PASSTHROUGH_KEYS = ( + "coordinates_unnormalized", + "cell_areas", + "sigma_t", + "sigma_s", + ) + + def __call__(self, data: TensorDict) -> TensorDict: + out = TensorDict({}, batch_size=data.batch_size, device=data.device) + + # Rename: coordinates -> fx (Transolver's positional input). + if "coordinates" in data: + out["fx"] = data["coordinates"] + + # Passthroughs. + for key in self._PASSTHROUGH_KEYS: + if key in data: + out[key] = data[key] + + # physical_properties -> embedding (optionally drop Q for hohlraum). + if "physical_properties" in data: + mat_props = data["physical_properties"] + if not self.include_q_in_embedding: + mat_props = mat_props[..., :3] + out["embedding"] = mat_props + + # material_properties -> material_labels (long dtype for embedding lookups). + if "material_properties" in data: + out["material_labels"] = data["material_properties"].to(dtype=torch.long) + + # flux_target promoted to shape (N, 1) if delivered as (N,). + if "flux_target" in data: + flux_tgt = data["flux_target"] + out["flux_target"] = ( + flux_tgt.unsqueeze(-1) if flux_tgt.ndim == 1 else flux_tgt + ) + + # sim_times -> single-scalar sim_time at the final snapshot; + # zero-tensor placeholder when the source series is empty. + if "sim_times" in data: + sim_times = data["sim_times"] + out["sim_time"] = ( + sim_times[-1].reshape(1).to(dtype=torch.float32) + if sim_times.numel() > 0 + else torch.zeros(1, dtype=torch.float32, device=data.device) + ) + + # NonTensorData passthroughs. + if "flux_normalization_stats" in data: + out.set_non_tensor( + "flux_normalization_stats", data["flux_normalization_stats"] + ) + + # Forward the eight hohlraum geometry parameters (0-D float32 + # tensors). Lattice samples never carry these keys. + for key in ("ulr", "llr", "urr", "lrr", "hlr", "hrr", "cx", "cy"): + if key in data: + out[key] = data[key] + + return out + + def extra_repr(self) -> str: + return f"include_q_in_embedding={self.include_q_in_embedding}" + + +@register("RTECollateNoPadding") +def collate_no_padding( + batch: Sequence[Tuple[TensorDict, Dict[str, Any]]], +) -> Dict[str, Any]: + """Batch-size-1 collate for the ``physicsnemo.datapipes.DataLoader``. + + Unsqueezes each tensor in the TensorDict to add a ``B=1`` leading + dim, passes NonTensorData entries through unchanged, and merges the + trailing metadata dict under ``batch["metadata"]``. Returns a plain + dict so downstream code can use ``batch["fx"]`` / ``batch["filename"]`` + without unpacking a TensorDict. ``build_dataloaders_for_training`` + enforces ``batch_size=1`` upstream so no padding is needed. + """ + assert len(batch) == 1, ( + f"collate_no_padding requires batch_size=1; got {len(batch)}" + ) + td, metadata = batch[0] + + out: Dict[str, Any] = {} + for key in td.keys(): + value = td[key] + out[key] = value.unsqueeze(0) if isinstance(value, torch.Tensor) else value + + # Merge the trailing metadata dict (filename / case_type / num_cells + # / num_timesteps / max_sim_time) under ``out["metadata"]``. Surface + # ``filename`` at the top level too for callers that use + # ``batch["filename"]`` directly (e.g. inference figure naming). + if metadata: + existing = out.get("metadata") or {} + merged = {**metadata, **existing} + out["metadata"] = merged + if "filename" in merged and "filename" not in out: + out["filename"] = merged["filename"] + return out + + +def _build_rte_dataset_kwargs(cfg: DictConfig) -> dict: + """Translate a Hydra config into the kwargs ``_build_rte_dataset`` expects.""" + data_cfg = cfg.data + use_fourier_features = data_cfg.get("use_fourier_features", False) + fourier_cfg = data_cfg.get("fourier_features") if use_fourier_features else None + + return { + "data_path": cfg.case.data_path, + "num_spatial_points": cfg.model.num_spatial_points, + "flux_normalization_stats_file": data_cfg.flux_normalization_stats_file, + "normalize_coordinates": data_cfg.get("normalize_coordinates", True), + "flux_clip_threshold": data_cfg.flux_clip_threshold, + "split_file": cfg.case.split_file, + "seed": data_cfg.get("seed") or cfg.train.get("seed"), + "cache_static_arrays": data_cfg.get("cache_static_arrays", True), + "include_q_in_embedding": cfg.model.get("include_q_in_embedding", True), + "use_fourier_features": use_fourier_features, + "fourier_num_frequencies": fourier_cfg.num_frequencies if fourier_cfg else None, + "fourier_coord_dims": fourier_cfg.coord_dims if fourier_cfg else None, + "fourier_base_frequency": fourier_cfg.base_frequency if fourier_cfg else None, + } + + +def _build_rte_dataset( + case_type: str, + data_path: Union[str, Path], + phase: str, + num_spatial_points: int, + flux_normalization_stats_file: Union[str, Path], + normalize_coordinates: bool, + flux_clip_threshold: float, + split_file: Union[str, Path], + seed: Optional[int], + cache_static_arrays: bool, + include_q_in_embedding: bool, + use_fourier_features: bool, + fourier_num_frequencies: Optional[int], + fourier_coord_dims: Optional[int], + fourier_base_frequency: Optional[float], + device: Optional[Union[str, torch.device]] = None, +) -> RTEBaseDataset: + """Build the canonical training/inference RTE dataset (transforms baked in).""" + if case_type not in ("lattice", "hohlraum"): + raise ValueError( + f"Unknown case_type: {case_type!r}. Expected 'lattice' or 'hohlraum'." + ) + + transforms = _build_transforms( + case_type=case_type, + flux_normalization_stats_file=flux_normalization_stats_file, + flux_clip_threshold=flux_clip_threshold, + seed=seed, + num_spatial_points=num_spatial_points, + normalize_coordinates=normalize_coordinates, + use_fourier_features=use_fourier_features, + fourier_num_frequencies=fourier_num_frequencies, + fourier_coord_dims=fourier_coord_dims, + fourier_base_frequency=fourier_base_frequency, + include_q_in_embedding=include_q_in_embedding, + ) + + return RTEBaseDataset( + data_path=data_path, + case_type=case_type, + phase=phase, + split_file=split_file, + seed=seed, + cache_static_arrays=cache_static_arrays, + transforms=transforms, + device=device, + ) + + +def _build_transforms( + case_type: str, + flux_normalization_stats_file: Union[str, Path], + flux_clip_threshold: float, + seed: Optional[int], + num_spatial_points: int, + normalize_coordinates: bool, + use_fourier_features: bool, + fourier_num_frequencies: int, + fourier_coord_dims: int, + fourier_base_frequency: float, + include_q_in_embedding: bool = True, +) -> Compose: + """Assemble the canonical RTE transform pipeline.""" + flux_stats = load_flux_stats(flux_normalization_stats_file) + if abs(flux_stats["clip_threshold"] - flux_clip_threshold) > 1e-10: + raise ValueError( + f"Clip threshold mismatch: got {flux_clip_threshold}, " + f"stats computed with {flux_stats['clip_threshold']}" + ) + + transform_list: List[Transform] = [ + RTEFluxLogClip( + clip_threshold=flux_clip_threshold, + log_flux_mean=flux_stats["log_flux_mean"], + log_flux_std=flux_stats["log_flux_std"], + ), + Normalize(**flux_normalize_kwargs(flux_stats, field="scalar_flux")), + ] + + transform_list.append(FinalTimeSampler()) + transform_list.append(MaterialPropertyExtractor()) + + material_stats_path = ( + Path(flux_normalization_stats_file).parent / f"{case_type}_material_stats.yaml" + ) + if not material_stats_path.exists(): + raise FileNotFoundError( + f"Material statistics file not found: {material_stats_path}\n" + f"Run src/compute_normalizations.py to generate it." + ) + material_stats = load_material_stats(material_stats_path) + transform_list.append( + Normalize( + **material_normalize_kwargs(material_stats, field="physical_properties") + ) + ) + + transform_list.append(SpatialSampler(num_points=num_spatial_points, seed=seed)) + + if normalize_coordinates: + center, half_extent = coord_translate_scale_params(case_type) + transform_list.append(RTEBackupCoords()) + transform_list.append( + Translate( + input_keys=["coordinates"], + center_key_or_value=center, + subtract=True, + ) + ) + transform_list.append( + Scale( + input_keys=["coordinates"], + scale=half_extent, + divide=True, + ) + ) + + if use_fourier_features: + transform_list.append( + FourierFeatures( + num_frequencies=fourier_num_frequencies, + coord_dims=fourier_coord_dims, + base_frequency=fourier_base_frequency, + append_to_coordinates=True, + ) + ) + + transform_list.append( + TransolverAdapter(include_q_in_embedding=include_q_in_embedding) + ) + + return Compose(transform_list) + + +def _make_loader( + dataset, + cfg: DictConfig, + phase: str, + sampler: Optional[Sampler], + collate_fn: Optional[Callable], + test_batch_size: int, +) -> DataLoader: + """Assemble a :class:`physicsnemo.datapipes.DataLoader` for one phase. + + The ``test`` phase has no matching ``cfg.test.*`` block; callers pass + ``test_batch_size`` explicitly. Stream-based prefetching defaults + (``num_streams=4``, ``use_streams=true``) come from the per-phase + Hydra config when present. + """ + if phase == "test": + return DataLoader( + dataset, + batch_size=test_batch_size, + shuffle=False, + collate_fn=collate_fn, + ) + + phase_cfg = cfg.train.dataloader if phase == "train" else cfg.train.val.dataloader + sampler_cfg = cfg.train.sampler if phase == "train" else cfg.train.val.sampler + + # sampler handles shuffling when present; keep ``shuffle=False`` to avoid + # the "sampler is incompatible with shuffle" path inside the DataLoader. + shuffle_train = sampler_cfg.shuffle if phase == "train" else False + shuffle = shuffle_train if sampler is None else False + + seed = cfg.train.get("seed", None) + seed = int(seed) if seed is not None else None + + return DataLoader( + dataset, + batch_size=phase_cfg.batch_size, + shuffle=shuffle, + drop_last=sampler_cfg.get("drop_last", False), + sampler=sampler, + collate_fn=collate_fn, + prefetch_factor=phase_cfg.get("prefetch_factor", 2), + num_streams=phase_cfg.get("num_streams", 4), + use_streams=phase_cfg.get("use_streams", True), + seed=seed, + ) + + +def build_dataloaders( + cfg: DictConfig, + dist=None, + collate_fn: Optional[Callable] = None, + phases: Iterable[str] = ("train", "val"), + test_batch_size: int = 1, + logger: Optional[logging.Logger] = None, +) -> Tuple[Dict[str, DataLoader], Optional[DistributedSampler]]: + """Build per-phase DataLoaders for training and/or evaluation. + + Args: + cfg: Hydra configuration (training cfg or a loaded checkpoint cfg). + dist: ``DistributedManager`` for training; ``None`` for eval. + collate_fn: Collate function. Defaults to :func:`collate_no_padding`. + phases: Which splits to build (subset of ``{"train", "val", "test"}``). + test_batch_size: Used only when ``test`` is in ``phases``. + logger: Optional logger; defaults to module logger. + + Returns: + ``({phase: DataLoader}, train_sampler)``. ``train_sampler`` is + ``None`` when ``train`` is not in ``phases`` or ``dist`` is not + distributed. + """ + logger = logger or logging.getLogger(__name__) + phases = tuple(phases) + + if collate_fn is None: + collate_fn = collate_no_padding + + rank_zero = dist is None or dist.rank == 0 + + if rank_zero: + logger.info(f"Loading {cfg.case.type} data from: {cfg.case.data_path}") + + common_kwargs = _build_rte_dataset_kwargs(cfg) + + if rank_zero: + logger.info("Mapping mode: first-snapshot -> final-time flux") + if common_kwargs["split_file"]: + logger.info(f"Using predefined splits from: {common_kwargs['split_file']}") + + if dist is not None and getattr(dist, "device", None) is not None: + device = dist.device + else: + device = "cuda" if torch.cuda.is_available() else "cpu" + + datasets = { + phase: _build_rte_dataset( + cfg.case.type, phase=phase, device=device, **common_kwargs + ) + for phase in phases + } + + if rank_zero: + split_summary = ", ".join(f"{p}={len(datasets[p])}" for p in phases) + logger.info(f"\nData split summary: {split_summary}") + + # Samplers + loaders. + train_sampler: Optional[DistributedSampler] = None + loaders: Dict[str, DataLoader] = {} + for phase in phases: + sampler = None + if dist is not None and dist.distributed and phase in ("train", "val"): + if phase == "train": + sampler = DistributedSampler( + datasets[phase], + num_replicas=dist.world_size, + rank=dist.rank, + shuffle=cfg.train.sampler.shuffle, + drop_last=cfg.train.sampler.get("drop_last", False), + seed=int(cfg.train.get("seed", 0) or 0), + ) + train_sampler = sampler + else: + sampler = DistributedSampler( + datasets[phase], + num_replicas=dist.world_size, + rank=dist.rank, + shuffle=False, + ) + + loaders[phase] = _make_loader( + datasets[phase], + cfg, + phase, + sampler, + collate_fn, + test_batch_size=test_batch_size, + ) + + return loaders, train_sampler diff --git a/examples/nuclear_engineering/radiation_transport/src/losses.py b/examples/nuclear_engineering/radiation_transport/src/losses.py new file mode 100644 index 0000000000..566233232e --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/losses.py @@ -0,0 +1,400 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Mapping, Optional + +import torch +from omegaconf import DictConfig + +from qoi import ( + evaluate_hohlraum_qoi_torch, + evaluate_lattice_qoi_torch, + extract_geometry_params, +) +from transforms import denormalize_flux + +__all__ = [ + # Schedulers + "create_scheduler", + # Regression losses + "region_weighted_loss_fn", + "parse_loss_config", + "physics_loss_weight_for_epoch", + # Physics loss + "compute_physics_loss", + "compute_lattice_qoi_loss", + "compute_hohlraum_qoi_loss", +] + + +def create_scheduler(cfg: DictConfig, optimizer: torch.optim.Optimizer, logger=None): + """Build the LR scheduler: linear warmup chained into cosine annealing.""" + warmup_epochs = cfg.train.get("warmup_epochs", 5) + peak_lr = cfg.train.learning_rate + min_lr = cfg.train.get("min_learning_rate", 1e-6) + total_epochs = cfg.train.epochs + + if logger: + logger.info("\nLearning rate schedule (warmup + cosine):") + logger.info(f" Peak LR: {peak_lr}") + logger.info(f" Min LR: {min_lr}") + logger.info(f" Warmup epochs: {warmup_epochs}") + + warmup = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=min_lr / peak_lr, + end_factor=1.0, + total_iters=max(warmup_epochs, 1), + ) + cosine = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=max(total_epochs - warmup_epochs, 1), + eta_min=min_lr, + ) + return torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup, cosine], + milestones=[warmup_epochs], + ) + + +_VOID_LABELS = {"hohlraum": 4, "lattice": 2} + + +def region_weighted_loss_fn( + output: torch.Tensor, + target: torch.Tensor, + material_labels: torch.Tensor, + case_type: str, + void_weight: float = 3.0, + material_weight: float = 1.0, +) -> torch.Tensor: + """Weighted MSE that penalizes void cells more than material cells. + + Material-label definitions (set by ``MaterialPropertyExtractor``): + + * Hohlraum: ``0`` black wall, ``1`` red wall, ``2`` green wall, + ``3`` blue capsule (all material); ``4`` white fill gas (void). + * Lattice: ``0`` blue absorber, ``1`` red scattering source + (material); ``2`` white background (void). + + Void cells are where radiation streams through and the surrogate has + to capture fine flux features, so we weight their squared error more + heavily. + + Args: + output, target: Predicted vs ground-truth flux, shape ``(B, N, 1)``. + material_labels: Per-cell label, shape ``(B, N)`` or ``(B, N, 1)``. + case_type: ``"hohlraum"`` or ``"lattice"``. + void_weight, material_weight: Per-region weights. + + Returns: + Scalar weighted-MSE loss. + """ + if case_type not in _VOID_LABELS: + raise ValueError( + f"Unknown case_type: {case_type}. Must be 'hohlraum' or 'lattice'." + ) + + labels = ( + material_labels.squeeze(-1) if material_labels.dim() == 3 else material_labels + ) + is_void = labels == _VOID_LABELS[case_type] # (B, N) bool + weights = ( + torch.where(is_void, float(void_weight), float(material_weight)) + .to(dtype=torch.float32) + .unsqueeze(-1) + ) # (B, N, 1) + + squared_error = (output - target) ** 2 + return (weights * squared_error).sum() / (weights.sum() + 1e-8) + + +def parse_loss_config( + cfg: DictConfig, + dist: Any, + logger: Any, +) -> dict: + """ + Parse the common loss configuration options shared across all models: + physics loss (including warmup schedule), region-weighted loss. + + The returned ``physics_loss_weight`` is the **base** weight; per-epoch + warmup ramping is applied by :func:`physics_loss_weight_for_epoch` inside + the trainer loop. + + Args: + cfg: Hydra config + dist: DistributedManager (only ``dist.rank`` is read) + logger: Logger + + Returns: + Dict with keys: ``use_physics_loss``, ``physics_loss_weight``, + ``physics_loss_mse_weight``, ``physics_loss_warmup_epochs``, + ``physics_loss_warmup_start_fraction``, + ``use_region_weighted_loss``, ``region_weight_cfg``. + """ + use_physics_loss = cfg.train.get("use_physics_loss", False) + if use_physics_loss: + physics_loss_weight = cfg.train.physics_loss.weight + physics_loss_mse_weight = cfg.train.physics_loss.mse_weight + physics_loss_warmup_epochs = cfg.train.physics_loss.get("warmup_epochs", 0) + physics_loss_warmup_start_fraction = cfg.train.physics_loss.get( + "warmup_start_fraction", 0.0 + ) + else: + physics_loss_weight = 0.0 + physics_loss_mse_weight = 1.0 + physics_loss_warmup_epochs = 0 + physics_loss_warmup_start_fraction = 0.0 + + use_region_weighted_loss = cfg.train.get("use_region_weighted_loss", False) + region_weight_cfg = { + "void_weight": cfg.train.get("region_weights", {}).get("void_weight", 3.0), + "material_weight": cfg.train.get("region_weights", {}).get( + "material_weight", 1.0 + ), + } + + if dist.rank == 0: + if use_physics_loss: + logger.info("\nPhysics loss configuration:") + logger.info(f" Weight: {physics_loss_weight}") + logger.info(f" MSE weight: {physics_loss_mse_weight}") + if physics_loss_warmup_epochs > 0: + logger.info(f" Warmup epochs: {physics_loss_warmup_epochs}") + logger.info( + f" Warmup start fraction: {physics_loss_warmup_start_fraction}" + ) + if use_region_weighted_loss: + logger.info("Region-weighted loss: enabled") + logger.info(f" Void weight: {region_weight_cfg['void_weight']}") + logger.info(f" Material weight: {region_weight_cfg['material_weight']}") + + return { + "use_physics_loss": use_physics_loss, + "physics_loss_weight": physics_loss_weight, + "physics_loss_mse_weight": physics_loss_mse_weight, + "physics_loss_warmup_epochs": physics_loss_warmup_epochs, + "physics_loss_warmup_start_fraction": physics_loss_warmup_start_fraction, + "use_region_weighted_loss": use_region_weighted_loss, + "region_weight_cfg": region_weight_cfg, + } + + +def physics_loss_weight_for_epoch(loss_cfg: dict, epoch: int) -> float: + """Linear ramp of the physics-loss weight over the warmup window. + + Ramps from ``warmup_start_fraction * base`` at epoch 0 to ``base`` at + ``warmup_epochs``, then stays at ``base``. With no warmup configured + (``warmup_epochs <= 0``), returns ``base`` unchanged. + """ + base = loss_cfg.get("physics_loss_weight", 0.0) + warmup_epochs = loss_cfg.get("physics_loss_warmup_epochs", 0) + if warmup_epochs <= 0 or epoch >= warmup_epochs: + return base + start_frac = loss_cfg.get("physics_loss_warmup_start_fraction", 0.0) + progress = epoch / max(1, warmup_epochs) + return (start_frac + (1.0 - start_frac) * progress) * base + + +def _relative_squared_error_loss( + pred: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-10, +) -> torch.Tensor: + """Mean of ``((pred - target) / |target|)^2`` over finite cells. + + Returns ``0.0`` (no graph) when every cell is non-finite — degenerate but + keeps the trainer alive instead of propagating NaN. + """ + squared = ((pred - target) / (torch.abs(target) + epsilon)) ** 2 + is_valid = torch.isfinite(squared) & torch.isfinite(pred) & torch.isfinite(target) + if not is_valid.any(): + return torch.zeros((), device=pred.device) + return squared[is_valid].mean() + + +def _prepare_for_qoi( + pred: torch.Tensor, + target: torch.Tensor, + sim_time: torch.Tensor, + stats: Optional[Mapping[str, Any]], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Squeeze ``(B, N, 1) -> (B, N)``, denormalize, then ``(B, 1, N)`` for QoI.""" + if pred.ndim == 3: + pred = pred.squeeze(-1) + if target.ndim == 3: + target = target.squeeze(-1) + if stats is not None: + pred = denormalize_flux(pred, stats) + target = denormalize_flux(target, stats) + sim_times = sim_time.unsqueeze(-1) if sim_time.ndim == 1 else sim_time + return pred.unsqueeze(1), target.unsqueeze(1), sim_times + + +def compute_lattice_qoi_loss( + predicted_flux: torch.Tensor, + target_flux: torch.Tensor, + cell_centers: torch.Tensor, + cell_areas: torch.Tensor, + sigma_t: torch.Tensor, + sigma_s: torch.Tensor, + sim_time: torch.Tensor, + flux_normalization_stats: Optional[Mapping[str, Any]] = None, + epsilon: float = 1e-10, +) -> tuple[torch.Tensor, dict[str, float]]: + """Relative-squared-error loss on the lattice absorption QoI. + + QoIs are evaluated in physical flux space; if normalization stats are + supplied, both flux tensors are denormalized first. Differentiable end + to end so the loss backprops into the model. + """ + pred_qoi, target_qoi, sim_times = _prepare_for_qoi( + predicted_flux, target_flux, sim_time, flux_normalization_stats + ) + qoi_pred = evaluate_lattice_qoi_torch( + cell_centers, + cell_areas, + sigma_t, + sigma_s, + pred_qoi, + sim_times, + ) + with torch.no_grad(): + qoi_target = evaluate_lattice_qoi_torch( + cell_centers, + cell_areas, + sigma_t, + sigma_s, + target_qoi, + sim_times, + ) + loss = _relative_squared_error_loss( + qoi_pred["cur_absorption"][:, 0], + qoi_target["cur_absorption"][:, 0], + epsilon, + ) + return loss, {"loss_qoi_absorption": loss.item()} + + +def compute_hohlraum_qoi_loss( + predicted_flux: torch.Tensor, + target_flux: torch.Tensor, + cell_centers: torch.Tensor, + cell_areas: torch.Tensor, + sigma_t: torch.Tensor, + sigma_s: torch.Tensor, + sim_time: torch.Tensor, + geometry_params: dict, + flux_normalization_stats: Optional[Mapping[str, Any]] = None, + epsilon: float = 1e-10, +) -> tuple[torch.Tensor, dict[str, float]]: + """Mean of the four hohlraum region relative-squared-error losses. + + Loss = mean of {center, vertical, horizontal, total} so every region + contributes to the gradient. All four are recorded in the details dict. + """ + pred_qoi, target_qoi, sim_times = _prepare_for_qoi( + predicted_flux, target_flux, sim_time, flux_normalization_stats + ) + qoi_pred = evaluate_hohlraum_qoi_torch( + cell_centers, + cell_areas, + sigma_t, + sigma_s, + pred_qoi, + sim_times, + geometry_params, + ) + with torch.no_grad(): + qoi_target = evaluate_hohlraum_qoi_torch( + cell_centers, + cell_areas, + sigma_t, + sigma_s, + target_qoi, + sim_times, + geometry_params, + ) + + region_losses: dict[str, torch.Tensor] = {} + pred_sum = target_sum = None + for key in ( + "cur_absorption_center", + "cur_absorption_vertical", + "cur_absorption_horizontal", + ): + p, t = qoi_pred[key][:, 0], qoi_target[key][:, 0] + region_losses[key.removeprefix("cur_absorption_")] = ( + _relative_squared_error_loss(p, t, epsilon) + ) + pred_sum = p if pred_sum is None else pred_sum + p + target_sum = t if target_sum is None else target_sum + t + region_losses["total"] = _relative_squared_error_loss(pred_sum, target_sum, epsilon) + + loss = torch.stack(list(region_losses.values())).mean() + details = {f"loss_qoi_{name}": val.item() for name, val in region_losses.items()} + return loss, details + + +def compute_physics_loss( + case_type: str, + predicted_flux: torch.Tensor, + target_flux: torch.Tensor, + cell_centers: torch.Tensor, + cell_areas: torch.Tensor, + sigma_t: torch.Tensor, + sigma_s: torch.Tensor, + sim_time: torch.Tensor, + sample=None, + flux_normalization_stats: dict | None = None, + qoi_epsilon: float = 1e-10, +) -> tuple[torch.Tensor, dict[str, float]]: + """Dispatch the per-case QoI loss; returns ``(loss, per-region details)``.""" + common = dict( + predicted_flux=predicted_flux, + target_flux=target_flux, + cell_centers=cell_centers, + cell_areas=cell_areas, + sigma_t=sigma_t, + sigma_s=sigma_s, + sim_time=sim_time, + flux_normalization_stats=flux_normalization_stats, + epsilon=qoi_epsilon, + ) + if case_type == "lattice": + return compute_lattice_qoi_loss(**common) + if case_type == "hohlraum": + if sample is None: + raise ValueError( + "hohlraum physics loss requires the sample TensorDict to read " + "geometry parameters (ulr, llr, urr, lrr, hlr, hrr, cx, cy)" + ) + geometry_params = extract_geometry_params(sample) + if not geometry_params: + raise ValueError( + "could not read hohlraum geometry parameters from the sample " + "TensorDict; expected 8 0-D float32 tensors (ulr, llr, urr, " + "lrr, hlr, hrr, cx, cy) on the TD top level (see " + "MeshDataReader.load)" + ) + return compute_hohlraum_qoi_loss(**common, geometry_params=geometry_params) + raise ValueError( + f"Unknown case type: {case_type}. Must be 'lattice' or 'hohlraum'." + ) diff --git a/examples/nuclear_engineering/radiation_transport/src/qoi.py b/examples/nuclear_engineering/radiation_transport/src/qoi.py new file mode 100644 index 0000000000..e1c613b7d3 --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/qoi.py @@ -0,0 +1,230 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Differentiable PyTorch QoI evaluators for the RTE benchmarks. + +These match KiT-RT's SNSolverHPC::IterPostprocessing() and are shared by the +training-time physics loss (losses.py) and the evaluation-time QoI metrics +(inference.py). + +https://github.com/KiT-RT/kitrt_code/blob/d257b1a3c6fb3fa13d8a346adca5360a95101932/src/solvers/snsolver_hpc.cpp#L594 + +The evaluators are differentiable and final-time (T=1); ``sim_times`` is +accepted only for callsite uniformity with the time-dependent variants. +""" + +from __future__ import annotations + +import torch + +__all__ = [ + "evaluate_lattice_qoi_torch", + "evaluate_hohlraum_qoi_torch", + "extract_geometry_params", +] + + +_HOHLRAUM_GEOMETRY_KEYS = ("ulr", "llr", "urr", "lrr", "hlr", "hrr", "cx", "cy") + + +def extract_geometry_params(sample) -> dict: + """Extract hohlraum geometry parameters from a sample TensorDict. + + Reads the eight 0-D float32 tensors that the curator writes into + ``mesh.global_data`` for hohlraum stores (``ulr, llr, urr, lrr, hlr, + hrr, cx, cy``) and that :meth:`MeshDataReader.load` promotes to the + TensorDict top level. Returns ``{}`` if any key is missing (e.g. on a + lattice sample, which has no geometry parameters). + """ + if sample is None: + return {} + try: + if not all(k in sample for k in _HOHLRAUM_GEOMETRY_KEYS): + return {} + except TypeError: + return {} + + out: dict = {} + for k in _HOHLRAUM_GEOMETRY_KEYS: + v = sample[k] + if hasattr(v, "ndim") and v.ndim > 0: + # Batched value (e.g. shape ``(B,)``): collapse to a single + # scalar by picking the first entry. Geometry parameters are + # static per simulation, so every batch element matches. + v = v.reshape(-1)[0] + out[k] = float(v.item() if hasattr(v, "item") else v) + return out + + +def evaluate_lattice_qoi_torch( + cell_centers: torch.Tensor, + cell_areas: torch.Tensor, + sigma_t: torch.Tensor, + sigma_s: torch.Tensor, + scalar_flux: torch.Tensor, + sim_times: torch.Tensor, +) -> dict[str, torch.Tensor]: + """Lattice absorption QoI, differentiable. + + When the leading dim of ``cell_centers`` is a (size-1) batch dim, the + call recurses on the squeezed slot and re-adds the dim on the way out. + + Args: + cell_centers: (N, 2) or (1, N, 2) + cell_areas: (N,) or (1, N) + sigma_t: (N,) or (1, N) + sigma_s: (N,) or (1, N) + scalar_flux: (T, N) or (1, T, N) — only T=1 is exercised + sim_times: (T,) or (1, T) — accepted for callsite uniformity, unused + + Returns: + ``{"cur_absorption": (T,) or (1, T)}`` + """ + if cell_centers.ndim == 3: + if cell_centers.shape[0] != 1: + raise NotImplementedError( + "evaluate_lattice_qoi_torch only supports batch_size=1; " + f"got batch={cell_centers.shape[0]}." + ) + result = evaluate_lattice_qoi_torch( + cell_centers[0], + cell_areas[0], + sigma_t[0], + sigma_s[0], + scalar_flux[0], + sim_times[0] if sim_times.ndim == 2 else sim_times, + ) + return {k: v.unsqueeze(0) for k, v in result.items()} + + if scalar_flux.ndim != 2: + raise ValueError(f"Expected scalar_flux shape (T, N), got {scalar_flux.shape}") + + x = cell_centers[:, 0] + y = cell_centers[:, 1] + sigma_a = sigma_t - sigma_s + + xy_corrector = -3.5 + lbounds = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + xy_corrector + ubounds = torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0]) + xy_corrector + + in_absorption = torch.zeros_like(x, dtype=torch.bool) + for k in range(5): + for l in range(5): # noqa: E741 + if (l + k) % 2 == 1: + continue + if (k == 2 and l == 2) or (k == 2 and l == 4): + continue + in_square = ( + (x >= lbounds[k]) + & (x <= ubounds[k]) + & (y >= lbounds[l]) + & (y <= ubounds[l]) + ) + in_absorption = in_absorption | in_square + + absorption_density = scalar_flux * sigma_a.unsqueeze(0) * cell_areas.unsqueeze(0) + cur_absorption = torch.sum( + absorption_density * in_absorption.unsqueeze(0).to(dtype=torch.float32), + dim=1, + ) + return {"cur_absorption": cur_absorption} + + +def evaluate_hohlraum_qoi_torch( + cell_centers: torch.Tensor, + cell_areas: torch.Tensor, + sigma_t: torch.Tensor, + sigma_s: torch.Tensor, + scalar_flux: torch.Tensor, + sim_times: torch.Tensor, + geometry_params: dict[str, float], +) -> dict[str, torch.Tensor]: + """Hohlraum per-region absorption QoI, differentiable. + + Three regions are returned: ``center`` (the capsule volume), ``vertical`` + (red wall strips on either x boundary), and ``horizontal`` (the top + bottom + strips). The vertical-wall predicate uses ``pos_red_left_bottom`` for both + sides — see the inline ``NOTE`` for why. + + When the leading dim of ``cell_centers`` is a (size-1) batch dim, the + call recurses on the squeezed slot and re-adds the dim on the way out. + + Args: + cell_centers: (N, 2) or (1, N, 2) + cell_areas: (N,) or (1, N) + sigma_t: (N,) or (1, N) + sigma_s: (N,) or (1, N) + scalar_flux: (T, N) or (1, T, N) — only T=1 is exercised + sim_times: (T,) or (1, T) — accepted for callsite uniformity, unused + geometry_params: dict with cx, cy, hlr, hrr, llr, ulr, lrr, urr + + Returns: + Dict with ``cur_absorption_{center,vertical,horizontal}``. + """ + if cell_centers.ndim == 3: + if cell_centers.shape[0] != 1: + raise NotImplementedError( + "evaluate_hohlraum_qoi_torch only supports batch_size=1; " + f"got batch={cell_centers.shape[0]}." + ) + result = evaluate_hohlraum_qoi_torch( + cell_centers[0], + cell_areas[0], + sigma_t[0], + sigma_s[0], + scalar_flux[0], + sim_times[0] if sim_times.ndim == 2 else sim_times, + geometry_params, + ) + return {k: v.unsqueeze(0) for k, v in result.items()} + + if scalar_flux.ndim != 2: + raise ValueError(f"Expected scalar_flux shape (T, N), got {scalar_flux.shape}") + + x = cell_centers[:, 0] + y = cell_centers[:, 1] + + cx = geometry_params["cx"] + cy = geometry_params["cy"] + pos_red_left_border = geometry_params["hlr"] + pos_red_right_border = geometry_params["hrr"] + pos_red_left_bottom = geometry_params["llr"] + pos_red_left_top = geometry_params["ulr"] + pos_red_right_top = geometry_params["urr"] + + sigma_a = sigma_t - sigma_s + + in_center = (x > -0.2 + cx) & (x < 0.2 + cx) & (y > -0.4 + cy) & (y < 0.4 + cy) + # NOTE: matches KiT-RT's behavior of using pos_red_left_bottom for both sides + in_vertical = ( + (x < pos_red_left_border) & (y > pos_red_left_bottom) & (y < pos_red_left_top) + ) | ( + (x > pos_red_right_border) & (y > pos_red_left_bottom) & (y < pos_red_right_top) + ) + in_horizontal = (y > 0.6) | (y < -0.6) + + absorption_density = scalar_flux * sigma_a.unsqueeze(0) * cell_areas.unsqueeze(0) + + def _region_sum(mask: torch.Tensor) -> torch.Tensor: + return torch.sum( + absorption_density * mask.unsqueeze(0).to(dtype=torch.float32), dim=1 + ) + + return { + "cur_absorption_center": _region_sum(in_center), + "cur_absorption_vertical": _region_sum(in_vertical), + "cur_absorption_horizontal": _region_sum(in_horizontal), + } diff --git a/examples/nuclear_engineering/radiation_transport/src/train.py b/examples/nuclear_engineering/radiation_transport/src/train.py new file mode 100644 index 0000000000..4f29e7f24b --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/train.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from typing import Any, Optional, Tuple + +import hydra +import torch +import torch.nn as nn +from omegaconf import DictConfig, OmegaConf +from torch.amp import GradScaler +from torch.utils.tensorboard import SummaryWriter + +from physicsnemo.datapipes import DataLoader +from physicsnemo.utils.logging.launch import LaunchLogger + +from checkpointing import create_optimizer, resume_if_available +from loader import build_dataloaders, collate_no_padding +from losses import create_scheduler, parse_loss_config +from trainer import ( + parse_amp, + run_training_loop, + set_seed, + setup_training_environment, + wrap_ddp, +) + + +def build_model(cfg: DictConfig, device: torch.device) -> nn.Module: + """Instantiate the Transolver model from the Hydra ``model`` group. + + Two RTE-specific keys (``num_spatial_points``, ``include_q_in_embedding``) + are stripped from the config before ``hydra.utils.instantiate`` because + they are consumed by the data pipeline, not the model constructor. + """ + cfg_model = OmegaConf.to_container(cfg.model, resolve=True) + for k in ("num_spatial_points", "include_q_in_embedding"): + cfg_model.pop(k, None) + return hydra.utils.instantiate(cfg_model).to(device) + + +def build_dataloaders_for_training( + cfg: DictConfig, dist: Any, logger: Any +) -> Tuple[DataLoader, DataLoader, Optional[Any]]: + """Build train / val DataLoaders for the Transolver point-cloud adapter.""" + if cfg.train.dataloader.batch_size != 1: + raise ValueError( + "Only batch_size=1 is supported for the Transolver point-cloud adapter." + ) + loaders, train_sampler = build_dataloaders( + cfg, + dist, + collate_fn=collate_no_padding, + phases=("train", "val"), + logger=logger, + ) + return loaders["train"], loaders["val"], train_sampler + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + """Train the Transolver RTE surrogate.""" + dist, logger = setup_training_environment(cfg, "Transolver") + + seed = cfg.train.get("seed", None) + if seed is not None: + set_seed(seed + dist.rank if dist.distributed else seed) + logger.info(f"Random seed: {seed}") + else: + logger.info("Random seed: not set (non-reproducible)") + + grad_accum_steps = cfg.train.get("gradient_accumulation_steps", 1) + use_amp, amp_dtype = parse_amp(cfg) + + amp_info = ( + f"ENABLED (dtype={cfg.train.get('amp_dtype', 'bf16')})" + if use_amp + else "DISABLED" + ) + batch_size = cfg.train.dataloader.batch_size + world_size = dist.world_size if dist.distributed else 1 + logger.info(f"Device: {dist.device}") + logger.info(f"Batch size: {batch_size}") + logger.info(f"Gradient accumulation steps: {grad_accum_steps}") + logger.info(f"AMP (mixed precision): {amp_info}") + logger.info(f"Effective batch size: {batch_size * grad_accum_steps * world_size}") + + train_loader, val_loader, _ = build_dataloaders_for_training(cfg, dist, logger) + + logger.info("\nInitializing Transolver model...") + model = build_model(cfg, dist.device) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Transolver initialized — {num_params:,} trainable parameters") + model = wrap_ddp(model, dist, logger) + + optimizer_cfg = cfg.train.get("optimizer", {}) + optimizer = create_optimizer( + model=model, + optimizer_type=optimizer_cfg.get("type", "adam"), + learning_rate=cfg.train.learning_rate, + weight_decay=optimizer_cfg.get( + "weight_decay", cfg.train.get("weight_decay", 0.0) + ), + muon_momentum_beta=optimizer_cfg.get("muon_momentum_beta", 0.95), + logger=logger, + ) + scheduler = create_scheduler(cfg, optimizer, logger) + # GradScaler is only meaningful for fp16 AMP; bf16 doesn't underflow and + # disabling avoids the overhead + masks fp16-specific failure modes. + scaler = GradScaler(enabled=use_amp and amp_dtype is torch.float16) + LaunchLogger.initialize(use_wandb=False, use_mlflow=False) + use_tensorboard = cfg.train.get("tensorboard", True) + writer = ( + SummaryWriter(os.path.join(cfg.output, "tensorboard")) + if (use_tensorboard and dist.rank == 0) + else None + ) + checkpoint_dir = os.path.join(cfg.output, "checkpoints") + os.makedirs(checkpoint_dir, exist_ok=True) + + loss_cfg = parse_loss_config(cfg, dist, logger) + loss_metric = cfg.train.get("loss_metric", "mse") + loss_cfg["loss_metric"] = loss_metric + logger.info(f"Loss metric: {loss_metric}") + + start_epoch, best_val_loss = resume_if_available( + cfg, model, optimizer, scheduler, scaler, dist, logger + ) + + logger.info("\n" + "=" * 70) + logger.info("Starting training...") + logger.info("=" * 70) + + run_training_loop( + cfg=cfg, + dist=dist, + model=model, + train_loader=train_loader, + val_loader=val_loader, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + loss_cfg=loss_cfg, + logger=logger, + checkpoint_dir=checkpoint_dir, + writer=writer, + best_val_loss=best_val_loss, + start_epoch=start_epoch, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/nuclear_engineering/radiation_transport/src/trainer.py b/examples/nuclear_engineering/radiation_transport/src/trainer.py new file mode 100644 index 0000000000..b146a5d51a --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/trainer.py @@ -0,0 +1,692 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import random +from contextlib import nullcontext +from pathlib import Path +from typing import Any, Dict, Mapping, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as torch_dist +import torch.nn as nn +from omegaconf import DictConfig, OmegaConf +from torch.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel +from torch.utils.tensorboard import SummaryWriter +from physicsnemo.datapipes import DataLoader +from physicsnemo.distributed import DistributedManager +from physicsnemo.distributed.utils import reduce_loss +from physicsnemo.utils.checkpoint import save_checkpoint +from physicsnemo.utils.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo.utils.logging.launch import LaunchLogger + +from checkpointing import save_best_checkpoint +from losses import ( + compute_physics_loss, + physics_loss_weight_for_epoch, + region_weighted_loss_fn, +) + + +def set_seed(seed: int) -> None: + """Set random seed for reproducibility across all RNGs.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +_AMP_DTYPES: Dict[str, torch.dtype] = { + "bf16": torch.bfloat16, + "fp16": torch.float16, +} + + +def parse_amp(cfg: DictConfig) -> Tuple[bool, torch.dtype]: + """Read ``cfg.train.amp`` and ``cfg.train.amp_dtype`` into ``(use_amp, dtype)``.""" + name = cfg.train.get("amp_dtype", "bf16") + if name not in _AMP_DTYPES: + raise ValueError( + f"Unsupported amp_dtype {name!r}; allowed: {sorted(_AMP_DTYPES)}." + ) + return cfg.train.get("amp", True), _AMP_DTYPES[name] + + +def synchronize_output_directory( + cfg: DictConfig, + dist: DistributedManager, +) -> str: + """Ensure ``cfg.output`` exists; barrier so DDP ranks don't race past it. + + Rank 0 creates the directory tree; + a final barrier keeps the other ranks from proceeding before it lands. + """ + if "output" not in cfg: + OmegaConf.set_struct(cfg, False) + cfg.output = os.path.join("outputs", "default") + OmegaConf.set_struct(cfg, True) + + output_dir = cfg.output + if dist.rank == 0: + os.makedirs(output_dir, exist_ok=True) + os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True) + if dist.distributed: + torch_dist.barrier() + return output_dir + + +def aggregate_validation_loss( + loss_sum: float, + num_batches: int, + dist: DistributedManager, +) -> float: + """Aggregate validation loss across DDP ranks via ``reduce_loss``. + + Returns the rank-0 mean-of-means; non-rank-0 ranks get their local mean + (unused downstream). Eval sampler pads the split to equal length across + ranks, so the mean-of-means equals the global mean up to at most + ``world_size - 1`` duplicate samples. + """ + per_rank_mean = loss_sum / max(num_batches, 1) + if not dist.distributed: + return per_rank_mean + reduced = reduce_loss(per_rank_mean, dst_rank=0, mean=True) + return reduced if reduced is not None else per_rank_mean + + +def aggregate_validation_metrics( + metric_sums: Mapping[str, float], + metric_counts: Mapping[str, int], + dist: DistributedManager, +) -> Dict[str, float]: + """Aggregate named validation metrics via tensor ``all_reduce`` over a + known schema. + + Every rank emits the same metric keys (the schema is fixed at config + time, not per-batch), so we sort keys, stack values into a single + tensor, and issue one collective per (sums, counts) pair. + """ + if not dist.distributed: + return { + key: metric_sums[key] / metric_counts[key] + for key in metric_sums + if metric_counts.get(key, 0) > 0 + } + + keys = sorted(metric_sums.keys()) + if not keys: + return {} + + sums = torch.tensor( + [float(metric_sums[k]) for k in keys], + dtype=torch.float64, + device=dist.device, + ) + counts = torch.tensor( + [int(metric_counts.get(k, 0)) for k in keys], + dtype=torch.int64, + device=dist.device, + ) + torch_dist.all_reduce(sums, op=torch_dist.ReduceOp.SUM) + torch_dist.all_reduce(counts, op=torch_dist.ReduceOp.SUM) + + return { + key: float(sums[i].item() / counts[i].item()) + for i, key in enumerate(keys) + if counts[i].item() > 0 + } + + +def setup_training_environment( + cfg: DictConfig, + model_name: str, +) -> Tuple[DistributedManager, Any]: + """Initialize DDP, sync the output dir, build a logger, and log a banner. + + Args: + cfg: Hydra configuration. + model_name: Human-readable model name for logging (e.g. "Transolver"). + + Returns: + ``(dist, logger)``. + """ + DistributedManager.initialize() + dist = DistributedManager() + + synchronize_output_directory(cfg, dist) + + logger = RankZeroLoggingWrapper(PythonLogger(f"RTE_{model_name}"), dist) + if dist.rank == 0: + logger.file_logging(os.path.join(cfg.output, "train.log")) + + logger.info("=" * 70) + logger.info(f"RTE {model_name} Training - {cfg.case.type.upper()}") + logger.info("=" * 70) + if dist.distributed: + logger.info(f"Distributed training: {dist.world_size} GPUs") + logger.info(f"\nConfiguration:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}\n") + + return dist, logger + + +def wrap_ddp( + model: nn.Module, + dist: DistributedManager, + logger: Any, + find_unused_parameters: bool = False, +) -> nn.Module: + """Wrap ``model`` with DistributedDataParallel if running distributed. + + Returns the unwrapped model in single-GPU mode. + """ + if not dist.distributed: + return model + + ddps = torch.cuda.Stream() + with torch.cuda.stream(ddps): + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=find_unused_parameters, + ) + torch.cuda.current_stream().wait_stream(ddps) + + fup = " (find_unused_parameters=True)" if find_unused_parameters else "" + logger.info(f"Using DistributedDataParallel with {dist.world_size} GPUs{fup}") + return model + + +def compute_losses( + pred: torch.Tensor, + target: torch.Tensor, + loss_inputs: Mapping[str, Any], + loss_cfg: Mapping[str, Any], + case_type: str, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], dict]: + """Compose the per-batch training loss. + + Args: + pred, target: ``(B, N, 1)`` tensors. + loss_inputs: presence-driven dispatch dict. Recognized keys: + - ``material_labels`` ``(B, N)`` or ``(B, N, 1)``: enables + region-weighted loss when ``loss_cfg['use_region_weighted_loss']``. + - ``coordinates_unnormalized`` ``(B, N, D)``, ``cell_areas`` + ``(B, N)``, ``sigma_t`` ``(B, N)``, ``sigma_s`` ``(B, N)``, + ``sim_time`` ``(B,)`` or ``(B, 1)``: required for physics loss. + - ``metadata``, ``flux_normalization_stats``: optional physics + context. + loss_cfg: ``use_region_weighted_loss``, ``region_weight_cfg``, + ``loss_metric`` ("mse"|"rmse"), ``use_physics_loss``, + ``physics_loss_weight``, ``physics_loss_mse_weight``. + + Returns: + ``(loss, loss_mse, loss_qoi_or_None, qoi_details_dict)``. + """ + use_region_weighted = loss_cfg.get("use_region_weighted_loss", False) + loss_metric = loss_cfg.get("loss_metric", "mse") + + if use_region_weighted and "material_labels" in loss_inputs: + rw = loss_cfg.get("region_weight_cfg") or {} + loss_mse = region_weighted_loss_fn( + pred, + target, + material_labels=loss_inputs["material_labels"], + case_type=case_type, + void_weight=rw.get("void_weight", 3.0), + material_weight=rw.get("material_weight", 1.0), + ) + else: + loss_mse = ((pred - target) ** 2).mean() + if loss_metric == "rmse": + loss_mse = torch.sqrt(loss_mse) + + if not loss_cfg.get("use_physics_loss", False): + return loss_mse, loss_mse, None, {} + + physics_w = loss_cfg.get("physics_loss_weight", 0.1) + if not physics_w: + # Zero (or missing/None) weight -> physics loss is disabled; skip the + # QoI computation entirely. + return loss_mse, loss_mse, None, {} + + with autocast(enabled=False, device_type=device.type): + loss_qoi, qoi_details = compute_physics_loss( + case_type=case_type, + predicted_flux=pred, + target_flux=target, + cell_centers=loss_inputs["coordinates_unnormalized"], + cell_areas=loss_inputs["cell_areas"], + sigma_t=loss_inputs["sigma_t"], + sigma_s=loss_inputs["sigma_s"], + sim_time=loss_inputs["sim_time"], + sample=loss_inputs, + flux_normalization_stats=loss_inputs.get("flux_normalization_stats"), + ) + + mse_w = loss_cfg.get("physics_loss_mse_weight", 1.0) + loss = mse_w * loss_mse + physics_w * loss_qoi + return loss, loss_mse, loss_qoi, qoi_details + + +def to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: + """Move tensor entries of a batch dict to ``device``; pass through the rest.""" + return { + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() + } + + +def forward( + model: nn.Module, + batch: Dict[str, Any], +) -> torch.Tensor: + """Run a forward pass with the Transolver-expected input keys.""" + return model(fx=batch["fx"], embedding=batch["embedding"]) + + +_PHYSICS_KEYS = ( + "coordinates_unnormalized", + "cell_areas", + "sigma_t", + "sigma_s", + "sim_time", +) + + +def loss_inputs(batch: Dict[str, Any], require_physics: bool = False) -> Dict[str, Any]: + """Assemble the optional/physics inputs consumed by ``compute_losses``. + + Always copies ``material_labels`` if present. Physics-loss tensors are + copied only when all of ``_PHYSICS_KEYS`` are in the batch; + ``require_physics=True`` raises if any is missing. ``metadata`` and + ``flux_normalization_stats`` are forwarded when present. + """ + inputs: Dict[str, Any] = {} + if "material_labels" in batch: + inputs["material_labels"] = batch["material_labels"] + + missing = [k for k in _PHYSICS_KEYS if k not in batch] + if missing: + if require_physics: + msg = f"Missing physics-loss input(s): {missing}." + if "coordinates_unnormalized" in missing: + msg += " (Enable the RTEBackupCoords transform in the data pipeline.)" + raise KeyError(msg) + return inputs + + for k in _PHYSICS_KEYS: + inputs[k] = batch[k] + for k in ("ulr", "llr", "urr", "lrr", "hlr", "hrr", "cx", "cy"): + if k in batch: + inputs[k] = batch[k] + if "flux_normalization_stats" in batch: + inputs["flux_normalization_stats"] = batch["flux_normalization_stats"] + return inputs + + +def _log_minibatch( + launch_logger: LaunchLogger, + loss: torch.Tensor, + loss_mse: torch.Tensor, + loss_qoi: Optional[torch.Tensor], + qoi_details: Dict[str, float], + scale: float, +) -> None: + metrics = {"loss": loss.item() * scale, "loss_mse": loss_mse.item()} + if loss_qoi is not None: + metrics["loss_qoi"] = loss_qoi.item() + metrics.update(qoi_details) + launch_logger.log_minibatch(metrics) + + +def train_epoch( + cfg: DictConfig, + dataloader: DataLoader, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scaler: GradScaler, + device: torch.device, + launch_logger: LaunchLogger, + loss_cfg: Dict[str, Any], +) -> None: + """Run one Transolver training epoch. + + Reads ``cfg.case.type``, ``cfg.train.amp*``, and + ``cfg.train.gradient_accumulation_steps`` directly so callers only + thread the per-epoch ``loss_cfg`` (which is pre-processed by + :func:`losses.parse_loss_config` and varies per epoch via warmup). + """ + case_type = cfg.case.type + use_amp, amp_dtype = parse_amp(cfg) + accum_steps = cfg.train.get("gradient_accumulation_steps", 1) + max_grad_norm = float(cfg.train.get("max_grad_norm", 10.0)) + + model.train() + epoch_len = len(dataloader) + + for i, batch in enumerate(dataloader): + # Gradient accumulation with DDP-aware grad-sync skip: zero at window + # start, run backward inside ``model.no_sync()`` until the boundary + # step (or the final batch of the epoch), then step + clip + update. + if i % accum_steps == 0: + optimizer.zero_grad(set_to_none=True) + is_step_boundary = (i + 1) % accum_steps == 0 or (i + 1) == epoch_len + + batch = to_device(batch, device) + + with autocast(enabled=use_amp, device_type=device.type, dtype=amp_dtype): + prediction = forward(model, batch) + + pred, target = prediction, batch["flux_target"] + + loss, loss_mse, loss_qoi, qoi_details = compute_losses( + pred=pred.float(), + target=target.float(), + loss_inputs=loss_inputs( + batch, require_physics=loss_cfg.get("use_physics_loss", False) + ), + loss_cfg=loss_cfg, + case_type=case_type, + device=device, + ) + + _log_minibatch( + launch_logger, + loss, + loss_mse, + loss_qoi, + qoi_details, + scale=1, + ) + + sync_ctx = ( + model.no_sync() + if (not is_step_boundary and hasattr(model, "no_sync")) + else nullcontext() + ) + with sync_ctx: + scaler.scale(loss / accum_steps).backward() + + if is_step_boundary: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) + scaler.step(optimizer) + scaler.update() + + +@torch.no_grad() +def validate( + cfg: DictConfig, + dataloader: DataLoader, + model: nn.Module, + device: torch.device, + launch_logger: LaunchLogger, + loss_cfg: Dict[str, Any], +) -> Tuple[float, int, Dict[str, float], Dict[str, int]]: + """Run validation and return loss plus metric sums/counts for DDP reduce.""" + case_type = cfg.case.type + use_amp, amp_dtype = parse_amp(cfg) + + model.eval() + eval_model = model.module if hasattr(model, "module") else model + + loss_sum = 0.0 + num_batches = 0 + metric_sums: Dict[str, float] = {} + metric_counts: Dict[str, int] = {} + + def accumulate_metric(name: str, value: Any) -> None: + scalar = float(value) + metric_sums[name] = metric_sums.get(name, 0.0) + scalar + metric_counts[name] = metric_counts.get(name, 0) + 1 + + for batch in dataloader: + batch = to_device(batch, device) + + with autocast(enabled=use_amp, device_type=device.type, dtype=amp_dtype): + prediction = forward(eval_model, batch) + + pred, target = prediction, batch["flux_target"] + + loss, loss_mse, loss_qoi, qoi_details = compute_losses( + pred=pred.float(), + target=target.float(), + loss_inputs=loss_inputs( + batch, require_physics=loss_cfg.get("use_physics_loss", False) + ), + loss_cfg=loss_cfg, + case_type=case_type, + device=device, + ) + + _log_minibatch(launch_logger, loss, loss_mse, loss_qoi, qoi_details, scale=1) + + loss_sum += loss.item() + num_batches += 1 + accumulate_metric("loss_mse", loss_mse.item()) + if loss_qoi is not None: + accumulate_metric("loss_qoi", loss_qoi.item()) + for key, value in qoi_details.items(): + accumulate_metric(key, value) + + return loss_sum, num_batches, metric_sums, metric_counts + + +def _format_epoch_log( + epoch: int, + train_log: Any, + val_log: Any, + val_loss: float, + current_lr: float, +) -> str: + """Build the per-epoch rank-0 log line. + + Emits ``train_loss`` / ``val_loss`` first, then ``train_X`` / ``val_X`` + pairs for every other metric key present in either log (sorted), then + ``lr``. Joined with ", ". + """ + parts = [ + f"train_loss={train_log.epoch_losses.get('loss', 0.0):.4e}", + f"val_loss={val_loss:.4e}", + ] + extra_keys = sorted( + {k for k in (*train_log.epoch_losses, *val_log.epoch_losses) if k != "loss"} + ) + for key in extra_keys: + short = key.removeprefix("loss_") + if key in train_log.epoch_losses: + parts.append(f"train_{short}={train_log.epoch_losses[key]:.4e}") + if key in val_log.epoch_losses: + parts.append(f"val_{short}={val_log.epoch_losses[key]:.4e}") + parts.append(f"lr={current_lr:.2e}") + return f"Epoch {epoch}: " + ", ".join(parts) + + +def run_training_loop( + cfg: DictConfig, + dist: DistributedManager, + model: torch.nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: Any, + scaler: GradScaler, + loss_cfg: Dict[str, Any], + logger: Any, + checkpoint_dir: str, + writer: Optional[SummaryWriter], + best_val_loss: float, + start_epoch: int, +) -> None: + """Run the main training loop: epochs, validation, checkpointing, logging. + + Drives the epoch loop, applies the physics-loss warmup ramp inline, + aggregates validation loss across DDP ranks, steps the scheduler, and + saves the single best-by-val_loss checkpoint. The per-epoch train and + validate steps are :func:`train_epoch` and :func:`validate` in this + module; they read case type, AMP, and gradient-accumulation settings + from ``cfg`` directly. + + Args: + cfg: Hydra config (uses ``train.epochs``, ``case.type``, ``train.amp*``, + and ``train.gradient_accumulation_steps``). + dist: DistributedManager instance. + model: Model (possibly DDP-wrapped). The DistributedSampler is + already attached to ``train_loader``; the loader forwards + ``set_epoch`` to it. + train_loader: Training DataLoader. + val_loader: Validation DataLoader. + optimizer: Optimizer. + scheduler: LR scheduler. + scaler: GradScaler for AMP. + loss_cfg: Loss configuration from + :func:`losses.parse_loss_config`. The trainer applies the + physics-loss warmup ramp to the ``physics_loss_weight`` per epoch + for the training pass; validation always uses the unwarmed dict. + logger: Logger (rank 0). + checkpoint_dir: Directory for checkpoints. + writer: TensorBoard SummaryWriter (rank 0) or None. + best_val_loss: Best validation loss seen so far (lower is better). + start_epoch: First epoch index to run. + """ + case_type = cfg.case.type + + for epoch in range(start_epoch, cfg.train.epochs): + train_loader.set_epoch(epoch) + val_loader.set_epoch(epoch) + + current_physics_weight = physics_loss_weight_for_epoch(loss_cfg, epoch) + epoch_loss_cfg = { + **loss_cfg, + "physics_loss_weight": current_physics_weight, + } + if current_physics_weight != loss_cfg["physics_loss_weight"]: + logger.info( + f"Physics loss warmup: epoch {epoch}, " + f"weight={current_physics_weight:.6f}" + ) + + with LaunchLogger( + "train", + epoch=epoch, + num_mini_batch=len(train_loader), + mini_batch_log_freq=10, + ) as train_log: + train_epoch( + cfg, + train_loader, + model, + optimizer, + scaler, + dist.device, + train_log, + loss_cfg=epoch_loss_cfg, + ) + + with LaunchLogger( + "val", epoch=epoch, num_mini_batch=len(val_loader) + ) as val_log: + ( + val_loss_sum, + val_num_batches, + val_metric_sums, + val_metric_counts, + ) = validate( + cfg, + val_loader, + model, + dist.device, + val_log, + loss_cfg=loss_cfg, + ) + + train_loss = train_log.epoch_losses.get("loss", 0.0) + val_loss = aggregate_validation_loss(val_loss_sum, val_num_batches, dist) + val_metrics = aggregate_validation_metrics( + val_metric_sums, val_metric_counts, dist + ) + val_log.epoch_losses.update(val_metrics) + + scheduler.step() + current_lr = scheduler.get_last_lr()[0] + + val_loss_qoi = val_metrics.get("loss_qoi") + + if dist.rank == 0: + logger.info( + _format_epoch_log(epoch, train_log, val_log, val_loss, current_lr) + ) + + if writer: + writer.add_scalar("Loss/train", train_loss, epoch) + writer.add_scalar("Loss/val", val_loss, epoch) + writer.add_scalar("Learning_Rate", current_lr, epoch) + + if not ( + math.isfinite(train_loss) + and math.isfinite(val_loss) + and (val_loss_qoi is None or math.isfinite(val_loss_qoi)) + ): + logger.warning( + "Skipping checkpoint save for epoch %s because at least " + "one checkpoint metric is NaN or inf: " + "train_loss=%s, val_loss=%s, val_loss_qoi=%s", + epoch, + train_loss, + val_loss, + val_loss_qoi, + ) + else: + best_val_loss = save_best_checkpoint( + checkpoint_dir=Path(checkpoint_dir), + val_loss=val_loss, + best_val_loss=best_val_loss, + save_checkpoint_fn=save_checkpoint, + logger=logger, + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + epoch=epoch, + metadata={ + "best_val_loss": val_loss, + "train_loss": train_loss, + "val_loss": val_loss, + "val_loss_qoi": val_loss_qoi, + "case_type": case_type, + }, + ) + + if val_loss_qoi is not None and writer: + writer.add_scalar("Loss/val_qoi", val_loss_qoi, epoch) + + if dist.distributed: + torch_dist.barrier() + + if writer: + writer.close() + + logger.info("=" * 70) + logger.info("Training completed!") + if best_val_loss < float("inf"): + logger.info(f"Best validation loss: {best_val_loss:.6f}") + logger.info(f"Checkpoints saved to: {checkpoint_dir}") + logger.info("=" * 70) diff --git a/examples/nuclear_engineering/radiation_transport/src/transforms.py b/examples/nuclear_engineering/radiation_transport/src/transforms.py new file mode 100644 index 0000000000..574d708f52 --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/transforms.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from typing import Any, Dict, Optional, Tuple + +import torch +from physicsnemo.datapipes.registry import register +from physicsnemo.datapipes.transforms import Transform +from tensordict import TensorDict + +__all__ = [ + "Transform", + "RTEFluxLogClip", + "denormalize_flux", + "GLOBAL_DOMAIN_BOUNDS", + "RTEBackupCoords", + "FourierFeatures", + "coord_bounds_for_case", + "coord_translate_scale_params", + "MaterialPropertyExtractor", + "SpatialSampler", + "FinalTimeSampler", +] + + +def denormalize_flux( + normalized_flux: torch.Tensor, + stats: Dict[str, float], +) -> torch.Tensor: + """Invert the ``RTEFluxLogClip + Normalize`` chain for evaluation/inference. + + ``normalized_flux`` is the model output in z-score-of-log space; + ``stats`` is the ``flux_normalization_stats`` dict that ``RTEFluxLogClip`` + recorded on the sample. + """ + mean = stats["log_flux_mean"] + std = stats["log_flux_std"] + clip = stats["clip_threshold"] + log_flux = normalized_flux * std + mean + log_flux = torch.clamp(log_flux, min=-38, max=38) + flux = torch.pow(10.0, log_flux) - clip + return torch.clamp(flux, min=0.0) + + +@register("RTEFluxLogClip") +class RTEFluxLogClip(Transform): + """Clip flux to a threshold, apply ``log10``, and record denorm stats. + + Input: + ``scalar_flux`` -- shape ``(T, N)`` or ``(N,)``, float tensor. + + Output: + ``scalar_flux`` -- same shape, ``log10(clamp(x, clip) + clip)``. + ``flux_normalization_stats`` -- non-tensor dict with ``log_flux_mean``, + ``log_flux_std``, ``clip_threshold`` for downstream denormalization. + """ + + def __init__( + self, + clip_threshold: float, + log_flux_mean: float, + log_flux_std: float, + ) -> None: + super().__init__() + self.clip_threshold = float(clip_threshold) + self.log_flux_mean = float(log_flux_mean) + self.log_flux_std = float(log_flux_std) + + def __call__(self, data: TensorDict) -> TensorDict: + flux = data["scalar_flux"] + clip = torch.tensor(self.clip_threshold, dtype=flux.dtype, device=flux.device) + flux = torch.clamp(flux, min=clip) + data["scalar_flux"] = torch.log10(flux + clip) + data.set_non_tensor( + "flux_normalization_stats", + { + "log_flux_mean": self.log_flux_mean, + "log_flux_std": self.log_flux_std, + "clip_threshold": self.clip_threshold, + }, + ) + return data + + def extra_repr(self) -> str: + return ( + f"clip_threshold={self.clip_threshold}, " + f"log_flux_mean={self.log_flux_mean:.4f}, " + f"log_flux_std={self.log_flux_std:.4f}" + ) + + +GLOBAL_DOMAIN_BOUNDS = { + "lattice": { + "min": torch.tensor([-3.5, -3.5], dtype=torch.float32), + "max": torch.tensor([3.5, 3.5], dtype=torch.float32), + }, + "hohlraum": { + "min": torch.tensor([-0.65, -0.65], dtype=torch.float32), + "max": torch.tensor([0.65, 0.65], dtype=torch.float32), + }, +} + + +@register("RTEBackupCoords") +class RTEBackupCoords(Transform): + """Clone ``coordinates`` into ``coordinates_unnormalized`` before Translate/Scale. + + Downstream consumers (e.g. graph construction or rasterization) read + ``coordinates_unnormalized`` for physical-space operations. Place this + transform immediately before + ``physicsnemo.datapipes.transforms.Translate`` + ``Scale`` in the + pipeline so the raw coords survive the normalization. + """ + + def __init__(self) -> None: + super().__init__() + + def __call__(self, data: TensorDict) -> TensorDict: + data["coordinates_unnormalized"] = data["coordinates"].clone() + return data + + def extra_repr(self) -> str: + return "preserve raw coordinates" + + +@register("RTEFourierFeatures") +class FourierFeatures(Transform): + """Sin/cos positional encoding features at multiple frequency scales.""" + + def __init__( + self, + num_frequencies: int = 3, + coord_dims: int = 2, + base_frequency: float = 1.0, + append_to_coordinates: bool = True, + ): + super().__init__() + self.num_frequencies = num_frequencies + self.coord_dims = coord_dims + self.base_frequency = base_frequency + self.append_to_coordinates = append_to_coordinates + self.frequency_multipliers = [ + 2**i * base_frequency for i in range(num_frequencies) + ] + + def get_output_dim(self) -> int: + """Number of Fourier-feature channels emitted (``2 * num_frequencies * coord_dims``).""" + return 2 * self.num_frequencies * self.coord_dims + + def __call__(self, data: TensorDict) -> TensorDict: + coords = data["coordinates"] + coords_subset = coords[:, : self.coord_dims].to(dtype=torch.float32) + + two_pi = 2.0 * math.pi + parts = [] + for freq_mult in self.frequency_multipliers: + angle = two_pi * float(freq_mult) * coords_subset + parts.append(torch.sin(angle)) + parts.append(torch.cos(angle)) + + fourier_features = torch.cat(parts, dim=-1).to(dtype=torch.float32) + data["fourier_features"] = fourier_features + + if self.append_to_coordinates: + data["coordinates"] = torch.cat( + [coords.to(dtype=torch.float32), fourier_features], dim=-1 + ) + return data + + def extra_repr(self) -> str: + return ( + f"num_frequencies={self.num_frequencies}, coord_dims={self.coord_dims}, " + f"base_frequency={self.base_frequency}, " + f"append_to_coordinates={self.append_to_coordinates}" + ) + + +@register("RTESpatialSampler") +class SpatialSampler(Transform): + """Randomly subsample spatial points to ``num_points``. + + ``num_points = -1`` is a passthrough. Otherwise ``num_available`` must be + ``>= num_points`` (the shipped lattice / hohlraum meshes have tens of + thousands of cells, far above any practical ``num_points``). + """ + + # Stride used when re-seeding per epoch; large prime keeps streams disjoint. + _EPOCH_PRIME: int = 1_000_003 + + def __init__(self, num_points: int, seed: Optional[int] = None): + super().__init__() + self.num_points = num_points + self.seed = seed + self.gen = torch.Generator() + if seed is not None: + self.gen.manual_seed(int(seed)) + + def set_epoch(self, epoch: int) -> None: + """Re-seed the generator for a new epoch (deterministic reshuffle). + + No-op when ``self.seed`` is ``None`` (caller opted into a non-deterministic + run; preserve current generator state). + """ + if self.seed is None: + return + self.gen.manual_seed(int(self.seed) + int(epoch) * self._EPOCH_PRIME) + + def to(self, device): + """No-op device move. ``self.gen`` stays pinned to CPU because + ``torch.randperm`` requires its generator and output to share a + device; selected indices are moved inside ``__call__``. + """ + return self + + def __call__(self, data: TensorDict) -> TensorDict: + if self.num_points == -1: + return data + + num_available = data["coordinates"].shape[0] + if num_available == self.num_points: + return data + if num_available < self.num_points: + raise ValueError( + f"SpatialSampler: num_available={num_available} < " + f"num_points={self.num_points}; the shipped meshes are larger " + "than any configured num_points, so this should never happen." + ) + + indices = torch.randperm(num_available, generator=self.gen)[: self.num_points] + indices = indices.to(torch.int64).to(data["coordinates"].device) + + spatial_keys = [ + "coordinates", + "cell_areas", + "material_properties", + "physical_properties", + "geometric_features", + "sigma_t", + "sigma_s", + "sigma_a", + "Q", + ] + for key in spatial_keys: + if key in data and data[key] is not None: + data[key] = data[key][indices] + + if "scalar_flux" in data: + data["scalar_flux"] = data["scalar_flux"][:, indices] + + for flux_key in ("flux_input", "flux_target"): + if flux_key in data: + data[flux_key] = data[flux_key][indices] + + return data + + def extra_repr(self) -> str: + return f"num_points={self.num_points}" + + +@register("RTEFinalTimeSampler") +class FinalTimeSampler(Transform): + """Extract the fixed final-time mapping: first flux -> final flux.""" + + def __init__(self): + super().__init__() + + def __call__(self, data: TensorDict) -> TensorDict: + flux_all = data["scalar_flux"] + if flux_all.shape[0] == 0: + raise ValueError("scalar_flux must contain at least one snapshot") + + input_idx = 0 + target_idx = flux_all.shape[0] - 1 + + data["flux_input"] = flux_all[input_idx].clone() + data["flux_target"] = flux_all[target_idx].clone() + data.set_non_tensor("timestep_input", 0) + data.set_non_tensor("timestep_target", int(target_idx)) + return data + + +@register("RTEMaterialPropertyExtractor") +class MaterialPropertyExtractor(Transform): + """Stack precomputed sigma fields into a per-cell ``(N, 4)`` tensor. + + Q must be present in the source data; it may be all-zero for source-free + regimes (e.g., hohlraum). + """ + + def __call__(self, data: TensorDict) -> TensorDict: + for key in ("sigma_a", "sigma_s", "sigma_t", "Q"): + if key not in data: + raise KeyError( + f"Mesh store is missing required field {key!r}. " + "All four fields (sigma_a, sigma_s, sigma_t, Q) must be precomputed." + ) + + data["physical_properties"] = torch.stack( + [data["sigma_a"], data["sigma_s"], data["sigma_t"], data["Q"]], + dim=-1, + ).to(dtype=torch.float32) + return data + + +def coord_bounds_for_case(case_type: str) -> Tuple[torch.Tensor, torch.Tensor]: + """Return ``(bbox_min, bbox_max)`` as float32 tensors for a known case.""" + if case_type not in GLOBAL_DOMAIN_BOUNDS: + raise ValueError( + f"Unknown case_type '{case_type}'. " + f"Expected one of: {list(GLOBAL_DOMAIN_BOUNDS.keys())}" + ) + bounds = GLOBAL_DOMAIN_BOUNDS[case_type] + return ( + torch.as_tensor(bounds["min"], dtype=torch.float32), + torch.as_tensor(bounds["max"], dtype=torch.float32), + ) + + +def coord_translate_scale_params( + case_type: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute ``(center, half_extent)`` for ``Translate`` + ``Scale``. + + Returns the tensors so the caller can wire them straight into + ``Translate(center_key_or_value=center, subtract=True)`` followed by + ``Scale(scale=half_extent, divide=True)`` — i.e. the standard + ``(x - center) / half_extent`` normalization into ``[-1, 1]``. + """ + bbox_min, bbox_max = coord_bounds_for_case(case_type) + center = 0.5 * (bbox_min + bbox_max) + half_extent = 0.5 * (bbox_max - bbox_min) + return center, half_extent diff --git a/examples/nuclear_engineering/radiation_transport/src/viz.py b/examples/nuclear_engineering/radiation_transport/src/viz.py new file mode 100644 index 0000000000..6e5953b6ce --- /dev/null +++ b/examples/nuclear_engineering/radiation_transport/src/viz.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Dict, Tuple, Union + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.colors import LogNorm +import numpy as np + +__all__ = ["plot_flux_panels", "plot_qoi_true_vs_pred"] + + +def plot_flux_panels( + coordinates: np.ndarray, + target: np.ndarray, + prediction: np.ndarray, + output_path: Union[str, Path], + log_flux: bool = False, + figsize: Tuple[int, int] = (16, 5), + dpi: int = 150, +) -> Path: + """Render a 3-panel figure: target | prediction | absolute error.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + target = target.flatten() + prediction = prediction.flatten() + error = np.abs(prediction - target) + + x, y = coordinates[:, 0], coordinates[:, 1] + x_pad = (x.max() - x.min()) * 0.01 + y_pad = (y.max() - y.min()) * 0.01 + xlim = (x.min() - x_pad, x.max() + x_pad) + ylim = (y.min() - y_pad, y.max() + y_pad) + + fig, axes = plt.subplots(1, 3, figsize=figsize, dpi=dpi) + flux_vmin = min(target.min(), prediction.min()) + flux_vmax = max(target.max(), prediction.max()) + flux_norm = None + if log_flux: + positive_flux = np.concatenate( + [target[target > 0.0], prediction[prediction > 0.0]] + ) + if positive_flux.size: + flux_vmin = float(positive_flux.min()) + flux_vmax = float(positive_flux.max()) + if flux_vmin == flux_vmax: + flux_vmax = flux_vmin * 1.01 + flux_norm = LogNorm(vmin=flux_vmin, vmax=flux_vmax) + else: + log_flux = False + cmap_flux = plt.get_cmap("viridis") + cmap_err = plt.get_cmap("hot") + + for ax, label, vals, cmap, vmin, vmax, norm in ( + (axes[0], "Target", target, cmap_flux, flux_vmin, flux_vmax, flux_norm), + ( + axes[1], + "Prediction", + prediction, + cmap_flux, + flux_vmin, + flux_vmax, + flux_norm, + ), + (axes[2], "Absolute Error", error, cmap_err, 0.0, float(error.max()), None), + ): + plot_vals = np.clip(vals, flux_vmin, None) if norm is not None else vals + sc = ax.scatter( + x, + y, + c=plot_vals, + cmap=cmap, + vmin=None if norm is not None else vmin, + vmax=None if norm is not None else vmax, + norm=norm, + s=1, + ) + ax.set_aspect("equal") + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.set_title(f"{label} (log)" if norm is not None else label) + plt.colorbar(sc, ax=ax) + + plt.tight_layout() + plt.savefig(output_path, dpi=dpi, bbox_inches="tight") + plt.close(fig) + return output_path + + +def plot_qoi_true_vs_pred( + per_sample_qoi: list[Dict[str, Dict[str, float]]], + output_path: Union[str, Path], + dpi: int = 150, +) -> Path: + """Scatter predicted vs ground-truth QoI values for each component. + + Takes the same per-sample QoI list that ``aggregate_qoi`` consumes; the + per-component arrays are flattened inline rather than via a separate + collector. + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Preserve first-seen order of component names across samples. + component_names: list[str] = [] + for sample in per_sample_qoi: + for name in sample: + if name not in component_names: + component_names.append(name) + + series: Dict[str, Tuple[np.ndarray, np.ndarray]] = {} + for name in component_names: + target_vals: list[float] = [] + pred_vals: list[float] = [] + for sample in per_sample_qoi: + entry = sample.get(name) + if entry is None: + continue + target_vals.append(entry["ground_truth"]) + pred_vals.append(entry["predicted"]) + if target_vals: + series[name] = (np.array(target_vals), np.array(pred_vals)) + + items = list(series.items()) + if not items: + plt.close(plt.figure()) + return output_path + + ncols = min(len(items), 3) + nrows = int(np.ceil(len(items) / ncols)) + fig, axes = plt.subplots( + nrows, ncols, figsize=(5 * ncols, 4.5 * nrows), dpi=dpi, squeeze=False + ) + + for ax, (name, (target, prediction)) in zip(axes.flat, items): + lo = float(min(target.min(), prediction.min())) + hi = float(max(target.max(), prediction.max())) + if lo == hi: + pad = max(abs(lo) * 0.05, 1e-12) + lo -= pad + hi += pad + + ax.scatter(target, prediction, s=18, alpha=0.75) + ax.plot([lo, hi], [lo, hi], "r--", linewidth=1.0, label="y = x") + ax.set_title(name) + ax.set_xlabel("Ground truth QoI") + ax.set_ylabel("Predicted QoI") + ax.set_aspect("equal") + ax.legend(loc="best") + + for ax in axes.flat[len(items) :]: + ax.axis("off") + + fig.suptitle("QoI predicted vs. ground truth") + plt.tight_layout() + plt.savefig(output_path, dpi=dpi, bbox_inches="tight") + plt.close(fig) + return output_path