From 36ad0bd78aa9f872ef993738bc5ce6e07e12f75c Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Wed, 19 Nov 2025 17:53:45 -0800 Subject: [PATCH 01/11] implement EddyFormer --- examples/cfd/isotropic_eddyformer/README.md | 95 +++++++ examples/cfd/isotropic_eddyformer/config.yaml | 23 ++ .../isotropic_eddyformer/download_dataset.sh | 1 + .../cfd/isotropic_eddyformer/requirements.txt | 2 + .../train_ef_isotropic.py | 110 +++++++++ physicsnemo/models/eddyformer/__init__.py | 5 + physicsnemo/models/eddyformer/_basis.py | 112 +++++++++ physicsnemo/models/eddyformer/_datatype.py | 233 ++++++++++++++++++ physicsnemo/models/eddyformer/eddyformer.py | 181 ++++++++++++++ physicsnemo/models/eddyformer/sem_attn.py | 74 ++++++ physicsnemo/models/eddyformer/sem_conv.py | 150 +++++++++++ pyproject.toml | 1 + 12 files changed, 987 insertions(+) create mode 100644 examples/cfd/isotropic_eddyformer/README.md create mode 100644 examples/cfd/isotropic_eddyformer/config.yaml create mode 100644 examples/cfd/isotropic_eddyformer/download_dataset.sh create mode 100644 examples/cfd/isotropic_eddyformer/requirements.txt create mode 100644 examples/cfd/isotropic_eddyformer/train_ef_isotropic.py create mode 100644 physicsnemo/models/eddyformer/__init__.py create mode 100644 physicsnemo/models/eddyformer/_basis.py create mode 100644 physicsnemo/models/eddyformer/_datatype.py create mode 100644 physicsnemo/models/eddyformer/eddyformer.py create mode 100644 physicsnemo/models/eddyformer/sem_attn.py create mode 100644 physicsnemo/models/eddyformer/sem_conv.py diff --git a/examples/cfd/isotropic_eddyformer/README.md b/examples/cfd/isotropic_eddyformer/README.md new file mode 100644 index 0000000000..1e22dd9485 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/README.md @@ -0,0 +1,95 @@ +# EddyFormer for 3D Isotropic Turbulence + +This example demonstrates how to use the EddyFormer model for simulating +a three-dimensional isotropic turbulence. This example runs on a single GPU. + +## Problem Overview + +This example focuses on **three-dimensional homogeneous isotropic turbulence (HIT)** sustained by large-scale forcing. The flow is governed by the incompressible Navier–Stokes equations with an external forcing term: + +\[ +\frac{\partial \mathbf{u}}{\partial t} + \mathbf{u} \cdot \nabla \mathbf{u} += \nu \nabla^2 \mathbf{u} + \mathbf{f}(\mathbf{x}) +\] + +where: + +- **\(\mathbf{u}(\mathbf{x}, t)\)** — velocity field in a 3D periodic domain +- **\(\nu = 0.01\)** — kinematic viscosity +- **\(\mathbf{f}(\mathbf{x})\)** — isotropic forcing applied at the largest scales + +### Forcing Mechanism + +To maintain statistically steady turbulence, a **constant-power forcing** is applied to the lowest Fourier modes (\(|\mathbf{k}| \le 1\)). The forcing injects a prescribed amount of energy \(P_{\text{in}} = 1.0\) into the system: + +\[ +\mathbf{f}(\mathbf{x}) = +\frac{P_{\text{in}}}{E_1} +\sum_{\substack{|\mathbf{k}| \le 1 \\ \mathbf{k} \neq 0}} +\hat{\mathbf{u}}_{\mathbf{k}} e^{i \mathbf{k} \cdot \mathbf{x}} +\] + +where: + +\[ +E_1 = \frac{1}{2} +\sum_{|\mathbf{k}| \le 1} +\hat{\mathbf{u}}_{\mathbf{k}} \cdot \hat{\mathbf{u}}_{\mathbf{k}}^{*} +\] + +is the kinetic energy contained in the forced low-wavenumber modes. + +Under this forcing, the flow reaches a **statistically steady state** with a Taylor-scale Reynolds number of: + +**\(\mathrm{Re}_\lambda \approx 94\)** + +### Task Description + +The objective of this example is to **predict the future velocity field** of the turbulent flow. Given \(\mathbf{u}(\mathbf{x}, t)\), the task is: + +> **Predict the velocity field \(\mathbf{u}(\mathbf{x}, t + \Delta t)\) with \(\Delta t = 0.5\).** + +This requires modeling nonlinear, chaotic, multi-scale turbulent dynamics, including: + +- energy injection at large scales +- nonlinear transfer across the inertial range +- dissipation at the smallest scales + +### Dataset Summary + +- **DNS resolution:** \(384^3\) (used to generate the dataset) +- **Stored dataset resolution:** \(96^3\) +- **Kolmogorov scale resolution:** ~0.5 η +- **Forcing:** applied to modes with \(|\mathbf{k}| \le 1\) +- **Viscosity:** \(\nu = 0.01\) +- **Input power:** \(P_{\text{in}} = 1.0\) +- **Flow regime:** statistically steady HIT at \(\mathrm{Re}_\lambda \approx 94\) + +## Prerequisites + +Install the required dependencies by running below: + +```bash +pip install -r requirements.txt +``` + +## Download the Dataset + +The dataset is publicly available at [Huggingface](https://huggingface.co/datasets/ydu11/re94). +To download the dataset, run (you might need to install the Huggingface CLI): + +```bash +bash download_dataset.sh +``` + +## Getting Started + +To train the model, run + +```bash +python train_ef_isotropic.py +``` + +## References + +- [EddyFormer: EddyFormer: Accelerated Neural Simulations of Three-Dimensional Turbulence at Scale](https://arxiv.org/abs/2510.24173) diff --git a/examples/cfd/isotropic_eddyformer/config.yaml b/examples/cfd/isotropic_eddyformer/config.yaml new file mode 100644 index 0000000000..e7018f54d0 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/config.yaml @@ -0,0 +1,23 @@ +model: + idim: 3 + odim: 3 + hdim: 32 + num_layers: 4 + layer_config: + basis: legendre + mesh: [8, 8, 8] + mode: [10, 10, 10] + mode_les: [5, 5, 5] + kernel_size: [2, 2, 2] + kernel_size_les: [2, 2, 2] + ffn_dim: 128 + activation: GELU + num_heads: 4 + heads_dim: 32 + +training: + dataset: data/ns3d-re94 + t: 0.5 + batch_size: 4 + num_epochs: 100 + learning_rate: 1e-3 diff --git a/examples/cfd/isotropic_eddyformer/download_dataset.sh b/examples/cfd/isotropic_eddyformer/download_dataset.sh new file mode 100644 index 0000000000..7b50328c92 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/download_dataset.sh @@ -0,0 +1 @@ +hf download --repo-type dataset ydu11/re94 --local-dir ${1:-data/ns3d-re94} \ No newline at end of file diff --git a/examples/cfd/isotropic_eddyformer/requirements.txt b/examples/cfd/isotropic_eddyformer/requirements.txt new file mode 100644 index 0000000000..001dc23f09 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/requirements.txt @@ -0,0 +1,2 @@ +hydra-core>=1.2.0 +termcolor>=2.1.1 diff --git a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py new file mode 100644 index 0000000000..6546d20ca7 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py @@ -0,0 +1,110 @@ +import hydra +from typing import Tuple +from torch import Tensor +from omegaconf import DictConfig + +import os +import numpy as np + +import torch +from torch.nn import MSELoss +from torch.optim import Adam +from torch.utils.data import Dataset, DataLoader + +from physicsnemo.models.eddyformer import EddyFormer, EddyFormerConfig +from physicsnemo.distributed import DistributedManager +from physicsnemo.utils import StaticCaptureTraining +from physicsnemo.launch.logging import PythonLogger, LaunchLogger + + +class Re94(Dataset): + + root: str + t: float + + n: int = 50 + dt: float = 0.1 + + def __init__(self, root: str, split: str, *, t: float = 0.5) -> None: + """ + """ + super().__init__() + self.root = root + self.t = t + + self.file = [] + for fname in sorted(os.listdir(root)): + if fname.startswith(split): + self.file.append(fname) + + @property + def stride(self) -> int: + k = int(self.t / self.dt) + assert self.dt * k == self.t + return k + + @property + def samples_per_file(self) -> int: + return self.n - self.stride + 1 + + def __len__(self) -> int: + return len(self.file) * self.samples_per_file + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]: + file_idx, time_idx = divmod(idx, self.samples_per_file) + + data = np.load(f"{self.root}/{self.file[file_idx]}", allow_pickle=True).item() + return torch.from_numpy(data["u"][time_idx]), torch.from_numpy(data["u"][time_idx + self.stride]) + +@hydra.main(version_base="1.3", config_path=".", config_name="config.yaml") +def isotropic_trainer(cfg: DictConfig) -> None: + """ + """ + DistributedManager.initialize() # Only call this once in the entire script! + dist = DistributedManager() # call if required elsewhere + + # initialize monitoring + log = PythonLogger(name="re94_ef") + log.file_logging() + LaunchLogger.initialize() # PhysicsNeMo launch logger + + # define model, loss, optimiser, scheduler, data loader + model = EddyFormer( + idim=cfg.model.idim, + odim=cfg.model.odim, + hdim=cfg.model.hdim, + num_layers=cfg.model.num_layers, + cfg=EddyFormerConfig(**cfg.model.layer_config), + ).to(dist.device) + loss_fun = MSELoss(reduction="mean") + optimizer = Adam(model.parameters(), lr=cfg.training.learning_rate) + dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t) + + # define forward passes for training and inference + @StaticCaptureTraining( + model=model, optim=optimizer, logger=log, use_amp=False, use_graphs=False + ) + def training_step(input, target): + pred = torch.vmap(model)(input) + loss = loss_fun(pred, target) + return loss + + for epoch in range(cfg.training.num_epochs): + + dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) + + for input, target in dataloader: + + input = input.to(dist.device) + target = target.to(dist.device) + with torch.autograd.set_detect_anomaly(True): + loss = training_step(input, target) + + with LaunchLogger("train", epoch=epoch) as logger: + logger.log_minibatch({"Training loss": loss.item()}) + + log.success("Training completed") + + +if __name__ == "__main__": + isotropic_trainer() diff --git a/physicsnemo/models/eddyformer/__init__.py b/physicsnemo/models/eddyformer/__init__.py new file mode 100644 index 0000000000..db0569fda6 --- /dev/null +++ b/physicsnemo/models/eddyformer/__init__.py @@ -0,0 +1,5 @@ +from ._basis import Legendre +from ._datatype import SEM +from .eddyformer import EddyFormer, EddyFormerLayer + +EddyFormerConfig = EddyFormerLayer.Config diff --git a/physicsnemo/models/eddyformer/_basis.py b/physicsnemo/models/eddyformer/_basis.py new file mode 100644 index 0000000000..e3906ca529 --- /dev/null +++ b/physicsnemo/models/eddyformer/_basis.py @@ -0,0 +1,112 @@ +from typing import Protocol +from torch import Tensor + +import torch +import torch.nn as nn + +import numpy as np +import functools + +class Basis(Protocol): + + grid: Tensor + quad: Tensor + + m: int + f: Tensor + + def fn(self, xs: Tensor) -> Tensor: + """ + Evaluate basis functions at given points. + """ + + def at(self, coef: Tensor, xs: Tensor) -> Tensor: + """ + Evaluate basis expansion at given points. + """ + return torch.tensordot(self.fn(xs), coef, dims=1) + + def modal(self, vals: Tensor) -> Tensor: + """ + Convert nodal values to modal coefficients. + """ + + def nodal(self, coef: Tensor) -> Tensor: + """ + Convert modal coefficients to nodal values. + """ + +class Element(Basis): + + def __init__(self, base: Basis): + """ + """ + +# ---------------------------------------------------------------------------- # +# LEGENDRE # +# ---------------------------------------------------------------------------- # + +from numpy.polynomial import legendre + +@functools.cache +class Legendre(nn.Module, Basis): + + """ + Shifted Legendre polynomials: + - `(1 - x^2) Pn''(x) - 2 x Pn(x) + n (n + 1) Pn(x) = 0` + - `Pn^~(x) = Pn(2 x - 1)` + """ + + def extra_repr(self) -> str: + return f"m={self.m}" + + def __init__(self, m: int, endpoint: bool = False): + """ + """ + super().__init__() + self.m = m + + if endpoint: m -= 1 + c = (0, ) * m + (1, ) + dc = legendre.legder(c) + + x = legendre.legroots(dc if endpoint else c) + y = legendre.legval(x, c if endpoint else dc) + + if endpoint: + x = np.concatenate([[-1], x, [1]]) + y = np.concatenate([[1], y, [1]]) + + w = 1 / y ** 2 + if endpoint: w /= m * (m + 1) + else: w /= 1 - x ** 2 + + self.register_buffer("grid", torch.tensor((1 + x) / 2, dtype=torch.float)) + self.register_buffer("quad", torch.tensor(w, dtype=torch.float)) + + self.register_buffer("f", self.fn(self.grid)) + + def fn(self, xs: Tensor) -> Tensor: + """ + """ + P = torch.ones_like(xs), 2 * xs - 1 + + for i in range(2, self.m): + a, b = (i * 2 - 1) / i, (i - 1) / i + P += a * P[-1] * P[1] - b * P[-2], + + return torch.stack(P, dim=-1) + +# --------------------------------- TRANSFORM -------------------------------- # + + def modal(self, vals: Tensor) -> Tensor: + """ + """ + norm = 2 * torch.arange(self.m, device=vals.device) + 1 + coef = self.f * norm * self.quad[:, None] + return torch.tensordot(coef.T, vals, dims=1) + + def nodal(self, coef: Tensor) -> Tensor: + """ + """ + return self.at(coef, self.grid) diff --git a/physicsnemo/models/eddyformer/_datatype.py b/physicsnemo/models/eddyformer/_datatype.py new file mode 100644 index 0000000000..ea1e5514bf --- /dev/null +++ b/physicsnemo/models/eddyformer/_datatype.py @@ -0,0 +1,233 @@ +from typing import Tuple +from torch import Tensor + +import torch +import torch.nn.functional as F + +from dataclasses import dataclass, replace +from functools import cached_property + +from ._basis import Basis, Legendre + +def interp1d(value: Tensor, xs: Tensor, method: str) -> Tensor: + """ + Interpolate from 1D regular grid to a target points. + + Args: + value: Values on a uniform grid along the first axis. + xs: Resolution or an array normalized by the domain size. + method: Interpolation method. One of "fft", "linear", or + f"lag{n}" for n-point Lagrangian interpolation. + """ + if method == "fft": + coef = torch.fft.rfft(value, dim=0, norm="forward") + + k = 2 * torch.pi * torch.arange(len(coef)) + f = torch.exp(1j * k * xs[..., None]); f[..., 1:-1] *= 2 + return torch.tensordot(f.real, coef.real, dims=1) \ + - torch.tensordot(f.imag, coef.imag, dims=1) + + if method.startswith("lag"): + n_points = int(method[3:]) + + assert n_points % 2 == 0 + r = n_points // 2 - 1 + + n = len(value) + + i = (xs * (N := n - 1)).int() + i = torch.clip(i, r, n - n_points + r) + + # 1. pad the input grid + + v_pad = value, value[:r+2] + + if r > 0: v_pad = (value[-r:], ) + v_pad + value = torch.concatenate(v_pad, dim=0) + + # 2. construct polynomials + + out = 0 + + for j in range(n_points): + lag = value[i + j] + + for k in range(n_points): + if j == k: continue + fac = xs - (i + k - r) / N + while fac.ndim < lag.ndim: + fac = fac.unsqueeze(-1) + lag *= fac * N / (j - k) + + out += lag + return out + + raise ValueError(f"invalid interpolation {method=}") + +# ---------------------------------------------------------------------------- # +# SPECTRAL ELEMENT # +# ---------------------------------------------------------------------------- # + +@dataclass +class SEM: + + """ + Spectral element expansion. The sub-domain partition is + given by the `mesh` attribute. The spectral coefficients + of each element is stored in the first channel dimension, + whose size must equal to the number of elements. + """ + + T_: str + + # Mesh + + size: Tensor + mesh: Tuple[int] + + # Data + + mode_: Tuple[int] = None + nodal: Tensor = None + + @property + def ndim(self) -> int: + return len(self.mesh) + + @property + def mode(self) -> Tuple[int]: + if self.mode_: return self.mode_ + return self.nodal.shape[:self.ndim] + + @property + def use_elem(self) -> bool: + return self.T_.endswith("elem") + + @staticmethod + def basis(T_: str) -> Basis: + if T_.startswith("leg"): return Legendre + raise ValueError(f"invalid basis {T_=}") + + @cached_property + def T(self) -> Tuple[Basis]: + """ + Basis on each dimension. + """ + T = self.basis(self.T_) + return tuple(map(T, self.mode)) + + def to(self, mode: Tuple[int]) -> "SEM": + """ + Resample to another mode. + + Args: + mode: Number of modes. + """ + out = SEM(self.T_, self.size, self.mesh, mode) + + value = self.nodal + for n in range(self.ndim): + coef = self.T[n].modal(value) + if (pad:=out.mode[n] - self.mode[n]) <= 0: coef = coef[:mode[n]] + else: coef = torch.concat([coef, torch.zeros(pad, *coef.shape[1:], device=coef.device)], dim=0) + value = out.T[n].nodal(coef).movedim(0, self.ndim - 1) + + return out.new(value) + + def at(self, *xs: Tensor, uniform: bool = False) -> Tensor: + """ + Evaluate on rectilinear grids. + + Args: + xs: Coordinate of each dimension. + uniform: Whether `xs` are uniformly spaced. + """ + value = self.nodal + for n in range(self.ndim): + x = xs[n] / self.size[n] + coef = self.T[n].modal(value) + + # indices of each global coordinate `x` + idx = torch.floor(x * float(self.mesh[n])).int() + idx = torch.minimum(idx, torch.tensor(self.mesh[n] - 1)) + + # global coordinate to local coordinate + ys = x * float(self.mesh[n]) - torch.arange(self.mesh[n], device=x.device)[idx] + + if not uniform: + + # coefficients where each `x` belongs + coef = coef.movedim(self.ndim, 0)[idx] + + # evaluate at each coordonate and move the output axis to the last dimension + # after `ndim` iterations, the axes are automatically rolled to the correct order + value = torch.vmap(self.T[n].at, out_dims=self.ndim - 1)(coef, ys) + + else: + + # coordinates within each element + ys = ys.reshape(self.mesh[n], -1) + + # batched evaluation of all coordinates + value = torch.vmap(self.T[n].at, (self.ndim, 0))(coef, ys) + value = torch.movedim(value.flatten(end_dim=1), 0, self.ndim - 1) + + return value + +# ---------------------------------- COORDS ---------------------------------- # + + @cached_property + def grid(self) -> Tensor: + axes = [self.T[n].grid.to(self.size.device) for n in range(self.ndim)] + return torch.stack(torch.meshgrid(*axes, indexing="ij"), dim=-1) + + @cached_property + def coords(self) -> Tensor: + local = self.grid + for _ in range(self.ndim): + local = local.unsqueeze(self.ndim) + return self.origins + local * self.lengths + + @cached_property + def origins(self) -> Tensor: + left = [torch.arange(m, device=self.size.device) / m for m in self.mesh] + return torch.stack(torch.meshgrid(*left, indexing="ij"), dim=-1) * self.size + + @cached_property + def lengths(self) -> Tensor: + ns = torch.tensor(self.mesh, device=self.size.device) + return self.size / ns.float() + +# --------------------------------- DATA TYPE -------------------------------- # + + def new(self, nodal: Tensor) -> "SEM": + assert nodal.shape[:self.ndim] == self.mode + return replace(self, mode_=None, nodal=nodal) + + def eval(self, resolution: Tuple[int]) -> Tensor: + xs = [torch.linspace(0, s, n, device=self.size.device) for n, s in zip(resolution, self.size)] + return self.at(*xs, uniform=False) + + def from_grid(self, value: Tensor, method: str) -> "SEM": + """ + Interpolate grid values to a target datatype. + + Args: + out: Target datatype. + method: Interpolation method along each axis. + See `interp1d::method` for details. + """ + xs = self.coords / self.size + for n in range(self.ndim): + + # interpolate at each collocation points. `idx` is the + # index of the elements along the `n`'th dimension. + idx = [slice(None) if i == n else 0 for i in range(self.ndim)] + value = interp1d(value, xs[tuple(idx * 2 + [n])], method) + + # roll the output. The interpolated values have shape `(mode, mesh)`, + # which are moved to the middle (`ndim - 1`) and the end (`ndim + n`) of + # the dimensions. After `ndim` iterations, all axes are ordered correctly. + value = torch.moveaxis(value, (0, 1), (self.ndim - 1, self.ndim + n)) + + return self.new(value) diff --git a/physicsnemo/models/eddyformer/eddyformer.py b/physicsnemo/models/eddyformer/eddyformer.py new file mode 100644 index 0000000000..569eb95857 --- /dev/null +++ b/physicsnemo/models/eddyformer/eddyformer.py @@ -0,0 +1,181 @@ +from typing import Tuple, Union +from torch import Tensor + +import torch +import torch.nn as nn + +from dataclasses import dataclass +from functools import partial + +from ..module import Module +from ..meta import ModelMetaData +from ..layers.mlp_layers import Mlp + +from ._datatype import SEM +from .sem_conv import SEMConv +from .sem_attn import SEMAttn + +# Layer + +class EddyFormerLayer(nn.Module): + + @dataclass + class Config: + + basis: str + mesh: Tuple[int] + mode: Tuple[int] + + # SGS STREAM + kernel_size: Tuple[int] + + ffn_dim: int + activation: str + + # LES STREAM + mode_les: Tuple[int] + kernel_size_les: Tuple[int] + + num_heads: int + heads_dim: int + + @property + def ffn(self) -> partial[Mlp]: + return partial(Mlp, + hidden_features=self.ffn_dim, + act_layer=getattr(nn, self.activation), + ) + + @property + def attn(self) -> partial[SEMAttn]: + return partial(SEMAttn, + mode=self.mode_les, + num_heads=self.num_heads, + heads_dim=self.heads_dim, + ) + + def conv(self, stream: str) -> partial[SEMConv]: + return partial(SEMConv, + kernel_mode=(mode:=self.mode if stream == "sgs" else self.mode_les), + kernel_size=self.kernel_size if stream == "sgs" else self.kernel_size_les, + T=tuple(map(SEM.basis(self.basis), mode)), + ) + + def __init__(self, hdim: int, cfg: Config, *, layer_scale: float = 1e-7): + """ + EddyFormer layer. + """ + super().__init__() + + self.mode = cfg.mode + self.mode_les = cfg.mode_les + + self.eps = nn.Parameter(torch.ones(hdim) * layer_scale) + self.ffn_les, self.ffn_sgs = cfg.ffn(hdim), cfg.ffn(hdim) + + self.sem_conv_sgs = cfg.conv("sgs")(hdim, hdim) + self.sem_conv_les = cfg.conv("les")(hdim, hdim) + self.sem_attn = cfg.attn(hdim, hdim, conv=cfg.conv("les")) + + def __call__(self, les: SEM, sgs: SEM) -> Tuple[SEM, SEM]: + """ + """ + les.nodal = les.nodal + self.sem_attn(les).nodal + les.nodal = les.nodal + self.ffn_les(self.sem_conv_les(les).nodal) + + sgs.nodal = sgs.nodal + self.eps * les.to(self.mode).nodal + sgs.nodal = sgs.nodal + self.ffn_sgs(self.sem_conv_sgs(sgs).nodal) + + return les, sgs + +# Model + +@dataclass +class MetaData(ModelMetaData): + name: str = "EddyFormer" + # Optimization + jit: bool = True + cuda_graphs: bool = True + amp: bool = False + # Inference + onnx_cpu: bool = False + onnx_gpu: bool = False + onnx_runtime: bool = False + # Physics informed + var_dim: int = 1 + func_torch: bool = False + auto_grad: bool = False + +class EddyFormer(Module): + + cfg: EddyFormerLayer.Config + + lift_les: nn.Linear + lift_sgs: nn.Linear + + layers: nn.ModuleList + + proj_les: Mlp + proj_sgs: Mlp + + scale: nn.Parameter + + def __init__(self, + idim: int, + odim: int, + hdim: int, + num_layers: int, + cfg: EddyFormerLayer.Config): + """ + EddyFormer model. + """ + super().__init__(meta=MetaData()) + + self.cfg = cfg + self.ndim = len(cfg.mesh) + + self.lift_les = nn.Linear(idim + self.ndim, hdim) + self.lift_sgs = nn.Linear(idim + self.ndim, hdim) + + self.layers = nn.ModuleList() + for _ in range(num_layers): + layer = EddyFormerLayer(hdim, cfg) + self.layers.append(layer) + + self.proj_les = cfg.ffn(hdim, out_features=odim) + self.proj_sgs = cfg.ffn(hdim, out_features=odim) + + self.scale = nn.Parameter(torch.zeros(odim)) + + def __call__(self, input: Union[SEM, Tensor], return_sem: bool = False) -> Union[SEM, Tensor]: + """ + """ + if isinstance(input, Tensor): + size = 2 * torch.pi * torch.ones(self.ndim, device=input.device) + ϕ = SEM(self.cfg.basis, size, self.cfg.mesh, self.cfg.mode) \ + .from_grid(input, "lag8") # default interpolation method + else: + ϕ = input + + x = ϕ.grid.to(ϕ.nodal) + for n, mesh in enumerate(ϕ.mesh): + x = x.unsqueeze(dim:=self.ndim + n) + x = torch.repeat_interleave(x, mesh, dim) + x = torch.concatenate(torch.broadcast_tensors(ϕ.nodal, x), dim=-1) + + sgs = ϕ.new(x) + les = sgs.to(self.cfg.mode_les) + + sgs.nodal = self.lift_sgs(sgs.nodal) + les.nodal = self.lift_les(les.nodal) + + for layer in self.layers: + les, sgs = layer(les, sgs) + + sgs.nodal = self.proj_sgs(sgs.nodal) + les.nodal = self.proj_les(les.nodal) + + out = ϕ.new(les.to(ϕ.mode).nodal + sgs.nodal) + if not return_sem: out = out.eval(input.shape[:-1]) + + return out diff --git a/physicsnemo/models/eddyformer/sem_attn.py b/physicsnemo/models/eddyformer/sem_attn.py new file mode 100644 index 0000000000..a9b5cc3674 --- /dev/null +++ b/physicsnemo/models/eddyformer/sem_attn.py @@ -0,0 +1,74 @@ +from typing import Tuple +from torch import Tensor + +import torch +import torch.nn as nn + +from functools import partial + +from ._datatype import SEM +from .sem_conv import SEMConv + +class SEMAttn(nn.Module): + + proj: nn.ModuleDict + bias: nn.ParameterDict + norm: nn.ModuleDict + + out: nn.Linear + + def __init__(self, + idim: int, + odim: int, + mode: Tuple[int], + num_heads: int, + heads_dim: int, + *, + conv: partial[SEMConv], + bias_init = torch.zeros): + """ + """ + super().__init__() + + self.proj = nn.ModuleDict() + self.bias = nn.ParameterDict() + self.norm = nn.ModuleDict() + + for name in "QKV": + self.proj[name] = conv(idim, (num_heads, heads_dim)) + + for n in range(len(mode)): + self.bias[f"{name}{n}"] = nn.Parameter(bias_init((num_heads, heads_dim))) + self.norm[f"{name}{n}"] = nn.LayerNorm(heads_dim) + + self.out = nn.Linear(num_heads * heads_dim * len(mode), odim) + + def project(self, ϕ: SEM, name: str) -> Tensor: + """ + Project the features to attention space. + """ + xs = [] + + for n in range(ϕ.ndim): + x = self.proj[name].factor(ϕ, n).nodal + + if name in ["Q", "K"]: + x = x + self.bias[f"{name}{n}"] + + f, g = torch.split(self.norm[f"{name}{n}"](x), x.shape[-1] // 2, dim=-1) + k = ϕ.coords[..., None, [n]] * torch.arange(f.shape[-1], device=x.device) + + f, g = torch.cos(k) * f - torch.sin(k) * g, torch.sin(k) * f + torch.cos(k) * g + x = torch.concatenate([torch.cos(k) + f, torch.sin(k) + g], dim=-1) + + xs.append(x.reshape(ϕ.mode + (-1, ) + x.shape[-2:])) + return torch.concatenate(xs, dim=-1) + + def __call__(self, ϕ: SEM) -> SEM: + """ + Self-attention on SEM features. + """ + q, k, v = (self.project(ϕ, name) for name in "QKV") + + attn = nn.functional.scaled_dot_product_attention(q, k, v) + return ϕ.new(self.out(attn.reshape(*ϕ.mode, *ϕ.mesh, -1))) diff --git a/physicsnemo/models/eddyformer/sem_conv.py b/physicsnemo/models/eddyformer/sem_conv.py new file mode 100644 index 0000000000..2dc4c9a42f --- /dev/null +++ b/physicsnemo/models/eddyformer/sem_conv.py @@ -0,0 +1,150 @@ +from typing import Tuple, Union +from torch import Tensor + +import torch +import torch.nn as nn + +import numpy as np +from functools import partial, cache +from scipy import integrate + +from ._basis import Basis +from ._datatype import SEM + +class SEMConv(nn.Module): + + odim: Tuple[int] + kernel: nn.ParameterList + + def __init__(self, + idim: int, + odim: Union[int, Tuple[int]], + T: Tuple[Basis], + kernel_mode: Tuple[int], + kernel_size: Tuple[int], + kernel_init_std: float = 1e-7): + """ + """ + super().__init__() + self.T = nn.ModuleList(T) + + if isinstance(odim, int): + self.odim = (odim, ) + else: + self.odim = odim + odim = np.prod(odim) + + self.kernel = nn.ParameterList() + for n, (m, s) in enumerate(zip(kernel_mode, kernel_size)): + self.kernel.append(nn.Parameter(coef:=torch.empty(s * m, idim, odim))) + + torch.nn.init.normal_(coef, std=kernel_init_std) + self.register_buffer(f"ws_{n}", weight(T[n], s)) + + def factor(self, ϕ: SEM, dim: int) -> SEM: + """ + Factorized SEM convolution. + + Args: + ϕ: Input SEM feature field. + dim: Dimension to convolve over. + """ + coef, ws = self.kernel[dim], getattr(self, f"ws_{dim}") + out = sem_conv(ϕ.nodal, coef, ws, T=ϕ.T[dim], ndim=ϕ.ndim, dim=dim) + return ϕ.new(out.reshape(out.shape[:-1] + self.odim)) + + def __call__(self, ϕ: SEM) -> SEM: + return ϕ.new(sum(self.factor(ϕ, n).nodal for n in range(ϕ.ndim))) + +# ---------------------------------------------------------------------------- # +# CONVOLUTION # +# ---------------------------------------------------------------------------- # + +def kernel(coef: Tensor, xs: Tensor) -> Tensor: + """ + Evaluate the Fourier kernel. + + Args: + coef: Fourier coefficients. + xs: Query coordinates. + """ + r, i_ = torch.split(coef, (m:=(n:=len(coef)) // 2 + 1, n - m)) + i = torch.zeros_like(r); i[1:n-m+1] = torch.flip(i_, dims=[0]) + + k = 2 * torch.pi * torch.arange(m, device=xs.device) + f = torch.exp(1j * k * xs[..., None]); f[..., 1:-1] *= 2 + + return torch.tensordot(f.real, r, 1) \ + - torch.tensordot(f.imag, i, 1) + +@cache +def weight(T: Basis, s: int, use_mp: bool = True) -> Tensor: + """ + """ + print(f"Pre-computing weights for `{T=}` and `{s=}`...") + + eps = torch.finfo(torch.float).eps + ab = T.grid[..., None] + torch.tensor([-s/2, s/2]) + + map_ = map + if use_mp: + from concurrent.futures import ThreadPoolExecutor + map_ = (pool := ThreadPoolExecutor()).map + + def quad(T: Basis, m: int, a: float, b: float) -> Tensor: + f = lambda x: T.fn(torch.tensor(x))[m] + y, e = integrate.quad(f, a, b) + return y + + ws = [] + for i in range(-s//2, s//2 + 1): + ws.append(w:=[]) + + from tqdm import tqdm + for a, b in tqdm(ab, f"{i=}"): + a = torch.clip(a - i, -eps, 1 + eps) + b = torch.clip(b - i, -eps, 1 + eps) + + q = torch.tensor(list(map_(partial(quad, T, a=a, b=b), range(T.m)))) + w.append(torch.linalg.solve(T.fn(T.grid).T, q).tolist()) + + if use_mp: pool.shutdown() + return torch.tensor(ws) + +def sem_conv(nodal: Tensor, coef: Tensor, ws: Tensor, *, T: Basis, ndim: int, dim: int): + """ + Args: + w: An (s + 1, m, n) array where s is the window size, m is the number + of quadrature nodes, and n is the number of output nodes. + """ + n = ndim + dim # mesh dim + + ns = "".join(map(chr, range(120, 120 + ndim))) + ms = ns.replace(i:=ns[dim], o:=ns[dim].upper()) + + pad_r = nodal.index_select(n, torch.arange(0, r:=len(ws)//2, device=nodal.device)) + pad_l = nodal.index_select(n, torch.arange(nodal.shape[n]-r, nodal.shape[n], device=nodal.device)) + + # pad_r = torch.slice_copy(nodal, n, 0, r:=len(ws)//2) + # pad_l = torch.slice_copy(nodal, n, -r, end=None) + + f = torch.concatenate([pad_l, nodal, pad_r], dim=n) + out = [] + # out = torch.zeros(*nodal.shape[:-1], coef.shape[-1], device=nodal.device) + + for k, w in enumerate(ws): + + x = T.grid + k - r + xy = T.grid[:, None] - x + + fx = torch.narrow(f, n, k, nodal.shape[n]) + gxy = kernel(coef, xy / (len(ws) - 1)) + + eqn = f"{ns}...i, {o}{i}io, {o}{i} -> {ms}...o" + # print(f"{eqn}: {tuple(fx.shape)}, {tuple(gxy.shape)}, {tuple(w.shape)}") + + # print(out.shape, torch.einsum(eqn, fx, gxy, w).shape) + # out += torch.einsum(eqn, fx, gxy, w) + out.append(torch.einsum(eqn, fx, gxy, w)) + + return sum(out) diff --git a/pyproject.toml b/pyproject.toml index 5415416427..af58333f31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ all = [ "ruamel.yaml>=0.17.22", "scikit-learn>=1.0.2", "scikit-image>=0.24.0", + "scipy>=1.15.0", "warp-lang>=1.0", "vtk>=9.2.6", "pyvista>=0.40.1", From 9cb780a57d488ee247761d96669eabdf1a81bbef Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Thu, 20 Nov 2025 21:15:50 -0800 Subject: [PATCH 02/11] fix format issue --- examples/cfd/isotropic_eddyformer/README.md | 2 +- .../isotropic_eddyformer/download_dataset.sh | 2 +- .../isotropic_eddyformer/train_ef_isotropic.py | 18 +++++++++++------- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/examples/cfd/isotropic_eddyformer/README.md b/examples/cfd/isotropic_eddyformer/README.md index 1e22dd9485..027b668367 100644 --- a/examples/cfd/isotropic_eddyformer/README.md +++ b/examples/cfd/isotropic_eddyformer/README.md @@ -92,4 +92,4 @@ python train_ef_isotropic.py ## References -- [EddyFormer: EddyFormer: Accelerated Neural Simulations of Three-Dimensional Turbulence at Scale](https://arxiv.org/abs/2510.24173) +- [EddyFormer: Accelerated Neural Simulations of Three-Dimensional Turbulence at Scale](https://arxiv.org/abs/2510.24173) diff --git a/examples/cfd/isotropic_eddyformer/download_dataset.sh b/examples/cfd/isotropic_eddyformer/download_dataset.sh index 7b50328c92..52da8b034d 100644 --- a/examples/cfd/isotropic_eddyformer/download_dataset.sh +++ b/examples/cfd/isotropic_eddyformer/download_dataset.sh @@ -1 +1 @@ -hf download --repo-type dataset ydu11/re94 --local-dir ${1:-data/ns3d-re94} \ No newline at end of file +hf download --repo-type dataset ydu11/re94 --local-dir ${1:-data/ns3d-re94} diff --git a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py index 6546d20ca7..a433e60bc1 100644 --- a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py +++ b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py @@ -68,7 +68,7 @@ def isotropic_trainer(cfg: DictConfig) -> None: log.file_logging() LaunchLogger.initialize() # PhysicsNeMo launch logger - # define model, loss, optimiser, scheduler, data loader + # define model, loss, optimizer model = EddyFormer( idim=cfg.model.idim, odim=cfg.model.odim, @@ -78,11 +78,18 @@ def isotropic_trainer(cfg: DictConfig) -> None: ).to(dist.device) loss_fun = MSELoss(reduction="mean") optimizer = Adam(model.parameters(), lr=cfg.training.learning_rate) + + # define dataset and dataloader dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t) + dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) - # define forward passes for training and inference + # define forward passes for training @StaticCaptureTraining( - model=model, optim=optimizer, logger=log, use_amp=False, use_graphs=False + model=model, + optim=optimizer, + logger=log, + use_amp=False, + use_graphs=False ) def training_step(input, target): pred = torch.vmap(model)(input) @@ -91,14 +98,11 @@ def training_step(input, target): for epoch in range(cfg.training.num_epochs): - dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) - for input, target in dataloader: input = input.to(dist.device) target = target.to(dist.device) - with torch.autograd.set_detect_anomaly(True): - loss = training_step(input, target) + loss = training_step(input, target) with LaunchLogger("train", epoch=epoch) as logger: logger.log_minibatch({"Training loss": loss.item()}) From a4d7a6505353cbcb7cb3cf8fdc79c1241ed93605 Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Thu, 20 Nov 2025 21:18:03 -0800 Subject: [PATCH 03/11] verify rope dimension --- physicsnemo/models/eddyformer/sem_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/physicsnemo/models/eddyformer/sem_attn.py b/physicsnemo/models/eddyformer/sem_attn.py index a9b5cc3674..2294d0c47b 100644 --- a/physicsnemo/models/eddyformer/sem_attn.py +++ b/physicsnemo/models/eddyformer/sem_attn.py @@ -54,6 +54,7 @@ def project(self, ϕ: SEM, name: str) -> Tensor: if name in ["Q", "K"]: x = x + self.bias[f"{name}{n}"] + assert x.shape[-1] % 2 == 0 f, g = torch.split(self.norm[f"{name}{n}"](x), x.shape[-1] // 2, dim=-1) k = ϕ.coords[..., None, [n]] * torch.arange(f.shape[-1], device=x.device) From 6bf561773c6cbd598a719cb58c460a22f64fdc87 Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Sat, 22 Nov 2025 22:12:02 -0800 Subject: [PATCH 04/11] fix device and docstring --- physicsnemo/models/eddyformer/_datatype.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/physicsnemo/models/eddyformer/_datatype.py b/physicsnemo/models/eddyformer/_datatype.py index ea1e5514bf..e2a248f7d4 100644 --- a/physicsnemo/models/eddyformer/_datatype.py +++ b/physicsnemo/models/eddyformer/_datatype.py @@ -2,7 +2,6 @@ from torch import Tensor import torch -import torch.nn.functional as F from dataclasses import dataclass, replace from functools import cached_property @@ -22,7 +21,7 @@ def interp1d(value: Tensor, xs: Tensor, method: str) -> Tensor: if method == "fft": coef = torch.fft.rfft(value, dim=0, norm="forward") - k = 2 * torch.pi * torch.arange(len(coef)) + k = 2 * torch.pi * torch.arange(len(coef), device=xs.device) f = torch.exp(1j * k * xs[..., None]); f[..., 1:-1] *= 2 return torch.tensordot(f.real, coef.real, dims=1) \ - torch.tensordot(f.imag, coef.imag, dims=1) @@ -149,7 +148,7 @@ def at(self, *xs: Tensor, uniform: bool = False) -> Tensor: # indices of each global coordinate `x` idx = torch.floor(x * float(self.mesh[n])).int() - idx = torch.minimum(idx, torch.tensor(self.mesh[n] - 1)) + idx = torch.minimum(idx, torch.tensor(self.mesh[n] - 1, device=idx.device)) # global coordinate to local coordinate ys = x * float(self.mesh[n]) - torch.arange(self.mesh[n], device=x.device)[idx] @@ -159,7 +158,7 @@ def at(self, *xs: Tensor, uniform: bool = False) -> Tensor: # coefficients where each `x` belongs coef = coef.movedim(self.ndim, 0)[idx] - # evaluate at each coordonate and move the output axis to the last dimension + # evaluate at each coordinate and move the output axis to the last dimension # after `ndim` iterations, the axes are automatically rolled to the correct order value = torch.vmap(self.T[n].at, out_dims=self.ndim - 1)(coef, ys) @@ -213,7 +212,7 @@ def from_grid(self, value: Tensor, method: str) -> "SEM": Interpolate grid values to a target datatype. Args: - out: Target datatype. + value: Input tensor (include boundary points). method: Interpolation method along each axis. See `interp1d::method` for details. """ From 5ca494cbaf93b72acf182735b4ed0c7413135037 Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Sat, 22 Nov 2025 22:13:49 -0800 Subject: [PATCH 05/11] fix import and remove comments --- physicsnemo/models/eddyformer/sem_conv.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/physicsnemo/models/eddyformer/sem_conv.py b/physicsnemo/models/eddyformer/sem_conv.py index 2dc4c9a42f..fa51e5006c 100644 --- a/physicsnemo/models/eddyformer/sem_conv.py +++ b/physicsnemo/models/eddyformer/sem_conv.py @@ -3,10 +3,11 @@ import torch import torch.nn as nn - import numpy as np + from functools import partial, cache from scipy import integrate +from tqdm import tqdm from ._basis import Basis from ._datatype import SEM @@ -100,7 +101,6 @@ def quad(T: Basis, m: int, a: float, b: float) -> Tensor: for i in range(-s//2, s//2 + 1): ws.append(w:=[]) - from tqdm import tqdm for a, b in tqdm(ab, f"{i=}"): a = torch.clip(a - i, -eps, 1 + eps) b = torch.clip(b - i, -eps, 1 + eps) @@ -125,12 +125,8 @@ def sem_conv(nodal: Tensor, coef: Tensor, ws: Tensor, *, T: Basis, ndim: int, di pad_r = nodal.index_select(n, torch.arange(0, r:=len(ws)//2, device=nodal.device)) pad_l = nodal.index_select(n, torch.arange(nodal.shape[n]-r, nodal.shape[n], device=nodal.device)) - # pad_r = torch.slice_copy(nodal, n, 0, r:=len(ws)//2) - # pad_l = torch.slice_copy(nodal, n, -r, end=None) - f = torch.concatenate([pad_l, nodal, pad_r], dim=n) out = [] - # out = torch.zeros(*nodal.shape[:-1], coef.shape[-1], device=nodal.device) for k, w in enumerate(ws): @@ -141,10 +137,6 @@ def sem_conv(nodal: Tensor, coef: Tensor, ws: Tensor, *, T: Basis, ndim: int, di gxy = kernel(coef, xy / (len(ws) - 1)) eqn = f"{ns}...i, {o}{i}io, {o}{i} -> {ms}...o" - # print(f"{eqn}: {tuple(fx.shape)}, {tuple(gxy.shape)}, {tuple(w.shape)}") - - # print(out.shape, torch.einsum(eqn, fx, gxy, w).shape) - # out += torch.einsum(eqn, fx, gxy, w) out.append(torch.einsum(eqn, fx, gxy, w)) return sum(out) From b839de8f708321a2db05a54376dd33f1cc17385b Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Mon, 24 Nov 2025 12:30:22 -0800 Subject: [PATCH 06/11] use ddp; change to rel l2 loss; add checkpointing --- examples/cfd/isotropic_eddyformer/config.yaml | 5 ++- .../train_ef_isotropic.py | 45 ++++++++++++++----- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/examples/cfd/isotropic_eddyformer/config.yaml b/examples/cfd/isotropic_eddyformer/config.yaml index e7018f54d0..8c0198c4d1 100644 --- a/examples/cfd/isotropic_eddyformer/config.yaml +++ b/examples/cfd/isotropic_eddyformer/config.yaml @@ -3,6 +3,7 @@ model: odim: 3 hdim: 32 num_layers: 4 + use_scale: true layer_config: basis: legendre mesh: [8, 8, 8] @@ -17,7 +18,9 @@ model: training: dataset: data/ns3d-re94 + result_dir: outputs/ef-re94 t: 0.5 batch_size: 4 - num_epochs: 100 + num_epochs: 1 learning_rate: 1e-3 + ckpt_every: 1000 diff --git a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py index a433e60bc1..d246b41883 100644 --- a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py +++ b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py @@ -7,13 +7,14 @@ import numpy as np import torch -from torch.nn import MSELoss from torch.optim import Adam from torch.utils.data import Dataset, DataLoader +from torch.nn.parallel import DistributedDataParallel from physicsnemo.models.eddyformer import EddyFormer, EddyFormerConfig from physicsnemo.distributed import DistributedManager from physicsnemo.utils import StaticCaptureTraining +from physicsnemo.launch.utils import save_checkpoint from physicsnemo.launch.logging import PythonLogger, LaunchLogger @@ -65,25 +66,43 @@ def isotropic_trainer(cfg: DictConfig) -> None: # initialize monitoring log = PythonLogger(name="re94_ef") - log.file_logging() + log.file_logging(f"{cfg.training.result_dir}/log.txt") LaunchLogger.initialize() # PhysicsNeMo launch logger - # define model, loss, optimizer + # define model and optimizer model = EddyFormer( idim=cfg.model.idim, odim=cfg.model.odim, hdim=cfg.model.hdim, num_layers=cfg.model.num_layers, + use_scale=cfg.model.use_scale, cfg=EddyFormerConfig(**cfg.model.layer_config), ).to(dist.device) - loss_fun = MSELoss(reduction="mean") + + if dist.distributed: + ddps = torch.cuda.Stream() + with torch.cuda.stream(ddps): + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=dist.find_unused_parameters, + ) + torch.cuda.current_stream().wait_stream(ddps) + log.success("Initialized DDP training") + optimizer = Adam(model.parameters(), lr=cfg.training.learning_rate) # define dataset and dataloader dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t) dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) - # define forward passes for training + # define relative l2 error as the loss function + def loss_fun(pred: Tensor, target: Tensor) -> Tensor: + return torch.linalg.norm(pred - target) / torch.linalg.norm(target) + + # define training step @StaticCaptureTraining( model=model, optim=optimizer, @@ -91,14 +110,16 @@ def isotropic_trainer(cfg: DictConfig) -> None: use_amp=False, use_graphs=False ) - def training_step(input, target): + def training_step(input: Tensor, target: Tensor) -> Tensor: pred = torch.vmap(model)(input) - loss = loss_fun(pred, target) - return loss + loss = torch.vmap(loss_fun)(pred, target) + return torch.mean(loss) - for epoch in range(cfg.training.num_epochs): + it = 0 + log.info("Training started") - for input, target in dataloader: + for epoch in range(cfg.training.num_epochs): + for it, (input, target) in enumerate(dataloader, it): input = input.to(dist.device) target = target.to(dist.device) @@ -107,7 +128,11 @@ def training_step(input, target): with LaunchLogger("train", epoch=epoch) as logger: logger.log_minibatch({"Training loss": loss.item()}) + if it and it % cfg.training.ckpt_every == 0 and dist.rank == 0: + save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer, epoch=it) + log.success("Training completed") + save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer) if __name__ == "__main__": From ff2947c02f0d441b6e7582c8304cf1d9eb7f49e9 Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Mon, 24 Nov 2025 12:33:51 -0800 Subject: [PATCH 07/11] switch to physicsnemo.Module; add use_scale; separate EddyFormerConfig class --- physicsnemo/models/eddyformer/__init__.py | 4 +- physicsnemo/models/eddyformer/eddyformer.py | 111 ++++++++++++-------- physicsnemo/models/eddyformer/sem_attn.py | 10 +- physicsnemo/models/eddyformer/sem_conv.py | 3 +- 4 files changed, 76 insertions(+), 52 deletions(-) diff --git a/physicsnemo/models/eddyformer/__init__.py b/physicsnemo/models/eddyformer/__init__.py index db0569fda6..63144343ba 100644 --- a/physicsnemo/models/eddyformer/__init__.py +++ b/physicsnemo/models/eddyformer/__init__.py @@ -1,5 +1,3 @@ from ._basis import Legendre from ._datatype import SEM -from .eddyformer import EddyFormer, EddyFormerLayer - -EddyFormerConfig = EddyFormerLayer.Config +from .eddyformer import EddyFormer, EddyFormerConfig diff --git a/physicsnemo/models/eddyformer/eddyformer.py b/physicsnemo/models/eddyformer/eddyformer.py index 569eb95857..ec65c3a5ff 100644 --- a/physicsnemo/models/eddyformer/eddyformer.py +++ b/physicsnemo/models/eddyformer/eddyformer.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Tuple, Union, Optional from torch import Tensor import torch @@ -15,53 +15,72 @@ from .sem_conv import SEMConv from .sem_attn import SEMAttn -# Layer - -class EddyFormerLayer(nn.Module): +class EddyFormerConfig(Module): - @dataclass - class Config: + basis: str + mesh: Tuple[int] + mode: Tuple[int] - basis: str - mesh: Tuple[int] - mode: Tuple[int] + # SGS STREAM + kernel_size: Tuple[int] - # SGS STREAM - kernel_size: Tuple[int] + ffn_dim: int + activation: str - ffn_dim: int - activation: str + # LES STREAM + mode_les: Tuple[int] + kernel_size_les: Tuple[int] - # LES STREAM - mode_les: Tuple[int] - kernel_size_les: Tuple[int] + num_heads: int + heads_dim: int - num_heads: int - heads_dim: int + def __init__(self, basis: str, mesh: Tuple[int], mode: Tuple[int], + kernel_size: Tuple[int], ffn_dim: int, activation: str, + mode_les: Tuple[int], kernel_size_les: Tuple[int], num_heads: int, heads_dim: int): + """ + """ + super().__init__() - @property - def ffn(self) -> partial[Mlp]: - return partial(Mlp, - hidden_features=self.ffn_dim, - act_layer=getattr(nn, self.activation), - ) + self.basis = basis + self.mesh = mesh + self.mode = mode + + self.kernel_size = kernel_size + self.ffn_dim = ffn_dim + self.activation = activation + + self.mode_les = mode_les + self.kernel_size_les = kernel_size_les + self.num_heads = num_heads + self.heads_dim = heads_dim + + @property + def ffn(self) -> partial[Mlp]: + return partial(Mlp, + hidden_features=self.ffn_dim, + act_layer=getattr(nn, self.activation), + ) + + @property + def attn(self) -> partial[SEMAttn]: + return partial(SEMAttn, + mode=self.mode_les, + num_heads=self.num_heads, + heads_dim=self.heads_dim, + ) + + def conv(self, stream: str) -> partial[SEMConv]: + return partial(SEMConv, + kernel_mode=(mode:=self.mode if stream == "sgs" else self.mode_les), + kernel_size=self.kernel_size if stream == "sgs" else self.kernel_size_les, + T=tuple(map(SEM.basis(self.basis), mode)), + ) - @property - def attn(self) -> partial[SEMAttn]: - return partial(SEMAttn, - mode=self.mode_les, - num_heads=self.num_heads, - heads_dim=self.heads_dim, - ) +# Layer - def conv(self, stream: str) -> partial[SEMConv]: - return partial(SEMConv, - kernel_mode=(mode:=self.mode if stream == "sgs" else self.mode_les), - kernel_size=self.kernel_size if stream == "sgs" else self.kernel_size_les, - T=tuple(map(SEM.basis(self.basis), mode)), - ) +class EddyFormerLayer(Module): - def __init__(self, hdim: int, cfg: Config, *, layer_scale: float = 1e-7): + def __init__(self, hdim: int, cfg: EddyFormerConfig, *, layer_scale: float = 1e-7): """ EddyFormer layer. """ @@ -108,7 +127,7 @@ class MetaData(ModelMetaData): class EddyFormer(Module): - cfg: EddyFormerLayer.Config + cfg: EddyFormerConfig lift_les: nn.Linear lift_sgs: nn.Linear @@ -118,14 +137,16 @@ class EddyFormer(Module): proj_les: Mlp proj_sgs: Mlp - scale: nn.Parameter + scale: Optional[nn.Parameter] def __init__(self, idim: int, odim: int, hdim: int, num_layers: int, - cfg: EddyFormerLayer.Config): + *, + use_scale: bool = True, + cfg: EddyFormerConfig): """ EddyFormer model. """ @@ -145,7 +166,7 @@ def __init__(self, self.proj_les = cfg.ffn(hdim, out_features=odim) self.proj_sgs = cfg.ffn(hdim, out_features=odim) - self.scale = nn.Parameter(torch.zeros(odim)) + self.scale = nn.Parameter(torch.zeros(odim)) if use_scale else None def __call__(self, input: Union[SEM, Tensor], return_sem: bool = False) -> Union[SEM, Tensor]: """ @@ -175,7 +196,9 @@ def __call__(self, input: Union[SEM, Tensor], return_sem: bool = False) -> Union sgs.nodal = self.proj_sgs(sgs.nodal) les.nodal = self.proj_les(les.nodal) - out = ϕ.new(les.to(ϕ.mode).nodal + sgs.nodal) - if not return_sem: out = out.eval(input.shape[:-1]) + scale = self.scale if self.scale is not None else 1. + out = ϕ.new(les.to(ϕ.mode).nodal + scale * sgs.nodal) + if not return_sem: + out = out.eval(input.shape[:-1]) return out diff --git a/physicsnemo/models/eddyformer/sem_attn.py b/physicsnemo/models/eddyformer/sem_attn.py index 2294d0c47b..261a6ea8c5 100644 --- a/physicsnemo/models/eddyformer/sem_attn.py +++ b/physicsnemo/models/eddyformer/sem_attn.py @@ -6,10 +6,11 @@ from functools import partial +from ..module import Module from ._datatype import SEM from .sem_conv import SEMConv -class SEMAttn(nn.Module): +class SEMAttn(Module): proj: nn.ModuleDict bias: nn.ParameterDict @@ -37,9 +38,10 @@ def __init__(self, for name in "QKV": self.proj[name] = conv(idim, (num_heads, heads_dim)) - for n in range(len(mode)): - self.bias[f"{name}{n}"] = nn.Parameter(bias_init((num_heads, heads_dim))) - self.norm[f"{name}{n}"] = nn.LayerNorm(heads_dim) + if name in ["Q", "K"]: + for n in range(len(mode)): + self.bias[f"{name}{n}"] = nn.Parameter(bias_init((num_heads, heads_dim))) + self.norm[f"{name}{n}"] = nn.LayerNorm(heads_dim) self.out = nn.Linear(num_heads * heads_dim * len(mode), odim) diff --git a/physicsnemo/models/eddyformer/sem_conv.py b/physicsnemo/models/eddyformer/sem_conv.py index fa51e5006c..6d8b55d021 100644 --- a/physicsnemo/models/eddyformer/sem_conv.py +++ b/physicsnemo/models/eddyformer/sem_conv.py @@ -9,10 +9,11 @@ from scipy import integrate from tqdm import tqdm +from ..module import Module from ._basis import Basis from ._datatype import SEM -class SEMConv(nn.Module): +class SEMConv(Module): odim: Tuple[int] kernel: nn.ParameterList From 822f59e525144d69dfd44439b37b64da395d5130 Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Mon, 8 Dec 2025 11:31:10 -0800 Subject: [PATCH 08/11] add amp --- examples/cfd/isotropic_eddyformer/config.yaml | 1 + .../train_ef_isotropic.py | 16 +++++++++++-- physicsnemo/models/eddyformer/eddyformer.py | 24 +++++++++---------- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/examples/cfd/isotropic_eddyformer/config.yaml b/examples/cfd/isotropic_eddyformer/config.yaml index 8c0198c4d1..3f63e51353 100644 --- a/examples/cfd/isotropic_eddyformer/config.yaml +++ b/examples/cfd/isotropic_eddyformer/config.yaml @@ -20,6 +20,7 @@ training: dataset: data/ns3d-re94 result_dir: outputs/ef-re94 t: 0.5 + amp: false batch_size: 4 num_epochs: 1 learning_rate: 1e-3 diff --git a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py index d246b41883..d20be89ef3 100644 --- a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py +++ b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py @@ -8,6 +8,7 @@ import torch from torch.optim import Adam +from torch.cuda.amp import GradScaler, autocast from torch.utils.data import Dataset, DataLoader from torch.nn.parallel import DistributedDataParallel @@ -76,7 +77,18 @@ def isotropic_trainer(cfg: DictConfig) -> None: hdim=cfg.model.hdim, num_layers=cfg.model.num_layers, use_scale=cfg.model.use_scale, - cfg=EddyFormerConfig(**cfg.model.layer_config), + cfg=EddyFormerConfig( + basis=cfg.model.layer_config.basis, + mesh=tuple(cfg.model.layer_config.mesh), + mode=tuple(cfg.model.layer_config.mode), + mode_les=tuple(cfg.model.layer_config.mode_les), + kernel_size=tuple(cfg.model.layer_config.kernel_size), + kernel_size_les=tuple(cfg.model.layer_config.kernel_size_les), + ffn_dim=cfg.model.layer_config.ffn_dim, + activation=cfg.model.layer_config.activation, + num_heads=cfg.model.layer_config.num_heads, + heads_dim=cfg.model.layer_config.heads_dim, + ), ).to(dist.device) if dist.distributed: @@ -107,7 +119,7 @@ def loss_fun(pred: Tensor, target: Tensor) -> Tensor: model=model, optim=optimizer, logger=log, - use_amp=False, + use_amp=cfg.training.amp, use_graphs=False ) def training_step(input: Tensor, target: Tensor) -> Tensor: diff --git a/physicsnemo/models/eddyformer/eddyformer.py b/physicsnemo/models/eddyformer/eddyformer.py index ec65c3a5ff..b599d8e933 100644 --- a/physicsnemo/models/eddyformer/eddyformer.py +++ b/physicsnemo/models/eddyformer/eddyformer.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, Optional +from typing import Tuple, Union, Sequence, Optional from torch import Tensor import torch @@ -34,23 +34,23 @@ class EddyFormerConfig(Module): num_heads: int heads_dim: int - def __init__(self, basis: str, mesh: Tuple[int], mode: Tuple[int], - kernel_size: Tuple[int], ffn_dim: int, activation: str, - mode_les: Tuple[int], kernel_size_les: Tuple[int], num_heads: int, heads_dim: int): + def __init__(self, basis: str, mesh: Sequence[int], mode: Sequence[int], + kernel_size: Sequence[int], ffn_dim: int, activation: str, + mode_les: Sequence[int], kernel_size_les: Sequence[int], num_heads: int, heads_dim: int): """ """ super().__init__() self.basis = basis - self.mesh = mesh - self.mode = mode + self.mesh = tuple(mesh) + self.mode = tuple(mode) - self.kernel_size = kernel_size + self.kernel_size = tuple(kernel_size) self.ffn_dim = ffn_dim self.activation = activation - self.mode_les = mode_les - self.kernel_size_les = kernel_size_les + self.mode_les = tuple(mode_les) + self.kernel_size_les = tuple(kernel_size_les) self.num_heads = num_heads self.heads_dim = heads_dim @@ -110,12 +110,12 @@ def __call__(self, les: SEM, sgs: SEM) -> Tuple[SEM, SEM]: # Model @dataclass -class MetaData(ModelMetaData): +class EddyFormerMetaData(ModelMetaData): name: str = "EddyFormer" # Optimization jit: bool = True cuda_graphs: bool = True - amp: bool = False + amp: bool = True # Inference onnx_cpu: bool = False onnx_gpu: bool = False @@ -150,7 +150,7 @@ def __init__(self, """ EddyFormer model. """ - super().__init__(meta=MetaData()) + super().__init__(meta=EddyFormerMetaData()) self.cfg = cfg self.ndim = len(cfg.mesh) From 1e15a49ca760e5882ec40c03eee4e2f884a17ce0 Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Mon, 8 Dec 2025 17:32:13 -0800 Subject: [PATCH 09/11] implement C0-Legendre SEM basis --- examples/cfd/isotropic_eddyformer/config.yaml | 4 +- physicsnemo/models/eddyformer/_basis.py | 47 ++++++++++++++----- physicsnemo/models/eddyformer/_datatype.py | 7 ++- 3 files changed, 43 insertions(+), 15 deletions(-) diff --git a/examples/cfd/isotropic_eddyformer/config.yaml b/examples/cfd/isotropic_eddyformer/config.yaml index 3f63e51353..9ffea00988 100644 --- a/examples/cfd/isotropic_eddyformer/config.yaml +++ b/examples/cfd/isotropic_eddyformer/config.yaml @@ -5,7 +5,7 @@ model: num_layers: 4 use_scale: true layer_config: - basis: legendre + basis: leg_elem mesh: [8, 8, 8] mode: [10, 10, 10] mode_les: [5, 5, 5] @@ -18,7 +18,7 @@ model: training: dataset: data/ns3d-re94 - result_dir: outputs/ef-re94 + result_dir: outputs/ef-leg-re94 t: 0.5 amp: false batch_size: 4 diff --git a/physicsnemo/models/eddyformer/_basis.py b/physicsnemo/models/eddyformer/_basis.py index e3906ca529..99b77323e1 100644 --- a/physicsnemo/models/eddyformer/_basis.py +++ b/physicsnemo/models/eddyformer/_basis.py @@ -36,19 +36,12 @@ def nodal(self, coef: Tensor) -> Tensor: Convert modal coefficients to nodal values. """ -class Element(Basis): - - def __init__(self, base: Basis): - """ - """ - # ---------------------------------------------------------------------------- # # LEGENDRE # # ---------------------------------------------------------------------------- # from numpy.polynomial import legendre -@functools.cache class Legendre(nn.Module, Basis): """ @@ -60,13 +53,14 @@ class Legendre(nn.Module, Basis): def extra_repr(self) -> str: return f"m={self.m}" + @functools.cache def __init__(self, m: int, endpoint: bool = False): """ """ super().__init__() self.m = m - if endpoint: m -= 1 + if endpoint: m += 1 c = (0, ) * m + (1, ) dc = legendre.legder(c) @@ -100,13 +94,44 @@ def fn(self, xs: Tensor) -> Tensor: # --------------------------------- TRANSFORM -------------------------------- # def modal(self, vals: Tensor) -> Tensor: - """ - """ norm = 2 * torch.arange(self.m, device=vals.device) + 1 coef = self.f * norm * self.quad[:, None] return torch.tensordot(coef.T, vals, dims=1) def nodal(self, coef: Tensor) -> Tensor: + return Legendre.at(self, coef, self.grid) + +class LegendreSEM(Legendre): + + def __init__(self, m: int): + super().__init__(m - 2, True) + + xs = self.grid[:, torch.newaxis] + mat = super().modal(self.f * xs * (1 - xs)) + self.register_buffer("inv", torch.linalg.inv(mat)) + + def at(self, coef: Tensor, xs: Tensor) -> Tensor: + vals = super().at(coef[2:], xs) + while xs.ndim < vals.ndim: + xs = xs.unsqueeze(-1) + return coef[0] * (1 - xs) + coef[1] * xs + vals * xs * (1 - xs) + + def modal(self, vals: Tensor) -> Tensor: """ """ - return self.at(coef, self.grid) + xs = self.grid + for _ in range(vals.ndim - 1): + xs = xs.unsqueeze(-1) + + coef = super().modal(vals - (1 - xs) * vals[0] - xs * vals[-1]) + return torch.concat([vals[[0, -1], ...], torch.tensordot(self.inv, coef, 1)], axis=0) + + def nodal(self, coef: Tensor) -> Tensor: + """ + """ + xs = self.grid + for _ in range(coef.ndim - 1): + xs = xs.unsqueeze(-1) + + vals = xs * (1 - xs) * super().nodal(coef[2:]) + return coef[0] * (1 - xs) + coef[1] * xs + vals \ No newline at end of file diff --git a/physicsnemo/models/eddyformer/_datatype.py b/physicsnemo/models/eddyformer/_datatype.py index e2a248f7d4..4cc895d38b 100644 --- a/physicsnemo/models/eddyformer/_datatype.py +++ b/physicsnemo/models/eddyformer/_datatype.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, replace from functools import cached_property -from ._basis import Basis, Legendre +from ._basis import Basis, Legendre, LegendreSEM def interp1d(value: Tensor, xs: Tensor, method: str) -> Tensor: """ @@ -104,7 +104,10 @@ def use_elem(self) -> bool: @staticmethod def basis(T_: str) -> Basis: - if T_.startswith("leg"): return Legendre + if T_.startswith("leg"): + if T_.endswith("elem"): + return LegendreSEM + return Legendre raise ValueError(f"invalid basis {T_=}") @cached_property From 83a28a72f15b604bca0afb28a67fb02d2b96b86d Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Tue, 9 Dec 2025 14:55:25 -0800 Subject: [PATCH 10/11] fix C0 Legendre SEM basis --- examples/cfd/isotropic_eddyformer/config.yaml | 2 +- .../cfd/isotropic_eddyformer/train_ef_isotropic.py | 3 +++ physicsnemo/models/eddyformer/_basis.py | 14 +++++++------- physicsnemo/models/eddyformer/_datatype.py | 5 ++++- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/examples/cfd/isotropic_eddyformer/config.yaml b/examples/cfd/isotropic_eddyformer/config.yaml index 9ffea00988..64d998a74b 100644 --- a/examples/cfd/isotropic_eddyformer/config.yaml +++ b/examples/cfd/isotropic_eddyformer/config.yaml @@ -7,7 +7,7 @@ model: layer_config: basis: leg_elem mesh: [8, 8, 8] - mode: [10, 10, 10] + mode: [13, 13, 13] mode_les: [5, 5, 5] kernel_size: [2, 2, 2] kernel_size_les: [2, 2, 2] diff --git a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py index d20be89ef3..0420e1dc88 100644 --- a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py +++ b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py @@ -110,6 +110,9 @@ def isotropic_trainer(cfg: DictConfig) -> None: dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t) dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) + testset = Re94(root=cfg.training.dataset, split="test", t=cfg.training.t) + testloader = DataLoader(testset) + # define relative l2 error as the loss function def loss_fun(pred: Tensor, target: Tensor) -> Tensor: return torch.linalg.norm(pred - target) / torch.linalg.norm(target) diff --git a/physicsnemo/models/eddyformer/_basis.py b/physicsnemo/models/eddyformer/_basis.py index 99b77323e1..9967b5c605 100644 --- a/physicsnemo/models/eddyformer/_basis.py +++ b/physicsnemo/models/eddyformer/_basis.py @@ -24,7 +24,8 @@ def at(self, coef: Tensor, xs: Tensor) -> Tensor: """ Evaluate basis expansion at given points. """ - return torch.tensordot(self.fn(xs), coef, dims=1) + mat = self.fn(xs)[..., :len(coef)] + return torch.tensordot(mat, coef, dims=1) def modal(self, vals: Tensor) -> Tensor: """ @@ -53,14 +54,13 @@ class Legendre(nn.Module, Basis): def extra_repr(self) -> str: return f"m={self.m}" - @functools.cache def __init__(self, m: int, endpoint: bool = False): """ """ super().__init__() self.m = m - if endpoint: m += 1 + if endpoint: m -= 1 c = (0, ) * m + (1, ) dc = legendre.legder(c) @@ -104,11 +104,11 @@ def nodal(self, coef: Tensor) -> Tensor: class LegendreSEM(Legendre): def __init__(self, m: int): - super().__init__(m - 2, True) + super().__init__(m, True) xs = self.grid[:, torch.newaxis] - mat = super().modal(self.f * xs * (1 - xs)) - self.register_buffer("inv", torch.linalg.inv(mat)) + mat = super().modal(self.f[:, :-2] * xs * (1 - xs)) + self.register_buffer("inv", torch.linalg.inv(mat[:-2])) def at(self, coef: Tensor, xs: Tensor) -> Tensor: vals = super().at(coef[2:], xs) @@ -123,7 +123,7 @@ def modal(self, vals: Tensor) -> Tensor: for _ in range(vals.ndim - 1): xs = xs.unsqueeze(-1) - coef = super().modal(vals - (1 - xs) * vals[0] - xs * vals[-1]) + coef = super().modal(vals - (1 - xs) * vals[0] - xs * vals[-1])[:-2] return torch.concat([vals[[0, -1], ...], torch.tensordot(self.inv, coef, 1)], axis=0) def nodal(self, coef: Tensor) -> Tensor: diff --git a/physicsnemo/models/eddyformer/_datatype.py b/physicsnemo/models/eddyformer/_datatype.py index 4cc895d38b..2988df9c78 100644 --- a/physicsnemo/models/eddyformer/_datatype.py +++ b/physicsnemo/models/eddyformer/_datatype.py @@ -4,10 +4,13 @@ import torch from dataclasses import dataclass, replace -from functools import cached_property +from functools import cache, cached_property from ._basis import Basis, Legendre, LegendreSEM +Legendre = cache(Legendre) +LegendreSEM = cache(LegendreSEM) + def interp1d(value: Tensor, xs: Tensor, method: str) -> Tensor: """ Interpolate from 1D regular grid to a target points. From 5b0e426aaadd1bd87f99b8582b6642197dcffd7a Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Mon, 15 Dec 2025 18:09:33 -0800 Subject: [PATCH 11/11] add test every 100 steps; add torch.compile; support amp training --- examples/cfd/isotropic_eddyformer/config.yaml | 2 + .../train_ef_isotropic.py | 68 ++++++++++++++++--- physicsnemo/models/eddyformer/_basis.py | 11 ++- physicsnemo/models/eddyformer/sem_attn.py | 3 + 4 files changed, 67 insertions(+), 17 deletions(-) diff --git a/examples/cfd/isotropic_eddyformer/config.yaml b/examples/cfd/isotropic_eddyformer/config.yaml index 64d998a74b..4dc303a76b 100644 --- a/examples/cfd/isotropic_eddyformer/config.yaml +++ b/examples/cfd/isotropic_eddyformer/config.yaml @@ -21,7 +21,9 @@ training: result_dir: outputs/ef-leg-re94 t: 0.5 amp: false + compile: true batch_size: 4 num_epochs: 1 learning_rate: 1e-3 + test_every: 100 ckpt_every: 1000 diff --git a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py index 0420e1dc88..454a7b17c8 100644 --- a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py +++ b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py @@ -1,24 +1,29 @@ import hydra +from tqdm import tqdm + from typing import Tuple from torch import Tensor from omegaconf import DictConfig import os +import collections import numpy as np import torch from torch.optim import Adam -from torch.cuda.amp import GradScaler, autocast from torch.utils.data import Dataset, DataLoader from torch.nn.parallel import DistributedDataParallel from physicsnemo.models.eddyformer import EddyFormer, EddyFormerConfig from physicsnemo.distributed import DistributedManager -from physicsnemo.utils import StaticCaptureTraining +from physicsnemo.utils import StaticCaptureTraining, StaticCaptureEvaluateNoGrad from physicsnemo.launch.utils import save_checkpoint from physicsnemo.launch.logging import PythonLogger, LaunchLogger +def rel_l2(pred: Tensor, target: Tensor) -> Tensor: + return torch.linalg.norm(pred - target) / torch.linalg.norm(target) + class Re94(Dataset): root: str @@ -27,13 +32,17 @@ class Re94(Dataset): n: int = 50 dt: float = 0.1 - def __init__(self, root: str, split: str, *, t: float = 0.5) -> None: + def __init__(self, root: str, split: str, *, t: float = 0.5, + n: int = 50, dt: float = 0.1) -> None: """ """ super().__init__() self.root = root self.t = t + self.n = n + self.dt = dt + self.file = [] for fname in sorted(os.listdir(root)): if fname.startswith(split): @@ -58,6 +67,12 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]: data = np.load(f"{self.root}/{self.file[file_idx]}", allow_pickle=True).item() return torch.from_numpy(data["u"][time_idx]), torch.from_numpy(data["u"][time_idx + self.stride]) + def metric(self, pred: Tensor, target: Tensor) -> dict[str, float]: + """ + """ + l2 = [rel_l2(pred[..., i], target[..., i]).item() for i in range(3)] + return { f"err_{ax}": value for ax, value in (zip("xyz", l2)) } + @hydra.main(version_base="1.3", config_path=".", config_name="config.yaml") def isotropic_trainer(cfg: DictConfig) -> None: """ @@ -110,27 +125,37 @@ def isotropic_trainer(cfg: DictConfig) -> None: dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t) dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) - testset = Re94(root=cfg.training.dataset, split="test", t=cfg.training.t) - testloader = DataLoader(testset) - - # define relative l2 error as the loss function - def loss_fun(pred: Tensor, target: Tensor) -> Tensor: - return torch.linalg.norm(pred - target) / torch.linalg.norm(target) + testset = Re94(root=cfg.training.dataset, split="test", t=cfg.training.t, n=40, dt=0.5) + testloader = DataLoader(testset, batch_size=None) # define training step @StaticCaptureTraining( model=model, optim=optimizer, logger=log, + use_graphs=False, use_amp=cfg.training.amp, - use_graphs=False + compile=cfg.training.compile ) def training_step(input: Tensor, target: Tensor) -> Tensor: pred = torch.vmap(model)(input) - loss = torch.vmap(loss_fun)(pred, target) + loss = torch.vmap(rel_l2)(pred, target) return torch.mean(loss) + # define evaluation step + @StaticCaptureEvaluateNoGrad( + model=model, + logger=log, + use_graphs=False, + use_amp=cfg.training.amp, + compile=cfg.training.compile + ) + def forward_eval(input): + return model(input) + it = 0 + + model.train() log.info("Training started") for epoch in range(cfg.training.num_epochs): @@ -146,6 +171,27 @@ def training_step(input: Tensor, target: Tensor) -> Tensor: if it and it % cfg.training.ckpt_every == 0 and dist.rank == 0: save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer, epoch=it) + if it and it % cfg.training.test_every == 0: + + model.eval() + metrics = collections.defaultdict(float) + + for input, target in tqdm(testloader, desc="Test"): + + input = input.to(dist.device) + target = target.to(dist.device) + + pred = forward_eval(input) + metric = testset.metric(pred, target) + + for key, value in metric.items(): + metrics[key] += value / len(testset) + + with LaunchLogger("test", epoch=epoch) as logger: + logger.log_minibatch(metrics) + + model.train() + log.success("Training completed") save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer) diff --git a/physicsnemo/models/eddyformer/_basis.py b/physicsnemo/models/eddyformer/_basis.py index 9967b5c605..3969888113 100644 --- a/physicsnemo/models/eddyformer/_basis.py +++ b/physicsnemo/models/eddyformer/_basis.py @@ -5,7 +5,6 @@ import torch.nn as nn import numpy as np -import functools class Basis(Protocol): @@ -96,10 +95,10 @@ def fn(self, xs: Tensor) -> Tensor: def modal(self, vals: Tensor) -> Tensor: norm = 2 * torch.arange(self.m, device=vals.device) + 1 coef = self.f * norm * self.quad[:, None] - return torch.tensordot(coef.T, vals, dims=1) + return torch.tensordot(coef.to(vals.dtype).T, vals, dims=1) def nodal(self, coef: Tensor) -> Tensor: - return Legendre.at(self, coef, self.grid) + return Legendre.at(self, coef, self.grid.to(coef.dtype)) class LegendreSEM(Legendre): @@ -119,17 +118,17 @@ def at(self, coef: Tensor, xs: Tensor) -> Tensor: def modal(self, vals: Tensor) -> Tensor: """ """ - xs = self.grid + xs = self.grid.to(vals.dtype) for _ in range(vals.ndim - 1): xs = xs.unsqueeze(-1) coef = super().modal(vals - (1 - xs) * vals[0] - xs * vals[-1])[:-2] - return torch.concat([vals[[0, -1], ...], torch.tensordot(self.inv, coef, 1)], axis=0) + return torch.concat([vals[[0, -1], ...], torch.tensordot(self.inv.to(coef.dtype), coef, 1)], axis=0) def nodal(self, coef: Tensor) -> Tensor: """ """ - xs = self.grid + xs = self.grid.to(coef.dtype) for _ in range(coef.ndim - 1): xs = xs.unsqueeze(-1) diff --git a/physicsnemo/models/eddyformer/sem_attn.py b/physicsnemo/models/eddyformer/sem_attn.py index 261a6ea8c5..57dd5aebde 100644 --- a/physicsnemo/models/eddyformer/sem_attn.py +++ b/physicsnemo/models/eddyformer/sem_attn.py @@ -73,5 +73,8 @@ def __call__(self, ϕ: SEM) -> SEM: """ q, k, v = (self.project(ϕ, name) for name in "QKV") + q = q.to(v.dtype) + k = k.to(v.dtype) + attn = nn.functional.scaled_dot_product_attention(q, k, v) return ϕ.new(self.out(attn.reshape(*ϕ.mode, *ϕ.mesh, -1)))