From 4b740424453a874e1289c29824c2e64388132bf0 Mon Sep 17 00:00:00 2001 From: Greg Partin Date: Wed, 11 Mar 2026 11:57:39 -0700 Subject: [PATCH 1/2] Add 2D wave equation FNO training recipe Add a complete Fourier Neural Operator example for the 2D wave equation, filling the hyperbolic PDE gap in PhysicsNeMo examples. - train_fno_wave.py: Training script following darcy_fno patterns - wave_data.py: On-the-fly leapfrog data generator with periodic BCs - validator.py: Side-by-side visualization (initial/truth/prediction/error) - config.yaml: Hydra config for 128x128 grid, 4-layer FNO - README.md: Problem description and usage instructions The model learns u(x,y,0) -> u(x,y,T) using random Fourier initial conditions. Data is generated in-situ, matching the Darcy2D pattern. --- examples/wave/wave_fno/README.md | 51 ++++++ examples/wave/wave_fno/config.yaml | 51 ++++++ examples/wave/wave_fno/requirements.txt | 2 + examples/wave/wave_fno/train_fno_wave.py | 170 ++++++++++++++++++++ examples/wave/wave_fno/validator.py | 95 ++++++++++++ examples/wave/wave_fno/wave_data.py | 188 +++++++++++++++++++++++ 6 files changed, 557 insertions(+) create mode 100644 examples/wave/wave_fno/README.md create mode 100644 examples/wave/wave_fno/config.yaml create mode 100644 examples/wave/wave_fno/requirements.txt create mode 100644 examples/wave/wave_fno/train_fno_wave.py create mode 100644 examples/wave/wave_fno/validator.py create mode 100644 examples/wave/wave_fno/wave_data.py diff --git a/examples/wave/wave_fno/README.md b/examples/wave/wave_fno/README.md new file mode 100644 index 0000000000..283a205282 --- /dev/null +++ b/examples/wave/wave_fno/README.md @@ -0,0 +1,51 @@ +# Fourier Neural Operator for 2D Wave Equation + +This example demonstrates how to train a Fourier Neural Operator (FNO) to learn +the solution operator for the 2D wave equation inside of PhysicsNeMo. + +The wave equation is a fundamental hyperbolic PDE: + +$$\frac{\partial^2 u}{\partial t^2} = c^2 \nabla^2 u$$ + +The FNO learns to map the initial wavefield $u(x, y, 0)$ to the solution at a +later time $u(x, y, T)$. + +Training data is generated on the fly using a leapfrog finite-difference solver +with periodic boundary conditions. + +## Problem Setup + +- **Domain**: $[0, 1]^2$ with periodic boundaries +- **Wave speed**: $c = 1.0$ +- **Initial condition**: Superposition of random Fourier modes +- **Target**: Solution at $T = 0.5$ +- **Resolution**: $128 \times 128$ + +## Prerequisites + +Install the required dependencies by running below: + +```bash +pip install -r requirements.txt +``` + +## Getting Started + +To train the model, run + +```bash +python train_fno_wave.py +``` + +Training data is generated on the fly. + +## Additional Information + +This fills the hyperbolic PDE gap in PhysicsNeMo examples. The existing examples +focus on elliptic (Darcy) and parabolic (Navier-Stokes) problems. Wave equations +are critical for acoustics, seismology, and electromagnetic applications. + +## References + +- [Fourier Neural Operator for Parametric Partial Differential Equations](https://arxiv.org/abs/2010.08895) +- [PDEBench: An Extensive Benchmark for Scientific Machine Learning](https://arxiv.org/abs/2210.07182) diff --git a/examples/wave/wave_fno/config.yaml b/examples/wave/wave_fno/config.yaml new file mode 100644 index 0000000000..291db3227c --- /dev/null +++ b/examples/wave/wave_fno/config.yaml @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +arch: + decoder: + out_features: 1 + layers: 1 + layer_size: 32 + + fno: + in_channels: 1 + dimension: 2 + latent_channels: 32 + fno_layers: 4 + fno_modes: 12 + padding: 9 + +scheduler: + initial_lr: 1.E-3 + decay_rate: .85 + decay_pseudo_epochs: 8 + +training: + resolution: 128 + batch_size: 32 + rec_results_freq: 8 + max_pseudo_epochs: 128 + pseudo_epoch_sample_size: 1024 + +validation: + validation_pseudo_epochs: 4 + sample_size: 128 + +wave: + speed: 1.0 + target_time: 0.5 + nr_modes: 5 + cfl: 0.25 diff --git a/examples/wave/wave_fno/requirements.txt b/examples/wave/wave_fno/requirements.txt new file mode 100644 index 0000000000..001dc23f09 --- /dev/null +++ b/examples/wave/wave_fno/requirements.txt @@ -0,0 +1,2 @@ +hydra-core>=1.2.0 +termcolor>=2.1.1 diff --git a/examples/wave/wave_fno/train_fno_wave.py b/examples/wave/wave_fno/train_fno_wave.py new file mode 100644 index 0000000000..8facf9f003 --- /dev/null +++ b/examples/wave/wave_fno/train_fno_wave.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +from omegaconf import DictConfig +from math import ceil + +from torch.nn import MSELoss +from torch.optim import Adam, lr_scheduler + +from physicsnemo.models.fno import FNO +from physicsnemo.distributed import DistributedManager +from physicsnemo.utils import StaticCaptureTraining, StaticCaptureEvaluateNoGrad +from physicsnemo.utils import load_checkpoint, save_checkpoint +from physicsnemo.utils.logging import PythonLogger, LaunchLogger + +from wave_data import WaveDataLoader +from validator import WaveValidator + + +@hydra.main(version_base="1.3", config_path=".", config_name="config.yaml") +def wave_trainer(cfg: DictConfig) -> None: + """Training for the 2D wave equation benchmark problem. + + This training script demonstrates how to set up a data-driven model for a 2D + wave equation using Fourier Neural Operators (FNO) and acts as a benchmark for + hyperbolic PDE operator learning. Training data is generated on the fly via + leapfrog finite-difference integration. The model learns to map an initial + wavefield u(x, y, 0) to the solution u(x, y, T) at a specified target time. + """ + DistributedManager.initialize() # Only call this once in the entire script! + dist = DistributedManager() # call if required elsewhere + + # initialize monitoring + log = PythonLogger(name="wave_fno") + log.file_logging() + LaunchLogger.initialize() # PhysicsNeMo launch logger + + # define model, loss, optimiser, scheduler, data loader + model = FNO( + in_channels=cfg.arch.fno.in_channels, + out_channels=cfg.arch.decoder.out_features, + decoder_layers=cfg.arch.decoder.layers, + decoder_layer_size=cfg.arch.decoder.layer_size, + dimension=cfg.arch.fno.dimension, + latent_channels=cfg.arch.fno.latent_channels, + num_fno_layers=cfg.arch.fno.fno_layers, + num_fno_modes=cfg.arch.fno.fno_modes, + padding=cfg.arch.fno.padding, + ).to(dist.device) + loss_fun = MSELoss(reduction="mean") + optimizer = Adam(model.parameters(), lr=cfg.scheduler.initial_lr) + scheduler = lr_scheduler.LambdaLR( + optimizer, lr_lambda=lambda step: cfg.scheduler.decay_rate**step + ) + + dataloader = WaveDataLoader( + resolution=cfg.training.resolution, + batch_size=cfg.training.batch_size, + wave_speed=cfg.wave.speed, + target_time=cfg.wave.target_time, + nr_modes=cfg.wave.nr_modes, + cfl=cfg.wave.cfl, + device=dist.device, + ) + validator = WaveValidator(loss_fun=MSELoss(reduction="mean")) + + ckpt_args = { + "path": f"./checkpoints", + "optimizer": optimizer, + "scheduler": scheduler, + "models": model, + } + loaded_pseudo_epoch = load_checkpoint(device=dist.device, **ckpt_args) + + # calculate steps per pseudo epoch + steps_per_pseudo_epoch = ceil( + cfg.training.pseudo_epoch_sample_size / cfg.training.batch_size + ) + validation_iters = ceil(cfg.validation.sample_size / cfg.training.batch_size) + log_args = { + "name_space": "train", + "num_mini_batch": steps_per_pseudo_epoch, + "epoch_alert_freq": 1, + } + if cfg.training.pseudo_epoch_sample_size % cfg.training.batch_size != 0: + log.warning( + f"increased pseudo_epoch_sample_size to multiple of " + f"batch size: {steps_per_pseudo_epoch * cfg.training.batch_size}" + ) + if cfg.validation.sample_size % cfg.training.batch_size != 0: + log.warning( + f"increased validation sample size to multiple of " + f"batch size: {validation_iters * cfg.training.batch_size}" + ) + + # define forward passes for training and inference + @StaticCaptureTraining( + model=model, optim=optimizer, logger=log, use_amp=False, use_graphs=False + ) + def forward_train(invars, target): + pred = model(invars) + loss = loss_fun(pred, target) + return loss + + @StaticCaptureEvaluateNoGrad( + model=model, logger=log, use_amp=False, use_graphs=False + ) + def forward_eval(invars): + return model(invars) + + if loaded_pseudo_epoch == 0: + log.success("Training started...") + else: + log.warning(f"Resuming training from pseudo epoch {loaded_pseudo_epoch + 1}.") + + for pseudo_epoch in range( + max(1, loaded_pseudo_epoch + 1), cfg.training.max_pseudo_epochs + 1 + ): + # Wrap epoch in launch logger for console / MLFlow logs + with LaunchLogger(**log_args, epoch=pseudo_epoch) as logger: + for _, batch in zip(range(steps_per_pseudo_epoch), dataloader): + loss = forward_train(batch["initial"], batch["target"]) + logger.log_minibatch({"loss": loss.detach()}) + logger.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]}) + + # save checkpoint + if pseudo_epoch % cfg.training.rec_results_freq == 0: + save_checkpoint(**ckpt_args, epoch=pseudo_epoch) + + # validation step + if pseudo_epoch % cfg.validation.validation_pseudo_epochs == 0: + with LaunchLogger("valid", epoch=pseudo_epoch) as logger: + total_loss = 0.0 + for _, batch in zip(range(validation_iters), dataloader): + val_loss = validator.compare( + batch["initial"], + batch["target"], + forward_eval(batch["initial"]), + pseudo_epoch, + logger, + ) + total_loss += val_loss + logger.log_epoch( + {"Validation error": total_loss / validation_iters} + ) + + # update learning rate + if pseudo_epoch % cfg.scheduler.decay_pseudo_epochs == 0: + scheduler.step() + + save_checkpoint(**ckpt_args, epoch=cfg.training.max_pseudo_epochs) + log.success("Training completed *yay*") + + +if __name__ == "__main__": + wave_trainer() diff --git a/examples/wave/wave_fno/validator.py b/examples/wave/wave_fno/validator.py new file mode 100644 index 0000000000..b5ac67e668 --- /dev/null +++ b/examples/wave/wave_fno/validator.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import matplotlib.pyplot as plt +from torch import FloatTensor +from physicsnemo.utils.logging import LaunchLogger + + +class WaveValidator: + """Grid Validator for wave equation predictions. + + Compares model prediction against ground truth, computes loss, and logs + a side-by-side visualization of the initial condition, truth, prediction, + and point-wise error. + + Parameters + ---------- + loss_fun : torch.nn.Module + Loss function for validation error + font_size : float, optional + Font size for plots + """ + + def __init__(self, loss_fun, font_size: float = 28.0): + self.criterion = loss_fun + self.font_size = font_size + self.headers = ("initial u(0)", "truth u(T)", "prediction", "abs error") + + def compare( + self, + invar: FloatTensor, + target: FloatTensor, + prediction: FloatTensor, + step: int, + logger: LaunchLogger, + ) -> float: + """Compare prediction to ground truth and log visualization. + + Parameters + ---------- + invar : FloatTensor + Initial condition input + target : FloatTensor + Ground truth solution at time T + prediction : FloatTensor + Model prediction + step : int + Current epoch/step for labeling + logger : LaunchLogger + Logger for figure output + + Returns + ------- + float + Validation loss + """ + loss = self.criterion(prediction, target) + + # Extract first sample for plotting + invar_np = invar.cpu().numpy()[0, 0, :, :] + target_np = target.cpu().numpy()[0, 0, :, :] + pred_np = prediction.detach().cpu().numpy()[0, 0, :, :] + error_np = abs(pred_np - target_np) + + plt.close("all") + plt.rcParams.update({"font.size": self.font_size}) + fig, ax = plt.subplots(1, 4, figsize=(15 * 4, 15), sharey=True) + im = [] + im.append(ax[0].imshow(invar_np, cmap="RdBu_r")) + im.append(ax[1].imshow(target_np, cmap="RdBu_r")) + im.append(ax[2].imshow(pred_np, cmap="RdBu_r")) + im.append(ax[3].imshow(error_np, cmap="hot")) + + for ii in range(len(im)): + fig.colorbar( + im[ii], ax=ax[ii], location="bottom", fraction=0.046, pad=0.04 + ) + ax[ii].set_title(self.headers[ii]) + + logger.log_figure(figure=fig, artifact_file=f"validation_step_{step:03d}.png") + + return loss diff --git a/examples/wave/wave_fno/wave_data.py b/examples/wave/wave_fno/wave_data.py new file mode 100644 index 0000000000..c187d61666 --- /dev/null +++ b/examples/wave/wave_fno/wave_data.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""On-the-fly 2D wave equation data generator using leapfrog finite differences. + +Generates random initial wavefields from a superposition of Fourier modes and +evolves them forward in time using the standard second-order leapfrog scheme +with periodic boundary conditions. +""" + +import numpy as np +import torch + + +def generate_wave_batch( + batch_size: int, + resolution: int, + wave_speed: float = 1.0, + target_time: float = 0.5, + nr_modes: int = 5, + cfl: float = 0.25, + device: str = "cuda", + seed: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Generate a batch of 2D wave equation initial conditions and solutions. + + Parameters + ---------- + batch_size : int + Number of samples to generate + resolution : int + Spatial resolution (NxN grid) + wave_speed : float + Wave propagation speed c + target_time : float + Time at which to evaluate the solution + nr_modes : int + Number of Fourier modes per axis for random initial conditions + cfl : float + CFL number (dt = cfl * dx / c) + device : str + Device to return tensors on + seed : int or None + Random seed for reproducibility + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + (initial_condition, target_solution) each of shape (batch, 1, N, N) + """ + rng = np.random.default_rng(seed) + dx = 1.0 / resolution + dt = cfl * dx / wave_speed + n_steps = int(np.ceil(target_time / dt)) + dt = target_time / n_steps # adjust for exact target time + + # Coordinate grids + x = np.linspace(0, 1, resolution, endpoint=False) + y = np.linspace(0, 1, resolution, endpoint=False) + xx, yy = np.meshgrid(x, y, indexing="ij") + + c2_ratio = (wave_speed * dt / dx) ** 2 + + u0_all = np.zeros((batch_size, resolution, resolution), dtype=np.float32) + uT_all = np.zeros((batch_size, resolution, resolution), dtype=np.float32) + + for b in range(batch_size): + # Random superposition of Fourier modes + u = np.zeros((resolution, resolution), dtype=np.float64) + for _ in range(nr_modes): + kx = rng.integers(-nr_modes, nr_modes + 1) + ky = rng.integers(-nr_modes, nr_modes + 1) + amp = rng.standard_normal() + phase = rng.uniform(0, 2 * np.pi) + u += amp * np.sin( + 2 * np.pi * (kx * xx + ky * yy) + phase + ) + + # Normalize to unit variance + std = np.std(u) + if std > 1e-10: + u /= std + + u0_all[b] = u.astype(np.float32) + + # Leapfrog time integration with zero initial velocity + u_prev = u.copy() + # Taylor expansion for first step: u(dt) = u(0) + 0.5*dt^2*c^2*laplacian(u) + lap = ( + np.roll(u, 1, axis=0) + np.roll(u, -1, axis=0) + + np.roll(u, 1, axis=1) + np.roll(u, -1, axis=1) + - 4.0 * u + ) / dx**2 + u_curr = u + 0.5 * dt**2 * wave_speed**2 * lap + + for _ in range(n_steps - 1): + lap = ( + np.roll(u_curr, 1, axis=0) + np.roll(u_curr, -1, axis=0) + + np.roll(u_curr, 1, axis=1) + np.roll(u_curr, -1, axis=1) + - 4.0 * u_curr + ) / dx**2 + u_next = 2.0 * u_curr - u_prev + dt**2 * wave_speed**2 * lap + u_prev = u_curr + u_curr = u_next + + uT_all[b] = u_curr.astype(np.float32) + + # Convert to tensors: (batch, 1, N, N) + initial = torch.from_numpy(u0_all).unsqueeze(1).to(device) + target = torch.from_numpy(uT_all).unsqueeze(1).to(device) + return initial, target + + +class WaveDataLoader: + """Iterable data loader that generates wave equation samples on the fly. + + Parameters + ---------- + resolution : int + Spatial resolution + batch_size : int + Batch size + wave_speed : float + Wave speed c + target_time : float + Target evolution time T + nr_modes : int + Number of Fourier modes for initial conditions + cfl : float + CFL number for time stepping + normaliser : dict or None + Normalisation parameters {"input": (mean, std), "output": (mean, std)} + device : str + Device for output tensors + """ + + def __init__( + self, + resolution: int = 128, + batch_size: int = 32, + wave_speed: float = 1.0, + target_time: float = 0.5, + nr_modes: int = 5, + cfl: float = 0.25, + normaliser: dict | None = None, + device: str = "cuda", + ): + self.resolution = resolution + self.batch_size = batch_size + self.wave_speed = wave_speed + self.target_time = target_time + self.nr_modes = nr_modes + self.cfl = cfl + self.normaliser = normaliser + self.device = device + + def __iter__(self): + return self + + def __next__(self) -> dict[str, torch.Tensor]: + initial, target = generate_wave_batch( + batch_size=self.batch_size, + resolution=self.resolution, + wave_speed=self.wave_speed, + target_time=self.target_time, + nr_modes=self.nr_modes, + cfl=self.cfl, + device=self.device, + ) + if self.normaliser is not None: + im, isd = self.normaliser.get("input", (0.0, 1.0)) + om, osd = self.normaliser.get("output", (0.0, 1.0)) + initial = (initial - im) / isd + target = (target - om) / osd + return {"initial": initial, "target": target} From 5498025a4fe2f9382cd952cb884482f381573906 Mon Sep 17 00:00:00 2001 From: Greg Partin Date: Thu, 12 Mar 2026 06:27:11 -0700 Subject: [PATCH 2/2] Address review: remove dead code, fix device defaults, add input validation, reuse loss_fun, add seed/reproducibility support, fix f-string, add vectorization comment --- examples/wave/wave_fno/train_fno_wave.py | 4 +-- examples/wave/wave_fno/wave_data.py | 40 +++++++++++++++++++----- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/examples/wave/wave_fno/train_fno_wave.py b/examples/wave/wave_fno/train_fno_wave.py index 8facf9f003..61f453bc3c 100644 --- a/examples/wave/wave_fno/train_fno_wave.py +++ b/examples/wave/wave_fno/train_fno_wave.py @@ -76,10 +76,10 @@ def wave_trainer(cfg: DictConfig) -> None: cfl=cfg.wave.cfl, device=dist.device, ) - validator = WaveValidator(loss_fun=MSELoss(reduction="mean")) + validator = WaveValidator(loss_fun=loss_fun) ckpt_args = { - "path": f"./checkpoints", + "path": "./checkpoints", "optimizer": optimizer, "scheduler": scheduler, "models": model, diff --git a/examples/wave/wave_fno/wave_data.py b/examples/wave/wave_fno/wave_data.py index c187d61666..cc0f527ed8 100644 --- a/examples/wave/wave_fno/wave_data.py +++ b/examples/wave/wave_fno/wave_data.py @@ -32,7 +32,7 @@ def generate_wave_batch( target_time: float = 0.5, nr_modes: int = 5, cfl: float = 0.25, - device: str = "cuda", + device: str | torch.device = "cpu", seed: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Generate a batch of 2D wave equation initial conditions and solutions. @@ -51,8 +51,8 @@ def generate_wave_batch( Number of Fourier modes per axis for random initial conditions cfl : float CFL number (dt = cfl * dx / c) - device : str - Device to return tensors on + device : str or torch.device + Device to return tensors on (default: ``"cpu"``) seed : int or None Random seed for reproducibility @@ -61,6 +61,17 @@ def generate_wave_batch( tuple[torch.Tensor, torch.Tensor] (initial_condition, target_solution) each of shape (batch, 1, N, N) """ + if resolution <= 0: + raise ValueError(f"resolution must be positive, got {resolution}") + if batch_size <= 0: + raise ValueError(f"batch_size must be positive, got {batch_size}") + if wave_speed <= 0: + raise ValueError(f"wave_speed must be positive, got {wave_speed}") + if cfl <= 0: + raise ValueError(f"cfl must be positive, got {cfl}") + if target_time <= 0: + raise ValueError(f"target_time must be positive, got {target_time}") + rng = np.random.default_rng(seed) dx = 1.0 / resolution dt = cfl * dx / wave_speed @@ -72,11 +83,13 @@ def generate_wave_batch( y = np.linspace(0, 1, resolution, endpoint=False) xx, yy = np.meshgrid(x, y, indexing="ij") - c2_ratio = (wave_speed * dt / dx) ** 2 - u0_all = np.zeros((batch_size, resolution, resolution), dtype=np.float32) uT_all = np.zeros((batch_size, resolution, resolution), dtype=np.float32) + # NOTE: The per-sample loop is intentional — each sample draws a different + # random mode set, and the leapfrog time-stepper keeps a small memory + # footprint. For high throughput a fully vectorized or GPU-based solver + # would be preferable, but this keeps the example dependency-free. for b in range(batch_size): # Random superposition of Fourier modes u = np.zeros((resolution, resolution), dtype=np.float64) @@ -143,8 +156,10 @@ class WaveDataLoader: CFL number for time stepping normaliser : dict or None Normalisation parameters {"input": (mean, std), "output": (mean, std)} - device : str - Device for output tensors + device : str or torch.device + Device for output tensors (default: ``"cpu"``) + seed : int or None + Base random seed; incremented each batch for reproducibility """ def __init__( @@ -156,7 +171,8 @@ def __init__( nr_modes: int = 5, cfl: float = 0.25, normaliser: dict | None = None, - device: str = "cuda", + device: str | torch.device = "cpu", + seed: int | None = None, ): self.resolution = resolution self.batch_size = batch_size @@ -166,11 +182,18 @@ def __init__( self.cfl = cfl self.normaliser = normaliser self.device = device + self.seed = seed + self._batch_counter = 0 def __iter__(self): return self def __next__(self) -> dict[str, torch.Tensor]: + batch_seed = None + if self.seed is not None: + batch_seed = self.seed + self._batch_counter + self._batch_counter += 1 + initial, target = generate_wave_batch( batch_size=self.batch_size, resolution=self.resolution, @@ -179,6 +202,7 @@ def __next__(self) -> dict[str, torch.Tensor]: nr_modes=self.nr_modes, cfl=self.cfl, device=self.device, + seed=batch_seed, ) if self.normaliser is not None: im, isd = self.normaliser.get("input", (0.0, 1.0))