diff --git a/CHANGELOG.md b/CHANGELOG.md index 8062a3fe42..36466c7185 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Adds xDeepONet to experimental models + (`physicsnemo.experimental.models.xdeeponet.DeepONet`). A single + dimension-generic (2D/3D) DeepONet that accepts a spatial or MLP branch, + an optional trunk, and an optional second branch as `nn.Module` inputs + (dependency injection). Six forward-call conventions cover trunked, + trunkless, packed/auto-padded, and xFNO-style time-axis-extend modes. + Supports multi-channel output, multiple decoder types (MLP, Conv, + temporal projection), composable Fourier / UNet / Conv spatial branches + (`SpatialBranch`), and coordinate features. +- Adds `Sin` elementwise sine activation to `physicsnemo.nn`, registered + in `ACT2FN` so it can be looked up by name (`get_activation("sin")`). - Adds GLOBE model (`physicsnemo.experimental.models.globe.model.GLOBE`), including new variant that uses a dual tree traversal algorithm to reduce the complexity of the kernel evaluations from O(N^2) to O(N). diff --git a/physicsnemo/experimental/models/xdeeponet/__init__.py b/physicsnemo/experimental/models/xdeeponet/__init__.py new file mode 100644 index 0000000000..8f29a2ce09 --- /dev/null +++ b/physicsnemo/experimental/models/xdeeponet/__init__.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""xDeepONet — the extended DeepONet family. + +A single :class:`DeepONet` class assembles operator-learning +architectures spanning the DeepONet and FNO families: + +- ``deeponet``, ``u_deeponet``, ``fourier_deeponet``, ``conv_deeponet``, + ``hybrid_deeponet`` — single-branch + trunk variants. +- ``mionet``, ``fourier_mionet`` — two-branch multi-input + trunk variants. +- ``tno`` — Temporal Neural Operator (branch2 = previous solution) + trunk. +- ``ufno`` / xFNO-style trunkless operators — trunkless spatial branch + with composable Fourier / UNet / Conv layers; the last spatial axis + can be interpreted as time for autoregressive bundling via the + :attr:`DeepONet.time_modes` parameter. + +The :class:`DeepONet` class is dimension-generic (``dimension=2|3`` +constructor argument; per-dimension primitives are dispatched +internally) and dispatches forward by two flags +(:attr:`auto_pad`, :attr:`trunk`-is-None) over six valid call +conventions: packed-input vs core-input × trunked vs trunkless, +plus the ``temporal_projection`` decoder variant. See the +:class:`DeepONet` class docstring for the full matrix and worked +examples; see :class:`SpatialBranch` for the spatial-encoder +composition options (Fourier / UNet / Conv layers, multi-layer +pointwise lift, optional coordinate-feature channels). +""" + +from .branches import SpatialBranch +from .deeponet import DeepONet + +__all__ = [ + "DeepONet", + "SpatialBranch", +] diff --git a/physicsnemo/experimental/models/xdeeponet/_padding.py b/physicsnemo/experimental/models/xdeeponet/_padding.py new file mode 100644 index 0000000000..342a7c4e65 --- /dev/null +++ b/physicsnemo/experimental/models/xdeeponet/_padding.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Right-side spatial padding helpers used by the xDeepONet packed-input +forward path. + +When :class:`~physicsnemo.experimental.models.xdeeponet.DeepONet` is +constructed with ``auto_pad=True`` it aligns spatial dimensions to a +multiple (typically 8) so that spectral and convolutional sub-branches +operate on compatible shapes. These helpers are dimension-agnostic and +support 2D, 3D, or 4D spatial layouts. + +Tensor layouts used here: +- 2D spatial samples: ``(B, H, W, T, C)`` +- 3D spatial samples: ``(B, X, Y, Z, T, C)`` + +This module is private (leading underscore): the helpers are part of the +xdeeponet package's internal API surface only and may be renamed or +restructured without notice. +""" + +from __future__ import annotations + +import math +from typing import Literal, Sequence + +import torch +import torch.nn.functional as F +from jaxtyping import Shaped +from torch import Tensor + + +def compute_right_pad_to_multiple( + spatial_shape: Sequence[int], + *, + multiple: int = 8, + min_right_pad: int = 0, +) -> tuple[int, ...]: + """Compute right-side padding to reach a multiple of *multiple*. + + Parameters + ---------- + spatial_shape : Sequence[int] + Current spatial dimension sizes. + multiple : int, optional + Target alignment (default ``8``). + min_right_pad : int, optional + Minimum right-side padding applied per dimension (default ``0``). + + Returns + ------- + tuple[int, ...] + Right-side padding per dimension such that ``(d + pad)`` is a multiple + of *multiple* and ``pad >= min_right_pad``. + """ + if multiple <= 0: + raise ValueError(f"multiple must be > 0, got {multiple}") + if min_right_pad < 0: + raise ValueError(f"min_right_pad must be >= 0, got {min_right_pad}") + + pads = [] + for d in spatial_shape: + if d <= 0: + raise ValueError( + f"spatial dimensions must be positive, got {spatial_shape}" + ) + to_mult = (multiple - (d % multiple)) % multiple + if to_mult >= min_right_pad: + pad = to_mult + else: + deficit = min_right_pad - to_mult + k = (deficit + multiple - 1) // multiple + pad = to_mult + k * multiple + pads.append(int(pad)) + return tuple(pads) + + +def pad_right_nd( + x: Shaped[Tensor, "..."], + *, + dims: Sequence[int], + right_pad: Sequence[int], + mode: Literal["replicate", "constant"] = "replicate", + constant_value: float = 0.0, +) -> Shaped[Tensor, "..."]: + """Right-pad arbitrary dimensions of an N-D tensor. + + Implemented manually so it works for ``mode="replicate"`` even when + :func:`torch.nn.functional.pad` does not support the tensor rank + (e.g. 6D tensors in the 3D-spatial case). + + Parameters + ---------- + x : torch.Tensor + Input tensor of any rank and dtype. + dims : Sequence[int] + Dimensions to right-pad. Negative indices are supported. + right_pad : Sequence[int] + Right-side padding amounts per ``dims`` entry. Non-positive + entries are no-ops. + mode : str, optional + ``"replicate"`` (default) repeats the last slice along each + padded dim; ``"constant"`` uses ``constant_value``. + constant_value : float, optional + Fill value when ``mode="constant"`` (default ``0.0``). + + Returns + ------- + torch.Tensor + Tensor of the same rank and dtype as ``x`` with the specified + dimensions right-padded. + """ + if len(dims) != len(right_pad): + raise ValueError("dims and right_pad must have the same length") + if not dims: + return x + + for dim, pad in zip(dims, right_pad): + pad = int(pad) + if pad <= 0: + continue + if dim < 0: + dim = x.dim() + dim + if dim < 0 or dim >= x.dim(): + raise ValueError(f"invalid dim {dim} for x.dim()={x.dim()}") + + if mode == "constant": + pad_shape = list(x.shape) + pad_shape[dim] = pad + pad_tensor = torch.full( + pad_shape, float(constant_value), dtype=x.dtype, device=x.device + ) + x = torch.cat([x, pad_tensor], dim=dim) + continue + + if mode != "replicate": + raise ValueError( + f"pad_right_nd supports mode='replicate' or 'constant', got {mode}" + ) + + last = x.select(dim, x.size(dim) - 1).unsqueeze(dim) + expand_shape = list(x.shape) + expand_shape[dim] = pad + pad_tensor = last.expand(*expand_shape) + x = torch.cat([x, pad_tensor], dim=dim) + + return x + + +def pad_spatial_right( + x: Shaped[Tensor, "..."], + *, + spatial_ndim: int, + right_pad: Sequence[int], + mode: Literal["replicate", "constant"] = "replicate", + constant_value: float = 0.0, +) -> Shaped[Tensor, "..."]: + """Right-pad the first *spatial_ndim* dimensions after the batch dim. + + Assumes ``x`` is shaped ``(B, *spatial, *rest)``. + + Parameters + ---------- + x : torch.Tensor + Input tensor shaped ``(B, *spatial, *rest)``; any dtype is + accepted. Must satisfy ``x.dim() >= 1 + spatial_ndim``. + spatial_ndim : int + Number of spatial dimensions immediately following the batch + dim. Must be ``2``, ``3``, or ``4``. + right_pad : Sequence[int] + Right-side padding amounts per spatial dimension; must have + length ``spatial_ndim``. Non-positive entries are no-ops. + mode : str, optional + ``"replicate"`` (default) or ``"constant"``. + constant_value : float, optional + Fill value when ``mode="constant"`` (default ``0.0``). + + Returns + ------- + torch.Tensor + Tensor of the same rank and dtype as ``x`` with the spatial + dimensions right-padded. + """ + if spatial_ndim not in (2, 3, 4): + raise ValueError(f"spatial_ndim must be 2, 3, or 4, got {spatial_ndim}") + if len(right_pad) != spatial_ndim: + raise ValueError( + f"right_pad must have length {spatial_ndim}, got {len(right_pad)}" + ) + if x.dim() < 1 + spatial_ndim: + raise ValueError( + f"expected x.dim() >= {1 + spatial_ndim}, got x.dim()={x.dim()}" + ) + if all(int(p) == 0 for p in right_pad): + return x + + # For 4 spatial dims fall back to the generic implementation (works for 6D+). + if spatial_ndim == 4: + dims = [1, 2, 3, 4] + return pad_right_nd( + x, + dims=dims, + right_pad=right_pad, + mode=mode, + constant_value=constant_value, + ) + + # For 2D/3D spatial, use a reshape trick so F.pad(replicate) applies. + b = x.shape[0] + spatial_shape = x.shape[1 : 1 + spatial_ndim] + rest_shape = x.shape[1 + spatial_ndim :] + rest_prod = math.prod(rest_shape) + + x_reshaped = x.reshape(b, *spatial_shape, rest_prod).permute( + 0, spatial_ndim + 1, *range(1, 1 + spatial_ndim) + ) + + if spatial_ndim == 2: + pad_h, pad_w = (int(p) for p in right_pad) + pad = (0, pad_w, 0, pad_h) + else: + pad_x, pad_y, pad_z = (int(p) for p in right_pad) + pad = (0, pad_z, 0, pad_y, 0, pad_x) + + if mode == "constant": + x_padded = F.pad(x_reshaped, pad, mode="constant", value=float(constant_value)) + else: + x_padded = F.pad(x_reshaped, pad, mode=mode) + + padded_spatial = x_padded.shape[2 : 2 + spatial_ndim] + return x_padded.permute(0, *range(2, 2 + spatial_ndim), 1).reshape( + b, *padded_spatial, *rest_shape + ) + + +__all__ = [ + "compute_right_pad_to_multiple", + "pad_right_nd", + "pad_spatial_right", +] diff --git a/physicsnemo/experimental/models/xdeeponet/branches.py b/physicsnemo/experimental/models/xdeeponet/branches.py new file mode 100644 index 0000000000..7e75da34a8 --- /dev/null +++ b/physicsnemo/experimental/models/xdeeponet/branches.py @@ -0,0 +1,505 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Spatial branch building block used by the xDeepONet family. + +Provides a single dimension-generic spatial encoder: + +- :class:`SpatialBranch` — composable from Fourier, UNet, and Conv layers, + parameterized by ``dimension`` (``2`` or ``3``) to operate on either + :math:`(B, H, W, C)` or :math:`(B, X, Y, Z, C)` inputs. Per-dimension + primitives are dispatched through the module-level :data:`_DIM_LAYERS` + lookup table. + +The MLP trunk and the optional MLP branch are built directly from +:class:`physicsnemo.models.mlp.FullyConnected` by the helpers in +``deeponet.py`` (``_build_trunk_mlp`` and ``_build_mlp_branch``). + +UNet sub-modules inside the spatial branch use +:class:`physicsnemo.models.unet.UNet` (3D). A small adapter +:class:`_UNet2DFromUNet3D` is provided locally for the 2D path: it wraps +the 3D UNet with a singleton time dimension so the same library model covers +both spatial dimensionalities. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from jaxtyping import Float +from torch import Tensor + +from physicsnemo.core.meta import ModelMetaData +from physicsnemo.core.module import Module +from physicsnemo.models.unet import UNet as _PhysicsNeMoUNet +from physicsnemo.nn import SpectralConv2d, SpectralConv3d, get_activation + +# Per-dimension layer lookup table used by :class:`SpatialBranch` to dispatch +# spectral / conv / pooling / UNet primitives without code duplication. The +# UNet adapter entries are populated lazily below (after the adapter classes +# are defined) so this module remains importable in any order. +_DIM_LAYERS: dict[int, dict] = { + 2: { + "SpectralConv": SpectralConv2d, + "Conv": nn.Conv2d, + "BatchNorm": nn.BatchNorm2d, + "AdaptiveAvgPool": nn.AdaptiveAvgPool2d, + "interp_mode": "bilinear", + "default_modes": (12, 12), + }, + 3: { + "SpectralConv": SpectralConv3d, + "Conv": nn.Conv3d, + "BatchNorm": nn.BatchNorm3d, + "AdaptiveAvgPool": nn.AdaptiveAvgPool3d, + "interp_mode": "trilinear", + "default_modes": (10, 10, 8), + }, +} + + +def _channel_first_permute(dimension: int) -> tuple[int, ...]: + """Permutation that moves the channels axis from the last position + (``(B, *spatial, C)``) to immediately after the batch dim + (``(B, C, *spatial)``).""" + return (0, dimension + 1, *range(1, dimension + 1)) + + +def _channel_last_permute(dimension: int) -> tuple[int, ...]: + """Inverse of :func:`_channel_first_permute`.""" + return (0, *range(2, dimension + 2), 1) + + +# --------------------------------------------------------------------------- +# UNet adapters (wrap the library's 3D UNet for reuse inside spatial branches) +# --------------------------------------------------------------------------- + + +class _UNet2DFromUNet3D(nn.Module): + r"""Adapter using :class:`physicsnemo.models.unet.UNet` for 2D inputs. + + The library UNet is 3D only. To reuse it for 2D, this adapter adds a + short tiled time axis of length :math:`2^{\text{model\_depth}}` (long + enough to survive the UNet's ``model_depth`` pooling stages), runs the + 3D UNet, and averages the result back to 2D. Channel-first layout + :math:`(B, C, H, W)` is preserved on input and output. + + .. important:: + + Selecting ``num_unet_layers > 0`` in a 2D + :class:`~physicsnemo.experimental.models.xdeeponet.SpatialBranch` + (i.e. when this 2D adapter is used) makes the UNet branch operate + on a tiled :math:`2^{\text{model\_depth}}`-deep volume. With the + default ``model_depth=3`` this is an **8x** memory and compute + cost relative to a native 2D UNet of the same width and depth. + This overhead is a property of the upstream library UNet being + 3D-only, not of this branch. When ``num_unet_layers == 0`` the + branch is bypassed and there is no overhead. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + model_depth: int = 3, + feature_map_channels: list[int] | None = None, + ): + super().__init__() + if feature_map_channels is None: + feature_map_channels = [in_channels] * model_depth + self._t_tile = 2**model_depth + self.unet = _PhysicsNeMoUNet( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + model_depth=model_depth, + feature_map_channels=feature_map_channels, + num_conv_blocks=1, + conv_activation="leaky_relu", + conv_transpose_activation="leaky_relu", + padding=kernel_size // 2, + pooling_type="MaxPool3d", + normalization="batchnorm", + gradient_checkpointing=False, + ) + + def forward( + self, + x: Float[Tensor, "batch channels h w"], + ) -> Float[Tensor, "batch out_channels h w"]: + """Forward through the 3D UNet via a tiled time axis.""" + x = x.unsqueeze(-1).repeat(1, 1, 1, 1, self._t_tile) + x = self.unet(x) + return x.mean(dim=-1) + + +class _UNet3DFromUNet3D(nn.Module): + r"""Thin wrapper exposing :class:`physicsnemo.models.unet.UNet`. + + Exposes the library 3D UNet with a fixed default configuration suitable + for skip-connection reuse inside :class:`SpatialBranch` (``dimension=3``). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + model_depth: int = 3, + feature_map_channels: list[int] | None = None, + ): + super().__init__() + if feature_map_channels is None: + feature_map_channels = [in_channels] * model_depth + self.unet = _PhysicsNeMoUNet( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + model_depth=model_depth, + feature_map_channels=feature_map_channels, + num_conv_blocks=1, + conv_activation="leaky_relu", + conv_transpose_activation="leaky_relu", + padding=kernel_size // 2, + pooling_type="MaxPool3d", + normalization="batchnorm", + gradient_checkpointing=False, + ) + + def forward( + self, + x: Float[Tensor, "batch channels x y z"], + ) -> Float[Tensor, "batch out_channels x y z"]: + """Forward pass through the library 3D UNet.""" + return self.unet(x) + + +# Populate the UNet adapter entries now that the adapter classes are +# defined; keeps the lookup table self-contained for callers below. +_DIM_LAYERS[2]["UNetAdapter"] = _UNet2DFromUNet3D +_DIM_LAYERS[3]["UNetAdapter"] = _UNet3DFromUNet3D + + +# --------------------------------------------------------------------------- +# Spatial branch (dimension-generic) +# --------------------------------------------------------------------------- + + +@dataclass +class _SpatialBranchMetaData(ModelMetaData): + """PhysicsNeMo model metadata for :class:`SpatialBranch`.""" + + +class SpatialBranch(Module): + r"""Dimension-generic spatial branch composable from Fourier, UNet, and + Conv layers. + + Operates on 2D :math:`(B, H, W, C)` or 3D :math:`(B, X, Y, Z, C)` inputs + selected via the ``dimension`` constructor argument; the spectral / + convolutional / pooling / UNet sub-modules are dispatched through the + module-level :data:`_DIM_LAYERS` lookup table so no per-dimension + subclasses are needed. The branch can be configured to use any + combination of spectral, UNet, and plain convolutional layers. When + Fourier layers are present (the "base" mode) UNet/Conv layers are + added alongside the spectral path (hybrid residual). When no Fourier + layers are present UNet/Conv act as independent sequential layers. + + Parameters + ---------- + dimension : int, optional + Spatial dimensionality of the inputs. Must be ``2`` (default) or + ``3``. + in_channels : int + Number of input channels (used only for documentation; the lift is + :class:`torch.nn.LazyLinear`). + width : int + Latent/output width. + num_fourier_layers : int + Number of spectral layers. + num_unet_layers : int + Number of UNet layers (uses :class:`physicsnemo.models.unet.UNet`). + num_conv_layers : int + Number of Conv+BN layers. + modes1, modes2 : int + Fourier modes along the first two spatial axes. + modes3 : int, optional + Fourier modes along the third spatial axis. Required when + ``dimension == 3``; ignored when ``dimension == 2``. + kernel_size : int + Kernel size for UNet and Conv layers. + activation_fn : str + Activation function name. + internal_resolution : list, optional + If set, inputs are adaptively pooled to this resolution before + processing and upsampled back, decoupling model size from grid size. + coord_features : bool, optional + When ``True``, concatenates ``dimension`` channels containing + the per-axis normalized coordinates (each spanning :math:`[0, 1]`) + to the input before the lift. Useful for operator-learning + architectures that don't carry coordinates through a trunk MLP + (e.g. the xFNO family) and instead inject them as extra channels. + Default ``False``. + lift_layers : int, optional + Number of layers in the lifting network (default ``1``, a single + :class:`torch.nn.LazyLinear`). When ``> 1`` the lift becomes a + multi-layer pointwise MLP equivalent to a stack of 1x1 (1x1x1) + convolutions. + lift_hidden_width : int, optional + Hidden width inside the multi-layer lift. Only consulted when + ``lift_layers > 1``; defaults to ``width // 2``. + + Attributes + ---------- + modes_per_dim : tuple[int, ...] + The Fourier mode counts the branch was built with, in spatial-axis + order. Length matches ``dimension``. + + Forward + ------- + x : torch.Tensor + Channels-last input of shape :math:`(B, H, W, C)` for + ``dimension=2`` or :math:`(B, X, Y, Z, C)` for ``dimension=3``. + + Outputs + ------- + torch.Tensor + Channels-last output with the same spatial layout as the input and + the channels dimension replaced by ``width``. + + Examples + -------- + 2D: + + >>> import torch + >>> from physicsnemo.experimental.models.xdeeponet import SpatialBranch + >>> branch = SpatialBranch( + ... dimension=2, in_channels=5, width=64, num_unet_layers=1, kernel_size=3 + ... ) + >>> x = torch.randn(2, 32, 32, 5) # (B, H, W, C) + >>> out = branch(x) # (2, 32, 32, 64) + + 3D: + + >>> branch = SpatialBranch( + ... dimension=3, in_channels=5, width=64, num_unet_layers=1, kernel_size=3 + ... ) + >>> x = torch.randn(1, 16, 16, 16, 5) # (B, X, Y, Z, C) + >>> out = branch(x) # (1, 16, 16, 16, 64) + + With coordinate features (xFNO-style trunkless operator): + + >>> branch = SpatialBranch( + ... dimension=3, in_channels=5, width=64, + ... num_fourier_layers=4, modes1=12, modes2=12, modes3=8, + ... coord_features=True, lift_layers=2, + ... ) + >>> x = torch.randn(1, 16, 16, 16, 5) # (B, X, Y, Z, C) + >>> out = branch(x) # (1, 16, 16, 16, 64) + """ + + def __init__( + self, + dimension: int = 2, + in_channels: int = 12, + width: int = 64, + num_fourier_layers: int = 0, + num_unet_layers: int = 0, + num_conv_layers: int = 0, + modes1: int = 12, + modes2: int = 12, + modes3: int | None = None, + kernel_size: int = 3, + activation_fn: str = "gelu", + internal_resolution: list | None = None, + coord_features: bool = False, + lift_layers: int = 1, + lift_hidden_width: int | None = None, + ): + super().__init__(meta=_SpatialBranchMetaData()) + + if dimension not in _DIM_LAYERS: + raise ValueError( + f"SpatialBranch only supports dimension=2 or dimension=3, " + f"got dimension={dimension!r}." + ) + layers = _DIM_LAYERS[dimension] + self.dimension = dimension + + if dimension == 3 and modes3 is None: + modes3 = layers["default_modes"][2] + modes_for_spec = ( + (modes1, modes2) if dimension == 2 else (modes1, modes2, modes3) + ) + # Public attribute so downstream code (e.g. + # :class:`DeepONet`'s time-axis-extend) can introspect the + # branch's mode configuration. + self.modes_per_dim: tuple[int, ...] = tuple(modes_for_spec) + + self.num_fourier_layers = num_fourier_layers + self.num_unet_layers = num_unet_layers + self.num_conv_layers = num_conv_layers + self.use_fourier_base = num_fourier_layers > 0 + self.internal_resolution = ( + tuple(internal_resolution) if internal_resolution else None + ) + self.coord_features = coord_features + + total_layers = num_fourier_layers + num_unet_layers + num_conv_layers + if total_layers == 0: + raise ValueError("SpatialBranch requires at least one layer type") + + if lift_layers < 1: + raise ValueError(f"lift_layers must be >= 1, got {lift_layers}.") + + self.activation_fn = get_activation(activation_fn) + + if self.internal_resolution is not None: + self.adaptive_pool = layers["AdaptiveAvgPool"](self.internal_resolution) + + # Lifting network: single LazyLinear by default, or a multi-layer + # pointwise MLP when ``lift_layers > 1`` (equivalent to a stack of + # 1x1 / 1x1x1 convolutions applied channels-last). + if lift_layers == 1: + self.lift: nn.Module = nn.LazyLinear(width) + else: + hidden = lift_hidden_width if lift_hidden_width is not None else width // 2 + stack: list[nn.Module] = [ + nn.LazyLinear(hidden), + get_activation(activation_fn), + ] + for _ in range(lift_layers - 2): + stack.extend([nn.Linear(hidden, hidden), get_activation(activation_fn)]) + stack.append(nn.Linear(hidden, width)) + self.lift = nn.Sequential(*stack) + + num_fourier_components = ( + total_layers if self.use_fourier_base else num_fourier_layers + ) + SpectralConv = layers["SpectralConv"] + Conv = layers["Conv"] + BatchNorm = layers["BatchNorm"] + UNetAdapter = layers["UNetAdapter"] + + self.spectral_convs = nn.ModuleList() + self.conv_1x1s = nn.ModuleList() + for _ in range(num_fourier_components): + self.spectral_convs.append(SpectralConv(width, width, *modes_for_spec)) + self.conv_1x1s.append(Conv(width, width, kernel_size=1)) + + self.unet_modules = nn.ModuleList() + for _ in range(num_unet_layers): + self.unet_modules.append(UNetAdapter(width, width, kernel_size=kernel_size)) + + self.conv_modules = nn.ModuleList() + padding = (kernel_size - 1) // 2 + for _ in range(num_conv_layers): + self.conv_modules.append( + nn.Sequential( + Conv( + width, + width, + kernel_size=kernel_size, + padding=padding, + bias=False, + ), + BatchNorm(width), + ) + ) + + # Cached so the forward path is dimension-agnostic. + self._channel_first_permute = _channel_first_permute(dimension) + self._channel_last_permute = _channel_last_permute(dimension) + self._interp_mode = layers["interp_mode"] + + def _build_coord_features(self, x: Tensor) -> Tensor: + """Build a channels-last coordinate-feature tensor matching ``x``. + + Returns a tensor of shape ``(B, *spatial, dimension)`` whose + ``dimension`` trailing channels are the per-axis normalized + coordinates in :math:`[0, 1]`. + """ + batch_size = x.shape[0] + spatial_shape = x.shape[1:-1] + grids = [ + torch.linspace(0.0, 1.0, s, dtype=x.dtype, device=x.device) + for s in spatial_shape + ] + mesh = torch.meshgrid(*grids, indexing="ij") + coord = torch.stack(mesh, dim=-1) # (*spatial, dimension) + coord = coord.unsqueeze(0).expand(batch_size, *spatial_shape, self.dimension) + return coord + + def forward( + self, + x: Float[Tensor, "..."], + ) -> Float[Tensor, "..."]: + """Forward pass of the spatial branch (2D or 3D, selected at init).""" + if not torch.compiler.is_compiling(): + expected_ndim = self.dimension + 2 # batch + spatial dims + channels + if x.ndim != expected_ndim: + raise ValueError( + f"Expected {expected_ndim}D input " + f"(B, {'H, W' if self.dimension == 2 else 'X, Y, Z'}, C), " + f"got {x.ndim}D tensor with shape {tuple(x.shape)}." + ) + if self.coord_features: + x = torch.cat([x, self._build_coord_features(x)], dim=-1) + x = self.lift(x) + x = x.permute(*self._channel_first_permute) + + original_size = x.shape[2:] + if self.internal_resolution is not None: + x = self.adaptive_pool(x) + + for i in range(self.num_fourier_layers): + x = self.activation_fn(self.spectral_convs[i](x) + self.conv_1x1s[i](x)) + + if self.use_fourier_base: + for i in range(self.num_unet_layers): + j = self.num_fourier_layers + i + x = self.activation_fn( + self.spectral_convs[j](x) + + self.conv_1x1s[j](x) + + self.unet_modules[i](x) + ) + for i in range(self.num_conv_layers): + j = self.num_fourier_layers + self.num_unet_layers + i + x = self.activation_fn( + self.spectral_convs[j](x) + + self.conv_1x1s[j](x) + + self.conv_modules[i](x) + ) + else: + for unet in self.unet_modules: + x = self.activation_fn(unet(x)) + for conv in self.conv_modules: + x = self.activation_fn(conv(x)) + + if self.internal_resolution is not None and x.shape[2:] != original_size: + x = F.interpolate( + x, size=original_size, mode=self._interp_mode, align_corners=True + ) + + return x.permute(*self._channel_last_permute) + + +__all__ = [ + "SpatialBranch", +] diff --git a/physicsnemo/experimental/models/xdeeponet/deeponet.py b/physicsnemo/experimental/models/xdeeponet/deeponet.py new file mode 100644 index 0000000000..a12ec31524 --- /dev/null +++ b/physicsnemo/experimental/models/xdeeponet/deeponet.py @@ -0,0 +1,1070 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core xDeepONet architectures for 2D and 3D operator learning. + +The xDeepONet family extends the original DeepONet with eight variants +that cover both single-input and multi-input operator learning, including +the Temporal Neural Operator (TNO) for autoregressive temporal bundling: + +- ``deeponet`` — basic DeepONet (MLP branch). +- ``u_deeponet`` — UNet-enhanced spatial branch. +- ``fourier_deeponet`` — spectral (Fourier) spatial branch. +- ``conv_deeponet`` — plain convolutional spatial branch. +- ``hybrid_deeponet`` — Fourier + UNet + Conv spatial branch. +- ``mionet`` — two-branch multi-input operator network. +- ``fourier_mionet`` — MIONet with a Fourier spatial branch. +- ``tno`` — Temporal Neural Operator (branch2 = previous + solution, autoregressive only). + +The core :class:`DeepONet` class is dimension-generic: pass +``dimension=2`` for 2D spatial inputs ``(B, H, W, C)`` and ``dimension=3`` +for 3D volumetric inputs ``(B, X, Y, Z, C)``. Construction is the same +in both cases — a primary branch (``branch1``), an optional secondary +branch (``branch2`` for MIONet/TNO), a coordinate trunk, and a decoder — +with per-dimension primitives dispatched internally through a small +lookup table (see :data:`SpatialBranch._DIM_LAYERS` and ``_DIM_DEFAULTS`` +in this module). + +References +---------- +- Lu, L. et al. (2021). "Learning nonlinear operators via DeepONet." + *Nature Machine Intelligence*, 3, 218-229. +- Jin, P., Meng, S. & Lu, L. (2022). "MIONet: Learning multiple-input + operators via tensor product." *SIAM J. Sci. Comp.*, 44(6), A3490-A3514. +- Diab, W. & Al Kobaisi, M. (2024). "U-DeepONet: U-Net enhanced deep + operator network for geologic carbon sequestration." + *Scientific Reports*, 14, 21298. +- Zhu, M. et al. (2023). "Fourier-DeepONet: Fourier-enhanced deep operator + networks for full waveform inversion." arXiv:2305.17289. +- Diab, W. & Al Kobaisi, M. (2025). "Temporal neural operator for modeling + time-dependent physical phenomena." *Scientific Reports*, 15. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal, get_args + +import torch +import torch.nn as nn +from jaxtyping import Float +from torch import Tensor + +from physicsnemo.core.meta import ModelMetaData +from physicsnemo.core.module import Module +from physicsnemo.experimental.models.xdeeponet._padding import ( + compute_right_pad_to_multiple, + pad_spatial_right, +) +from physicsnemo.experimental.models.xdeeponet.branches import SpatialBranch +from physicsnemo.models.mlp import FullyConnected +from physicsnemo.nn import Conv2dFCLayer, Conv3dFCLayer, get_activation + +# Type alias for the enumerated ``decoder_type`` parameter. Annotating +# with ``Literal`` rather than bare ``str`` lets static type checkers +# and IDEs flag unknown values at the call site; the runtime +# ``.lower()`` normalization and ``ValueError`` guard below remain in +# place so mixed-case strings still flow through (Python does not +# enforce ``Literal`` at runtime). +_DecoderTypeStr = Literal["mlp", "conv", "temporal_projection"] + +# Supported decoder types -- runtime view of the ``_DecoderTypeStr`` +# alias. Used by ``DeepONet.__init__`` to reject unknown decoder +# types at the API boundary instead of deferring to ``_build_decoder`` +# and raising cryptically from deep inside construction. +_VALID_DECODER_TYPES = frozenset(get_args(_DecoderTypeStr)) + + +@dataclass +class _DeepONetMetaData(ModelMetaData): + """PhysicsNeMo model metadata for :class:`DeepONet`.""" + + +# Per-dimension defaults referenced by the class docstring (input +# channels / modes defaults that users can copy into their own branch +# construction). See ``branches._DIM_LAYERS`` for the matching +# per-dimension layer-class lookup table. +_DIM_DEFAULTS: dict[int, dict] = { + 2: { + "default_in_channels": 12, + "default_modes": (12, 12), + "ConvNdFC": "Conv2dFCLayer", + }, + 3: { + "default_in_channels": 11, + "default_modes": (10, 10, 8), + "ConvNdFC": "Conv3dFCLayer", + }, +} + + +# --------------------------------------------------------------------------- +# 2D DeepONet +# --------------------------------------------------------------------------- + + +class DeepONet(Module): + r"""Dimension-generic xDeepONet core architecture for operator learning. + + Combines a primary spatial/MLP branch, an optional secondary branch + (for MIONet/TNO variants), a coordinate trunk, and a decoder. The + branch outputs and trunk are combined via Hadamard product and then + projected to the output by the decoder. + + The same class handles 2D inputs ``(B, H, W, C)`` and 3D inputs + ``(B, X, Y, Z, C)``; the spatial dimensionality is selected through + the :attr:`dimension` constructor argument and the per-dimension + primitives (``SpectralConv*d``, ``Conv*dFCLayer``, ``Adaptive*Pool*d``, + UNet adapters) are dispatched internally. + + Parameters + ---------- + branch1 : torch.nn.Module + Primary branch. Spatial branches must be a :class:`SpatialBranch` + instance (or subclass) and produce a channels-last output of shape + :math:`(B, *spatial, \text{width})`. Any other module is treated + as an MLP branch and must consume a 2D input :math:`(B, D_{in})` + and produce a 2D output :math:`(B, \text{width})`. + trunk : torch.nn.Module, optional + Trunk MLP. Takes coordinate queries of shape :math:`(T, D_{in})` + and produces :math:`(T, \text{width})`. Set to ``None`` to build + a trunkless operator (the branch output is fed directly to the + decoder, skipping the branch-trunk Hadamard product). This is + the xFNO operator shape and is the recommended entry point for + Fourier-only operators that don't need coordinate queries. + branch2 : torch.nn.Module, optional + Secondary branch for MIONet / TNO variants. Must produce the same + output rank as ``branch1`` (both spatial or both flat); the + constructor rejects mixed configurations up front. + dimension : Literal[2, 3], optional + Spatial dimensionality of the inputs. Must be ``2`` (default) or + ``3``. + width : int, optional + Latent width. Must match the output channel dim of ``branch1``, + ``branch2`` (if any), and (when present) ``trunk``. + out_channels : int, optional + Number of output channels. Default ``1``. The decoder's final + layer maps the latent width to ``out_channels``; the output + tensor always carries an explicit trailing channel dim of size + ``out_channels``. + decoder_type : Literal["mlp", "conv", "temporal_projection"], optional + Decoder choice: ``"mlp"`` queries the trunk at each target + timestep and applies an MLP decoder; ``"conv"`` uses a + convolutional decoder; ``"temporal_projection"`` queries the + trunk once and projects the combined latent to K timesteps via a + learned linear head for fast autoregressive bundling. + ``"temporal_projection"`` requires a trunk. + Mixed-case strings are accepted at runtime and lowercased. + decoder_width : int, optional + Decoder hidden width. + decoder_layers : int, optional + Decoder layer count. + decoder_activation_fn : str, optional + Activation function name for the decoder. Resolved at decoder + construction time via + :func:`physicsnemo.nn.module.activations.get_activation`; see + :data:`physicsnemo.nn.module.activations.ACT2FN` for the full + set of supported names (e.g. ``"relu"``, ``"gelu"``, ``"silu"``, + ``"tanh"``, ``"sin"``). Default ``"relu"``. + output_window : int, optional + Output window length K for the ``"temporal_projection"`` decoder. + When supplied the temporal head is constructed at ``__init__``, which + produces a deterministic ``state_dict`` and makes checkpoint + round-tripping straightforward. When omitted, + :meth:`set_output_window` must be called before the first forward + pass. + auto_pad : bool, optional + When ``True`` (default ``False``) the packed-input forward path + right-pads the spatial dims to a multiple of 8 (with a floor of + ``padding``) before running the core operator and crops back + afterwards. + padding : int, optional + Minimum right-side padding for the spatial dims when + ``auto_pad=True``. Rounded up to a multiple of 8. Default ``8``. + trunk_input : Literal["time", "grid"], optional + Used by trunked packed-input mode to decide how trunk coordinates + are extracted from the channel dimension of the packed input. + ``"time"`` (default) treats the last channel as time; + ``"grid"`` treats the last :math:`d+1` channels as + ``(x, y, [z,] t)``. Ignored in trunkless mode. + time_modes : int, optional + Enables xFNO-style time-axis autoregressive bundling. Only + meaningful in **trunkless packed-input mode** (``trunk=None`` + and ``auto_pad=True``). When set and ``target_times`` of length + :math:`K` is supplied at forward time, the last spatial axis is + treated as the time axis and replicate-padded to + :math:`\max(T_{in} + K, 2 \cdot \texttt{time\_modes})` before the + branch runs, then cropped to the K future steps. Must equal + the Fourier-modes count along the time axis of the branch. + Forward + ------- + Six valid call conventions, dispatched by :attr:`auto_pad` and + whether :attr:`trunk` is ``None``. The output tensor's trailing + dim is always :attr:`out_channels` (no implicit squeeze). + + +------------+-----------------+--------+--------------------------+---------------------------------+ + | ``auto_pad``| ``trunk`` | branch | call style | output shape | + +============+=================+========+==========================+=================================+ + | ``False`` | module | spatial| ``model(x_b, x_t)`` | ``(B, *spatial, T, oc)`` | + +------------+-----------------+--------+--------------------------+---------------------------------+ + | ``False`` | module | mlp | ``model(x_b, x_t)`` | ``(B, T, oc)`` | + +------------+-----------------+--------+--------------------------+---------------------------------+ + | ``True`` | module | spatial| ``model(x)`` packed | ``(B, *spatial, T, oc)`` | + +------------+-----------------+--------+--------------------------+---------------------------------+ + | ``True`` | module + ``output_window=K`` | ``model(x)`` | ``(B, *spatial, K, oc)`` (temporal_projection) | + +------------+-----------------+--------+--------------------------+---------------------------------+ + | ``False`` | ``None`` | spatial| ``model(x)`` | ``(B, *spatial, oc)`` | + +------------+-----------------+--------+--------------------------+---------------------------------+ + | ``True`` | ``None`` | spatial| ``model(x)`` (+ optional ``target_times`` w/ time_modes set) | ``(B, *spatial', oc)`` | + +------------+-----------------+--------+--------------------------+---------------------------------+ + + In trunkless packed mode with ``time_modes`` set and ``target_times`` + of length K, the last spatial axis is sliced to the K future steps + (:math:`[T_{in} : T_{in}+K]`); otherwise the spatial axes are + cropped back to the original input shape. + + Notes + ----- + The :class:`SpatialBranch` ``in_channels`` defaults to ``12`` for + ``dimension=2`` and ``11`` for ``dimension=3``; default Fourier modes + are ``(12, 12)`` and ``(10, 10, 8)`` respectively. + + Examples + -------- + 2D U-DeepONet (trunked): + + >>> import torch + >>> from physicsnemo.experimental.models.xdeeponet import DeepONet, SpatialBranch + >>> from physicsnemo.models.mlp import FullyConnected + >>> branch1 = SpatialBranch( + ... dimension=2, in_channels=5, width=64, + ... num_unet_layers=1, kernel_size=3, activation_fn="tanh", + ... ) + >>> trunk = FullyConnected( + ... in_features=1, layer_size=64, out_features=64, + ... num_layers=4, activation_fn="sin", + ... ) + >>> model = DeepONet( + ... branch1=branch1, trunk=trunk, + ... dimension=2, width=64, out_channels=1, + ... decoder_type="mlp", decoder_width=64, decoder_layers=2, + ... ) + >>> x_branch = torch.randn(2, 32, 32, 5) + >>> x_time = torch.linspace(0, 1, 3).unsqueeze(-1) + >>> out = model(x_branch, x_time) # (2, 32, 32, 3, 1) + + 3D U-FNO (trunkless, packed input with auto_pad + time-axis-extend): + + >>> branch1 = SpatialBranch( + ... dimension=3, in_channels=2, width=32, + ... num_fourier_layers=4, num_unet_layers=0, + ... modes1=12, modes2=12, modes3=8, + ... coord_features=True, + ... ) + >>> model = DeepONet( + ... branch1=branch1, trunk=None, + ... dimension=3, width=32, out_channels=1, + ... decoder_type="mlp", decoder_width=32, decoder_layers=2, + ... auto_pad=True, padding=8, + ... time_modes=8, # enables time-extend + ... ) + >>> x = torch.randn(1, 32, 32, 4, 2) # (B, H, W, T_in=4, C) + >>> y = model(x) # (1, 32, 32, 4, 1) -- predict same length + >>> t_future = torch.linspace(0.5, 1.0, 6) # K=6 future steps + >>> y_future = model(x, target_times=t_future) # (1, 32, 32, 6, 1) + """ + + def __init__( + self, + branch1: nn.Module, + *, + trunk: nn.Module | None = None, + branch2: nn.Module | None = None, + dimension: Literal[2, 3] = 2, + width: int = 64, + out_channels: int = 1, + decoder_type: _DecoderTypeStr = "mlp", + decoder_width: int = 128, + decoder_layers: int = 2, + decoder_activation_fn: str = "relu", + output_window: int | None = None, + auto_pad: bool = False, + padding: int = 8, + trunk_input: Literal["time", "grid"] = "time", + time_modes: int | None = None, + ): + super().__init__(meta=_DeepONetMetaData()) + + if dimension not in _DIM_DEFAULTS: + raise ValueError( + f"DeepONet only supports dimension=2 or dimension=3, " + f"got dimension={dimension!r}." + ) + self.dimension = dimension + + if out_channels < 1: + raise ValueError(f"out_channels must be >= 1, got {out_channels}.") + self.out_channels = out_channels + + decoder_type_lc = decoder_type.lower() + if decoder_type_lc not in _VALID_DECODER_TYPES: + raise ValueError( + f"Unknown decoder_type: {decoder_type!r}. Valid: " + f"{sorted(_VALID_DECODER_TYPES)}." + ) + self.decoder_type = decoder_type_lc + self.decoder_activation_fn = decoder_activation_fn + + if trunk_input not in ("time", "grid"): + raise ValueError( + f"trunk_input must be 'time' or 'grid', got {trunk_input!r}." + ) + self.trunk_input = trunk_input + + # Auto-padding: when enabled, the packed-input forward path + # right-pads the spatial dims to a multiple of 8 (with a floor of + # ``padding``) before running the core operator and crops back + # afterwards. + if padding < 0: + raise ValueError(f"padding must be non-negative, got {padding}.") + self.auto_pad = auto_pad + # Round the padding up to a multiple of 8 so the UNet pooling + # chain stays evenly divisible. Stored even when + # ``auto_pad=False`` so callers can inspect the value, but the + # forward path only consults it when ``auto_pad=True``. + self.padding = ((padding + 7) // 8) * 8 if padding % 8 != 0 else padding + + # Time-axis-extend: when set together with ``trunk=None`` and + # ``auto_pad=True``, the packed-input forward path interprets the + # last spatial axis as the time axis. Given ``target_times`` of + # length ``K`` it right-replicate-pads that axis to at least + # ``max(T_in + K, 2 * time_modes)`` before running the branch, + # then crops the output to the predicted ``K`` future steps. + # Only meaningful in trunkless packed mode. + if time_modes is not None and trunk is not None: + raise ValueError( + "time_modes is only meaningful when trunk is None " + "(xFNO-style trunkless operators). Drop time_modes, or " + "set trunk=None." + ) + if time_modes is not None and not auto_pad: + raise ValueError( + "time_modes requires auto_pad=True; the time-axis-extend " + "feature is part of the packed-input forward path." + ) + if time_modes is not None and time_modes < 1: + raise ValueError(f"time_modes must be >= 1, got {time_modes}.") + self.time_modes = time_modes + + self.width = width + + # Cached forward-time permute / ndim values; computed once at + # construction so the forward path can stay dimension-agnostic + # without rebuilding tuples on every call (and so torch.compile + # sees them as Python constants per model instance). + self._spatial_branch_ndim = dimension + 2 + self._spatial_axes = tuple(range(2, dimension + 2)) + # Trunked-mode permutes: ``combined`` has rank ``dimension + 3`` + # ``(B, T, *spatial, channels)``. + self._mlp_decoder_permute = (0, *self._spatial_axes, 1, dimension + 2) + self._conv_decoder_in_permute = (0, 1, dimension + 2, *self._spatial_axes) + self._conv_decoder_out_permute = ( + 0, + *tuple(range(3, dimension + 3)), + 1, + 2, + ) + # Trunkless-mode permutes: branch output has rank ``dimension + 2`` + # ``(B, *spatial, channels)``; the channel axis is moved to / + # from position 1 for conv-decoder dispatch. + self._trunkless_channel_first_permute = ( + 0, + dimension + 1, + *range(1, dimension + 1), + ) + self._trunkless_channel_last_permute = ( + 0, + *range(2, dimension + 2), + 1, + ) + + # Detect MLP vs spatial branches via instance check. This drives + # both the runtime forward dispatch (different unsqueeze / permute + # paths for spatial-vs-MLP branch outputs) and the fail-fast + # validation below. A non-:class:`SpatialBranch` module is + # assumed to produce a flat ``(B, width)`` output, matching the + # MLP-branch shape contract. + self._branch1_is_mlp = not isinstance(branch1, SpatialBranch) + self.has_branch2 = branch2 is not None + self._branch2_is_mlp = self.has_branch2 and not isinstance( + branch2, SpatialBranch + ) + + # ``temporal_projection`` decoder only makes sense with a trunk + # (it projects the trunk-queried combined latent to ``K`` output + # timesteps via a learned linear head). Without a trunk there is + # no temporal-query axis to project from. + if trunk is None and self.decoder_type == "temporal_projection": + raise ValueError( + "decoder_type='temporal_projection' requires a trunk; " + "use decoder_type='mlp' or 'conv' for trunkless operators." + ) + + # ``temporal_projection`` and ``conv`` decoders need a spatial + # ``combined`` tensor. MLP-branch forward produces a flat 3D + # tensor of shape (B, T, width), incompatible with both: ``conv`` + # crashes inside the decoder; ``temporal_projection`` silently + # drops the temporal head. Fail fast at construction. + if self._branch1_is_mlp and self.decoder_type in ( + "temporal_projection", + "conv", + ): + raise ValueError( + f"decoder_type={self.decoder_type!r} is not supported with " + "MLP branches. Use decoder_type='mlp', or pass a " + "SpatialBranch as branch1." + ) + + # Reject mixed (MLP branch1, spatial branch2): forward assumes + # branch2's output has the same rank as branch1's, otherwise the + # Hadamard product broadcasts nonsensically or raises a cryptic + # dim-mismatch error. + if self.has_branch2 and self._branch1_is_mlp and not self._branch2_is_mlp: + raise ValueError( + "When branch1 is an MLP branch, branch2 must also be an " + "MLP branch (i.e. produce a 2D (B, width) output). " + "Swap branch1 and branch2, or pass a SpatialBranch as " + "branch1." + ) + + # Reject MLP branch + auto_pad: packed-input mode assumes the + # input has spatial axes to pad and (in trunked mode) a time + # axis to strip. MLP branches consume flat ``(B, D_in)`` input + # and have neither. + if self._branch1_is_mlp and auto_pad: + raise ValueError( + "auto_pad=True requires a SpatialBranch branch1 (the " + "packed-input forward path operates on spatial dims). " + "Use auto_pad=False with an MLP branch." + ) + + # Register submodules. + self.branch1 = branch1 + if self.has_branch2: + self.branch2 = branch2 + # ``self.trunk`` is registered unconditionally (None or a module); + # PyTorch handles None submodule attributes fine. + self.trunk = trunk + + if self.decoder_type == "temporal_projection": + self._temporal_projection = True + self.decoder = self._build_decoder( + width, + width, + decoder_layers, + decoder_width, + "mlp", + decoder_activation_fn, + ) + # Preferred path: construct the temporal head at __init__ so + # state_dict keys are deterministic and checkpointing just works. + # When ``output_window`` is not provided the user must call + # :meth:`set_output_window` before the first forward pass. + # The head projects to ``output_window * out_channels`` so a + # multi-channel output is reshaped at the end. + if output_window is not None: + if output_window < 1: + raise ValueError( + f"output_window must be a positive integer, got {output_window}" + ) + self.temporal_head = nn.Linear(self.width, output_window * out_channels) + else: + self.temporal_head = None + else: + self._temporal_projection = False + self.decoder = self._build_decoder( + width, + out_channels, + decoder_layers, + decoder_width, + self.decoder_type, + decoder_activation_fn, + ) + + @property + def has_temporal_projection(self) -> bool: + """Whether the model was constructed with the temporal-projection + decoder (``decoder_type="temporal_projection"``). + + Public read-only view of the internal flag; preferred over reaching + into the private attribute from outside the class. + """ + return self._temporal_projection + + def set_output_window(self, K: int): + """Create the temporal-projection head for K output timesteps. + + The head projects to ``K * out_channels`` so the trailing + out-channels dim is preserved in the output. Only effective + when ``decoder_type="temporal_projection"``. + """ + if self._temporal_projection: + device = next(self.parameters()).device + self.temporal_head = nn.Linear(self.width, K * self.out_channels).to(device) + + def _build_decoder( + self, + width: int, + out_channels: int, + num_layers: int, + hidden_width: int, + decoder_type: str, + activation_fn: str, + ) -> nn.Module: + # Per-dimension dispatch for the spatial decoder. + ConvNdFC = Conv2dFCLayer if self.dimension == 2 else Conv3dFCLayer + + if decoder_type == "mlp": + if num_layers == 0: + return nn.Linear(width, out_channels) + return FullyConnected( + width, hidden_width, out_channels, num_layers, activation_fn + ) + + elif decoder_type == "conv": + if num_layers == 0: + return ConvNdFC(width, out_channels) + + layers = [] + in_ch = width + for _ in range(num_layers): + layers.extend( + [ConvNdFC(in_ch, hidden_width), get_activation(activation_fn)] + ) + in_ch = hidden_width + layers.append(ConvNdFC(hidden_width, out_channels)) + return nn.Sequential(*layers) + + else: + raise ValueError(f"Unknown decoder_type: {decoder_type}") + + def forward( + self, + *args: Float[Tensor, "..."], + x_branch2: Float[Tensor, "..."] | None = None, + target_times: Float[Tensor, "..."] | None = None, + ) -> Float[Tensor, "..."]: + """Forward pass. + + Dispatched by the :attr:`auto_pad` constructor flag: + + - **Packed mode** (``auto_pad=True``): ``model(x)`` (or + ``model(x, x_branch2)`` for MIONet-style dual-branch variants). + ``x`` has shape :math:`(B, *spatial, T, C)` with the time axis + and trunk / grid coordinates encoded in the channel dimension. + The model extracts the spatial branch input and trunk + coordinates itself (using :attr:`trunk_input`), right-pads the + spatial dims when :attr:`padding` is positive, runs the core + operator, and crops the output back to the original spatial + extent. ``target_times`` (keyword) selects an explicit set of + trunk query coordinates. + + - **Core mode** (``auto_pad=False``, the default): + ``model(x_branch1, x_time)`` (or + ``model(x_branch1, x_time, x_branch2)``). Required for the + MLP-branch code path (where there is no spatial axis to + extract from) and for power users who assemble the trunk + coordinates themselves. ``target_times`` must be ``None`` + in this mode. + + ``x_branch2`` may be passed positionally (second arg in packed + mode, third arg in core mode) or by keyword. + """ + # Branch on the four (auto_pad, trunk-is-None) combinations. + if self.auto_pad: + if self.trunk is None: + # Trunkless packed mode: ``model(x)`` only. + if len(args) != 1: + raise TypeError( + f"In trunkless packed-input mode (auto_pad=True, " + f"trunk=None), forward expects exactly 1 positional " + f"tensor, got {len(args)}." + ) + if x_branch2 is not None: + raise TypeError("x_branch2 is not supported in trunkless mode.") + return self._forward_packed_trunkless( + args[0], target_times=target_times + ) + + # Trunked packed mode. + if len(args) == 1: + return self._forward_packed( + args[0], + x_branch2=x_branch2, + target_times=target_times, + ) + if len(args) == 2: + if x_branch2 is not None: + raise TypeError( + "x_branch2 supplied both positionally and as a " + "keyword argument." + ) + return self._forward_packed( + args[0], + x_branch2=args[1], + target_times=target_times, + ) + raise TypeError( + f"In trunked packed-input mode (auto_pad=True, trunk!=None), " + f"forward expects 1 or 2 positional tensors ((x,) or " + f"(x, x_branch2)), got {len(args)}." + ) + + # Core mode (auto_pad=False). + if target_times is not None: + raise TypeError( + "target_times is only valid in packed-input mode " + "(construct DeepONet with auto_pad=True)." + ) + if self.trunk is None: + # Trunkless core mode: ``model(x)`` only. + if len(args) != 1: + raise TypeError( + f"In trunkless core mode (auto_pad=False, trunk=None), " + f"forward expects exactly 1 positional tensor, got " + f"{len(args)}." + ) + if x_branch2 is not None: + raise TypeError("x_branch2 is not supported in trunkless mode.") + return self._forward_core(args[0], None, x_branch2=None) + + # Trunked core mode. + if len(args) == 2: + x_branch1, x_time = args + b2 = x_branch2 + elif len(args) == 3: + if x_branch2 is not None: + raise TypeError( + "x_branch2 supplied both positionally and as a keyword argument." + ) + x_branch1, x_time, b2 = args + else: + raise TypeError( + f"In trunked core mode (auto_pad=False, trunk!=None), " + f"forward expects 2 ((x_branch1, x_time)) or 3 " + f"((x_branch1, x_time, x_branch2)) positional tensors, " + f"got {len(args)}." + ) + return self._forward_core(x_branch1, x_time, x_branch2=b2) + + def _forward_packed( + self, + x: Float[Tensor, "..."], + *, + x_branch2: Float[Tensor, "..."] | None = None, + target_times: Float[Tensor, "..."] | None = None, + ) -> Float[Tensor, "..."]: + """Trunked packed-input forward: unpack ``x`` and (optionally) auto-pad. + + ``x`` has shape :math:`(B, *spatial, T, C)`; this method: + + 1. Optionally right-pads the spatial dims to a multiple of 8 + when :attr:`auto_pad` is ``True``. + 2. Extracts the spatial branch input as ``x[..., 0, :]`` (the + ``T=0`` slice). + 3. Builds the trunk coordinates from ``x`` (or ``target_times`` + when provided) according to :attr:`trunk_input`. + 4. Runs :meth:`_forward_core` and crops back to the original + spatial extent if auto-padding was applied. + """ + dim = self.dimension + expected_ndim = dim + 3 # batch + spatial + time + channels + + if not torch.compiler.is_compiling(): + if x.ndim != expected_ndim: + spatial_doc = "H, W" if dim == 2 else "X, Y, Z" + raise ValueError( + f"Packed-input mode (trunked) expects {expected_ndim}D " + f"input (B, {spatial_doc}, T, C), got {x.ndim}D tensor " + f"with shape {tuple(x.shape)}." + ) + if target_times is not None and target_times.ndim not in (1, 2): + raise ValueError( + f"Expected target_times to be 1D (K,) or 2D (K, 1), " + f"got {target_times.ndim}D tensor with shape " + f"{tuple(target_times.shape)}." + ) + + spatial_shape = x.shape[1 : 1 + dim] + + # Right-pad the spatial axes (always to a multiple of 8, with the + # configured floor) when auto-padding is enabled. + if self.auto_pad and self.padding > 0: + pads = compute_right_pad_to_multiple( + spatial_shape, multiple=8, min_right_pad=self.padding + ) + x = pad_spatial_right(x, spatial_ndim=dim, right_pad=pads, mode="replicate") + if x_branch2 is not None and x_branch2.dim() > 2: + x_branch2 = pad_spatial_right( + x_branch2, + spatial_ndim=dim, + right_pad=pads, + mode="replicate", + ) + + # Strip the time axis -- spatial branch sees only the T=0 slice. + # Index = (slice(None),) * (1 + dim) + (0, slice(None)) + idx_strip_T = (slice(None),) * (1 + dim) + (0, slice(None)) + x_spatial = x[idx_strip_T] + # Symmetric handling for branch2: when it's also a packed + # (B, *spatial, T, C) tensor, strip its time axis the same way so + # the second spatial branch sees a 4D (B, *spatial, C) tensor. + # 2D (B, D_in) and already-stripped (B, *spatial, C) inputs are + # left untouched so MLP/spatial-branch consumers stay valid. + if x_branch2 is not None and x_branch2.ndim == expected_ndim: + x_branch2 = x_branch2[idx_strip_T] + + # Build the trunk input. All paths produce a (T_out, in_features) tensor. + if target_times is not None: + if self.trunk_input == "grid": + t_vals = ( + target_times + if target_times.dim() == 1 + else target_times.squeeze(-1) + ) + # Spatial coords of point [0, 0, ..., 0, t=0]: the + # ``dim`` channels preceding the time channel. + idx_spatial = (0,) * (2 + dim) + (slice(-(dim + 1), -1),) + spatial = x[idx_spatial] + spatial_exp = spatial.unsqueeze(0).expand(t_vals.shape[0], -1) + x_trunk = torch.cat([spatial_exp, t_vals.unsqueeze(-1)], dim=-1) + else: + x_trunk = ( + target_times + if target_times.dim() == 2 + else target_times.unsqueeze(-1) + ) + elif self.trunk_input == "grid": + # Sweep all T values at the first spatial point; keep last + # ``dim+1`` channels. + idx_grid_over_time = ( + (0,) * (1 + dim) + (slice(None),) + (slice(-(dim + 1), None),) + ) + x_trunk = x[idx_grid_over_time] + else: + # Time-only coords at the first spatial point. + idx_time_over_time = (0,) * (1 + dim) + (slice(None), -1) + x_trunk = x[idx_time_over_time].unsqueeze(-1) + + out = self._forward_core(x_spatial, x_trunk, x_branch2=x_branch2) + # out: (B, *padded_spatial, T_out, out_channels) + + # Crop back to original spatial extent when auto-padding shifted + # the padded dims out beyond ``spatial_shape``. Trailing two + # axes (T_out, out_channels) are preserved in full. + if self.auto_pad and self.padding > 0: + crop_idx = ( + (slice(None),) + + tuple(slice(0, s) for s in spatial_shape) + + (slice(None), slice(None)) + ) + out = out[crop_idx] + return out + + def _forward_packed_trunkless( + self, + x: Float[Tensor, "..."], + *, + target_times: Float[Tensor, "..."] | None = None, + ) -> Float[Tensor, "..."]: + """Trunkless packed-input forward (xFNO-style operator). + + ``x`` is channels-last ``(B, *spatial, C)``. Steps: + + 1. **Time-axis extension** (only when ``self.time_modes is not + None`` and ``target_times`` is provided with length + :math:`K \\neq T_{in}`): replicate-pad the last spatial axis + to :math:`\\max(T_{in} + K, 2 \\cdot \\texttt{time\\_modes})`. + 2. **Spatial padding** (when ``self.auto_pad`` and + ``self.padding > 0``): right-pad all spatial dims to a + multiple of 8 with a floor of ``self.padding``. + 3. Run the trunkless core forward (branch + decoder). + 4. **Crop** the output back to the original spatial shape. In + the time-axis-extend case, the last spatial axis is sliced to + :math:`[T_{in} : T_{in} + K]` (the predicted future steps); + otherwise it's sliced to :math:`[:T_{in}]`. + + Output shape: ``(B, *spatial_or_K, out_channels)``. + """ + dim = self.dimension + expected_ndim = dim + 2 # batch + spatial + channels + + if not torch.compiler.is_compiling(): + if x.ndim != expected_ndim: + spatial_doc = "H, W" if dim == 2 else "X, Y, Z" + raise ValueError( + f"Packed-input mode (trunkless) expects {expected_ndim}D " + f"input (B, {spatial_doc}, C), got {x.ndim}D tensor " + f"with shape {tuple(x.shape)}." + ) + if target_times is not None: + if self.time_modes is None: + raise ValueError( + "target_times provided but the model was constructed " + "without time_modes; nothing to extend. Either pass " + "time_modes=N at construction (xFNO-style autoregressive " + "bundling) or omit target_times." + ) + if target_times.ndim not in (1, 2): + raise ValueError( + f"Expected target_times to be 1D (K,) or 2D (K, 1), " + f"got {target_times.ndim}D tensor with shape " + f"{tuple(target_times.shape)}." + ) + + original_spatial = x.shape[1 : 1 + dim] + + # Time-axis extension (xFNO autoregressive bundling). The last + # spatial axis is the time axis by convention. ``K`` is the + # number of future steps to predict; when ``K == T_in`` (or + # ``target_times`` is absent) no extension happens and the + # output covers the original time axis. + k_future: int | None = None + if self.time_modes is not None and target_times is not None: + k_candidate = target_times.shape[0] + t_in = original_spatial[-1] + if k_candidate != t_in: + k_future = k_candidate + desired_t = t_in + k_future + min_t = max(desired_t, 2 * self.time_modes) + extra = min_t - t_in + time_pad = (0,) * (dim - 1) + (extra,) + x = pad_spatial_right( + x, spatial_ndim=dim, right_pad=time_pad, mode="replicate" + ) + + # Spatial padding to a multiple of 8 (after any time extension). + if self.auto_pad and self.padding > 0: + current_spatial = x.shape[1 : 1 + dim] + pads = compute_right_pad_to_multiple( + current_spatial, multiple=8, min_right_pad=self.padding + ) + x = pad_spatial_right(x, spatial_ndim=dim, right_pad=pads, mode="replicate") + + # Trunkless core forward. + out = self._forward_core(x, None, x_branch2=None) + # out: (B, *padded_spatial, out_channels) + + # Crop back. When time-extending: the last spatial axis is + # sliced to the K future steps; other spatial axes to original. + if k_future is not None: + t_in = original_spatial[-1] + crop = ( + (slice(None),) + + tuple(slice(0, s) for s in original_spatial[:-1]) + + (slice(t_in, t_in + k_future),) + + (slice(None),) # out_channels axis + ) + else: + crop = ( + (slice(None),) + + tuple(slice(0, s) for s in original_spatial) + + (slice(None),) # out_channels axis + ) + return out[crop] + + def _forward_core( + self, + x_branch1: Float[Tensor, "..."], + x_time: Float[Tensor, "..."] | None, + x_branch2: Float[Tensor, "..."] | None = None, + ) -> Float[Tensor, "..."]: + """Raw branch + (optional) trunk + decoder forward. + + ``x_branch1`` is either 2D ``(B, D_in)`` (MLP branches) or + ``(dimension + 2)``-D channels-last spatial input. ``x_time`` is + 1D ``(T,)``, 2D ``(T, D_trunk)``, or ``None`` (trunkless + operator). Called by :meth:`forward` in core mode and internally + by :meth:`_forward_packed`. + + Output shape: + + - Spatial branch + trunk: ``(B, *spatial, T, out_channels)`` + - Spatial branch + trunkless: ``(B, *spatial, out_channels)`` + - Spatial branch + ``temporal_projection``: + ``(B, *spatial, output_window, out_channels)`` + - MLP branch + trunk: ``(B, T, out_channels)`` + - MLP branch + trunkless: ``(B, out_channels)`` + """ + spatial_ndim = self._spatial_branch_ndim + + if not torch.compiler.is_compiling(): + if x_branch1.ndim not in (2, spatial_ndim): + spatial_shape_doc = ( + "(B, H, W, C)" if self.dimension == 2 else "(B, X, Y, Z, C)" + ) + raise ValueError( + f"Expected x_branch1 to be 2D (B, D_in) for MLP branches " + f"or {spatial_ndim}D {spatial_shape_doc} for spatial " + f"branches, got {x_branch1.ndim}D tensor with shape " + f"{tuple(x_branch1.shape)}" + ) + if x_time is not None and x_time.ndim not in (1, 2): + raise ValueError( + f"Expected x_time to be 1D (T,) or 2D (T, D), got " + f"{x_time.ndim}D tensor with shape {tuple(x_time.shape)}" + ) + if self.has_branch2 and x_branch2 is None: + raise ValueError( + "branch2 is configured but x_branch2 was not provided " + "to forward()." + ) + + b1_out = self.branch1(x_branch1) + + if self.has_branch2: + if x_branch2 is None: + raise ValueError("x_branch2 required for mionet/tno variants") + b2_out = self.branch2(x_branch2) + + # ---- Trunkless path (xFNO-style operator) ---------------------- + if self.trunk is None: + if b1_out.dim() == spatial_ndim: + # Spatial branch: combine with optional branch2 directly. + combined = b1_out + if self.has_branch2: + combined = combined * b2_out + return self._apply_decoder_trunkless(combined) + # MLP branch: flat ``(B, width)`` -> ``(B, out_channels)``. + combined = b1_out + if self.has_branch2: + combined = combined * b2_out + return self.decoder(combined) + + # ---- Trunked path ---------------------------------------------- + if x_time.dim() == 1: + x_time = x_time.unsqueeze(-1) + trunk_out = self.trunk(x_time) + + if b1_out.dim() == spatial_ndim: # Spatial branch path + if self._temporal_projection: + # Broadcast a single trunk value across every spatial point: + # trunk_single: (1, width) -> (1, 1, ..., 1, width) with + # ``dimension`` spatial singleton axes inserted at position 1. + trunk_single = trunk_out[0:1] + trunk_exp = trunk_single + for _ in range(self.dimension): + trunk_exp = trunk_exp.unsqueeze(1) + combined = b1_out * trunk_exp + if self.has_branch2: + if b2_out.dim() == spatial_ndim: + combined = combined * b2_out + else: + b2_exp = b2_out + for _ in range(self.dimension): + b2_exp = b2_exp.unsqueeze(1) + combined = combined * b2_exp + combined = self.decoder(combined) + if self.temporal_head is None: + raise RuntimeError( + "decoder_type='temporal_projection' requires either " + "output_window to be provided at construction time, " + "or set_output_window(K) to be called before forward." + ) + # temporal_head: width -> (output_window * out_channels); + # reshape splits the trailing dim into (output_window, + # out_channels) so the final shape is + # (B, *spatial, output_window, out_channels). + head_out = self.temporal_head(combined) + head_shape = head_out.shape + return head_out.reshape(*head_shape[:-1], -1, self.out_channels) + + # Insert a time axis at position 1 in the branch output: + # (B, *spatial, width) -> (B, 1, *spatial, width) + b1_out = b1_out.unsqueeze(1) + # Insert a batch axis at position 0 and ``dimension`` spatial + # singleton axes at position 2, giving (1, T, *1..*1, width). + trunk_out = trunk_out.unsqueeze(0) + for _ in range(self.dimension): + trunk_out = trunk_out.unsqueeze(2) + + if self.has_branch2: + if b2_out.dim() == spatial_ndim: + b2_out = b2_out.unsqueeze(1) + else: + b2_out = b2_out.unsqueeze(1) + for _ in range(self.dimension): + b2_out = b2_out.unsqueeze(2) + combined = b1_out * b2_out * trunk_out + else: + combined = b1_out * trunk_out + + # ``combined`` is now (B, T, *spatial, width). + if self.decoder_type == "mlp": + # Decoder maps width -> out_channels: + # (B, T, *spatial, out_channels). Move T from position 1 + # to the second-to-last so result is + # (B, *spatial, T, out_channels). + return self.decoder(combined).permute(*self._mlp_decoder_permute) + + # ``conv`` decoder: needs channel-first (B*T, width, *spatial) + # input, returns (B*T, out_channels, *spatial). + shape = combined.shape + batch_size, n_t = shape[0], shape[1] + spatial_shape = shape[2:-1] + ch = shape[-1] + combined = combined.permute(*self._conv_decoder_in_permute).reshape( + batch_size * n_t, ch, *spatial_shape + ) + decoded = self.decoder(combined) + # decoded: (B*T, out_channels, *spatial) + decoded = decoded.reshape( + batch_size, n_t, self.out_channels, *spatial_shape + ) + # -> (B, *spatial, T, out_channels) + return decoded.permute(*self._conv_decoder_out_permute) + + # MLP branch + trunk path (no spatial axes). + b1_out = b1_out.unsqueeze(1) + trunk_out = trunk_out.unsqueeze(0) + if self.has_branch2: + combined = b1_out * b2_out.unsqueeze(1) * trunk_out + else: + combined = b1_out * trunk_out + return self.decoder(combined) + + def _apply_decoder_trunkless( + self, + branch_out: Float[Tensor, "..."], + ) -> Float[Tensor, "..."]: + """Apply the decoder to a trunkless spatial branch output. + + ``branch_out`` is channels-last ``(B, *spatial, width)``; the + returned tensor is channels-last ``(B, *spatial, out_channels)``. + """ + if self.decoder_type == "mlp": + # MLP decoder acts pointwise on the last axis; no permute needed. + return self.decoder(branch_out) + # ``conv`` decoder operates channels-first. + cf = branch_out.permute(*self._trunkless_channel_first_permute) + cf = self.decoder(cf) + return cf.permute(*self._trunkless_channel_last_permute) + + +__all__ = [ + "DeepONet", +] diff --git a/physicsnemo/nn/__init__.py b/physicsnemo/nn/__init__.py index fb7dfaa161..52b1829c12 100644 --- a/physicsnemo/nn/__init__.py +++ b/physicsnemo/nn/__init__.py @@ -21,6 +21,7 @@ CappedGELU, CappedLeakyReLU, Identity, + Sin, SquarePlus, Stan, get_activation, diff --git a/physicsnemo/nn/module/activations.py b/physicsnemo/nn/module/activations.py index 9e4369d824..54fdc2f646 100644 --- a/physicsnemo/nn/module/activations.py +++ b/physicsnemo/nn/module/activations.py @@ -169,6 +169,21 @@ def forward(self, inputs): return x +class Sin(nn.Module): + """Elementwise sine activation: :math:`y = \\sin(x)`. + + Example + ------- + >>> sin_func = physicsnemo.nn.Sin() + >>> input = torch.tensor([0.0, 3.141592653589793 / 2]) + >>> torch.allclose(sin_func(input), torch.tensor([0.0, 1.0]), atol=1e-6) + True + """ + + def forward(self, x: Tensor) -> Tensor: + return torch.sin(x) + + # Dictionary of activation functions ACT2FN = { "relu": nn.ReLU, @@ -180,6 +195,7 @@ def forward(self, inputs): "selu": nn.SELU, "silu": nn.SiLU, "gelu": nn.GELU, + "sin": Sin, "sigmoid": nn.Sigmoid, "logsigmoid": nn.LogSigmoid, "softplus": nn.Softplus, diff --git a/test/experimental/models/xdeeponet/data/_generate_xdeeponet_goldens.py b/test/experimental/models/xdeeponet/data/_generate_xdeeponet_goldens.py new file mode 100644 index 0000000000..ad135e2579 --- /dev/null +++ b/test/experimental/models/xdeeponet/data/_generate_xdeeponet_goldens.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regenerate the xDeepONet golden ``.pth`` fixtures. + +Run from the repository root:: + + python test/experimental/models/xdeeponet/data/_generate_xdeeponet_goldens.py + +Overwrites the committed fixtures with freshly-seeded model outputs. +Invoke this deliberately whenever model numerics intentionally change +(architecture edit, default-argument change, etc.) and commit the +resulting ``.pth`` files. + +The set of fixtures is driven by :data:`_FIXTURE_REGISTRY` in +``test_xdeeponet.py`` — adding a new scenario there automatically +extends this generator. + +Each fixture stores a dict with three keys: + +- ``"args"``: tuple of positional forward arguments +- ``"y"``: stored output for the non-regression assertion +- ``"state_dict"``: model parameters +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import torch + +_REPO_ROOT = Path(__file__).resolve().parents[5] +# Repo root: so ``import physicsnemo...`` resolves. +# xdeeponet test dir: so ``import test_xdeeponet`` resolves. +sys.path.insert(0, str(_REPO_ROOT)) +sys.path.insert(0, str(_REPO_ROOT / "test" / "experimental" / "models" / "xdeeponet")) + +from test_xdeeponet import ( # noqa: E402 + _FIXTURE_REGISTRY, + _init_lazy, +) + + +def _write(path: Path, builder) -> None: + """Materialise lazy weights, run forward, and save the golden payload.""" + path.parent.mkdir(parents=True, exist_ok=True) + model, args = builder() + _init_lazy(model, *args) + with torch.no_grad(): + y = model(*args) + torch.save( + {"args": tuple(args), "y": y, "state_dict": model.state_dict()}, + path, + ) + arg_shapes = [tuple(a.shape) for a in args] + print( + f"wrote {path.relative_to(_REPO_ROOT)} " + f"args={arg_shapes} y={tuple(y.shape)} " + f"size={path.stat().st_size}B" + ) + + +if __name__ == "__main__": + for _name, _builder, _golden_path in _FIXTURE_REGISTRY: + _write(_golden_path, _builder) diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_core_2d_mlpbranch_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_core_2d_mlpbranch_v1.pth new file mode 100644 index 0000000000..834b0c7c8f Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_core_2d_mlpbranch_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_core_3d_kitchen_sink_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_core_3d_kitchen_sink_v1.pth new file mode 100644 index 0000000000..540357c111 Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_core_3d_kitchen_sink_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_fourier_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_fourier_v1.pth new file mode 100644 index 0000000000..1eb3da749f Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_fourier_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_kitchen_sink_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_kitchen_sink_v1.pth new file mode 100644 index 0000000000..9b772a0399 Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_kitchen_sink_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_mionet_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_mionet_v1.pth new file mode 100644 index 0000000000..bee12c5a91 Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_mionet_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_multichannel_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_multichannel_v1.pth new file mode 100644 index 0000000000..d3eb7091a5 Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_multichannel_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_temporal_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_temporal_v1.pth new file mode 100644 index 0000000000..a7c900f07a Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_temporal_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_v1.pth new file mode 100644 index 0000000000..e85c422792 Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_packed_2d_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_packed_3d_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_packed_3d_v1.pth new file mode 100644 index 0000000000..be4b74b45c Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_packed_3d_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_xfno_packed_3d_extend_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_xfno_packed_3d_extend_v1.pth new file mode 100644 index 0000000000..22824f1610 Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_xfno_packed_3d_extend_v1.pth differ diff --git a/test/experimental/models/xdeeponet/data/xdeeponet_xfno_packed_3d_v1.pth b/test/experimental/models/xdeeponet/data/xdeeponet_xfno_packed_3d_v1.pth new file mode 100644 index 0000000000..4041058de7 Binary files /dev/null and b/test/experimental/models/xdeeponet/data/xdeeponet_xfno_packed_3d_v1.pth differ diff --git a/test/experimental/models/xdeeponet/test_xdeeponet.py b/test/experimental/models/xdeeponet/test_xdeeponet.py new file mode 100644 index 0000000000..7212e83edf --- /dev/null +++ b/test/experimental/models/xdeeponet/test_xdeeponet.py @@ -0,0 +1,1212 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test suite for the xDeepONet family. + +Covers, per `MOD-008a/b/c <../../CODING_STANDARDS/MODELS_IMPLEMENTATION.md>`_: + +- **Constructor + public attributes** (MOD-008a) — default and custom configs. +- **Forward non-regression** (MOD-008b) — compare a single forward pass + against committed golden ``.pth`` fixtures. +- **Checkpoint round-trip** (MOD-008c) — ``save`` to ``.mdlus``, reload via + :meth:`physicsnemo.Module.from_checkpoint`, and verify the loaded model + reproduces the same output as the in-memory model. +- **Gradient flow** — backward pass produces non-None gradients on input + and parameters. +- **torch.compile smoke** — wrapping the model in :func:`torch.compile` + succeeds and produces shape-compatible output. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +import pytest +import torch + +from physicsnemo import Module +from physicsnemo.core.meta import ModelMetaData +from physicsnemo.experimental.models.xdeeponet import DeepONet, SpatialBranch +from physicsnemo.models.mlp import FullyConnected +from physicsnemo.nn import get_activation + +_DATA_DIR = Path(__file__).parent / "data" +_SEED = 0 + +# ----- Golden fixture paths ------------------------------------------------ +# +# One ``.pth`` per scenario. The fixture filenames are versioned (``_v1``) +# so a new ``v2`` can land alongside an older fixture during a numerics +# transition. + +# Packed-input (auto_pad=True) scenarios. +_GOLDEN_PACKED_2D = _DATA_DIR / "xdeeponet_packed_2d_v1.pth" +_GOLDEN_PACKED_3D = _DATA_DIR / "xdeeponet_packed_3d_v1.pth" +_GOLDEN_PACKED_2D_FOURIER = _DATA_DIR / "xdeeponet_packed_2d_fourier_v1.pth" +_GOLDEN_PACKED_2D_MIONET = _DATA_DIR / "xdeeponet_packed_2d_mionet_v1.pth" +_GOLDEN_PACKED_2D_TEMPORAL = _DATA_DIR / "xdeeponet_packed_2d_temporal_v1.pth" +_GOLDEN_PACKED_2D_MULTICHANNEL = _DATA_DIR / "xdeeponet_packed_2d_multichannel_v1.pth" +# Kitchen-sink scenarios: every major code path turned on simultaneously. +# - 2D variant: packed-input mode + ``temporal_projection`` decoder + +# ``output_window > 1`` + ``trunk_input="grid"`` + multi-layer lift + +# ``coord_features`` asymmetry across branches. +# - 3D variant: core mode (``auto_pad=False``) + ``decoder_type="conv"`` +# + mionet dual-branch + deeper trunk (3 layers, no output activation) +# + ``lift_hidden_width`` set explicitly + a different activation +# palette (celu / leaky_relu / elu / tanh). +# Together they exercise nearly every constructor knob the model exposes. +_GOLDEN_PACKED_2D_KITCHEN_SINK = _DATA_DIR / "xdeeponet_packed_2d_kitchen_sink_v1.pth" +_GOLDEN_CORE_3D_KITCHEN_SINK = _DATA_DIR / "xdeeponet_core_3d_kitchen_sink_v1.pth" +# Trunkless packed-input (xFNO-style) scenarios. +_GOLDEN_XFNO_PACKED_3D = _DATA_DIR / "xdeeponet_xfno_packed_3d_v1.pth" +_GOLDEN_XFNO_PACKED_3D_EXTEND = _DATA_DIR / "xdeeponet_xfno_packed_3d_extend_v1.pth" +# Core-mode (auto_pad=False) fixture for the MLP-branch path. +_GOLDEN_CORE_2D_MLPBRANCH = _DATA_DIR / "xdeeponet_core_2d_mlpbranch_v1.pth" + + +# ----- Module builders ----------------------------------------------------- +# +# DeepONet expects branch / trunk modules to be constructed and passed in +# directly. These helpers produce minimal modules so the golden files +# stay tiny (test inputs are 1x8x8 or 1x8x8x8) and every test runs in +# well under a second. +# +# Note on physicsnemo.Module compliance: every submodule passed into +# DeepONet as a constructor argument (branch1, branch2, trunk) must be a +# physicsnemo.Module instance, otherwise Module.save rejects the +# hierarchy at serialization time (see Module._save_process). A bare +# nn.Sequential wrapper around a FullyConnected does not satisfy that +# contract, so :class:`_MLPWithTrailingActivation` below replaces the +# nn.Sequential pattern used by ``_make_trunk`` and ``_make_mlp_branch``. + + +@dataclass +class _MLPWithTrailingActivationMeta(ModelMetaData): + """PhysicsNeMo metadata for the test-only MLP+activation wrapper.""" + + +class _MLPWithTrailingActivation(Module): + """Test-only physicsnemo.Module replacement for + ``nn.Sequential(FullyConnected, get_activation(...))``. + + A plain ``nn.Sequential`` cannot be a DeepONet constructor arg + because :meth:`Module._save_process` rejects ``torch.nn.Module`` + instances in ``_args``; this lightweight subclass satisfies the + contract without changing forward semantics. Production users + wanting the same pattern should define their own + :class:`physicsnemo.Module` subclass (or use + :meth:`Module.from_torch` on a custom class). + + Parameters + ---------- + in_features : int + Input feature count, forwarded to :class:`FullyConnected`. + layer_size : int + Hidden width, forwarded to :class:`FullyConnected.layer_size`. + out_features : int + Output feature count, forwarded to :class:`FullyConnected`. + num_layers : int + Hidden-layer count, forwarded to :class:`FullyConnected.num_layers`. + activation_fn : str + Activation name used both as the FullyConnected hidden activation + and as the trailing activation applied to the projection output. + + Forward + ------- + x : torch.Tensor + Input tensor of shape ``(..., in_features)``. + + Outputs + ------- + torch.Tensor + Output tensor of shape ``(..., out_features)`` after the trailing + activation. + """ + + def __init__( + self, + *, + in_features: int, + layer_size: int, + out_features: int, + num_layers: int, + activation_fn: str, + ): + super().__init__(meta=_MLPWithTrailingActivationMeta()) + self.fc = FullyConnected( + in_features=in_features, + layer_size=layer_size, + out_features=out_features, + num_layers=num_layers, + activation_fn=activation_fn, + ) + self.activation = get_activation(activation_fn) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.activation(self.fc(x)) + + +def _make_unet_spatial_branch(dimension: int, width: int) -> SpatialBranch: + """Spatial branch with a single UNet layer (U-DeepONet style).""" + return SpatialBranch( + dimension=dimension, + in_channels=2, + width=width, + num_unet_layers=1, + kernel_size=3, + activation_fn="relu", + ) + + +def _make_fourier_spatial_branch(dimension: int, width: int) -> SpatialBranch: + """Spatial branch with a single Fourier layer (Fourier-DeepONet style).""" + return SpatialBranch( + dimension=dimension, + in_channels=2, + width=width, + num_fourier_layers=1, + modes1=2, + modes2=2, + activation_fn="relu", + ) + + +def _make_mlp_branch( + *, + in_features: int, + hidden_width: int, + out_features: int, + num_layers: int, + activation_fn: str = "relu", +) -> Module: + """Flat MLP branch: ``num_layers`` activated linears in total. + + Composed as :class:`FullyConnected` (with ``num_layers - 1`` activated + hidden layers + one unactivated projection) wrapped with a trailing + activation so every linear is followed by an activation. The + wrapping is a :class:`_MLPWithTrailingActivation` instance so the + branch is a :class:`physicsnemo.Module` and survives + :meth:`DeepONet.save`. + """ + return _MLPWithTrailingActivation( + in_features=in_features, + layer_size=hidden_width, + out_features=out_features, + num_layers=num_layers - 1, + activation_fn=activation_fn, + ) + + +def _make_trunk( + *, + in_features: int = 1, + out_features: int, + hidden_width: int = 16, + num_layers: int = 2, + activation_fn: str = "tanh", + output_activation: bool = True, +) -> Module: + """Trunk MLP. + + A :class:`FullyConnected` produces ``num_layers`` activated hidden + linears followed by a single unactivated projection + (``hidden_width -> out_features``); when ``output_activation`` is + true the projection is wrapped with a trailing activation. + + Both branches of the conditional return a :class:`physicsnemo.Module` + so the trunk survives :meth:`DeepONet.save`. With + ``output_activation=False`` the bare :class:`FullyConnected` (already + a physicsnemo.Module) is returned; otherwise a + :class:`_MLPWithTrailingActivation` wraps the same FullyConnected + semantics plus a trailing activation. + """ + if output_activation: + return _MLPWithTrailingActivation( + in_features=in_features, + layer_size=hidden_width, + out_features=out_features, + num_layers=num_layers, + activation_fn=activation_fn, + ) + return FullyConnected( + in_features=in_features, + layer_size=hidden_width, + out_features=out_features, + num_layers=num_layers, + activation_fn=activation_fn, + ) + + +# ----- Fixture builders ---------------------------------------------------- + + +def _wrapper_2d() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Packed-input 2D U-DeepONet builder.""" + torch.manual_seed(_SEED) + model = DeepONet( + branch1=_make_unet_spatial_branch(dimension=2, width=8), + trunk=_make_trunk(out_features=8), + dimension=2, + width=8, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + auto_pad=True, + padding=8, + trunk_input="time", + ) + x = torch.randn(1, 8, 8, 2, 2) + return model, (x,) + + +def _wrapper_3d() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Packed-input 3D U-DeepONet builder.""" + torch.manual_seed(_SEED) + model = DeepONet( + branch1=_make_unet_spatial_branch(dimension=3, width=8), + trunk=_make_trunk(out_features=8), + dimension=3, + width=8, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + auto_pad=True, + padding=8, + trunk_input="time", + ) + x = torch.randn(1, 8, 8, 8, 2, 2) + return model, (x,) + + +def _wrapper_2d_fourier() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Packed-input 2D Fourier-DeepONet builder (exercises SpectralConv2d).""" + torch.manual_seed(_SEED) + model = DeepONet( + branch1=_make_fourier_spatial_branch(dimension=2, width=8), + trunk=_make_trunk(out_features=8), + dimension=2, + width=8, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + auto_pad=True, + padding=8, + trunk_input="time", + ) + x = torch.randn(1, 8, 8, 2, 2) + return model, (x,) + + +def _wrapper_2d_mionet() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Packed-input 2D MIONet builder (exercises the dual-branch path).""" + torch.manual_seed(_SEED) + model = DeepONet( + branch1=_make_unet_spatial_branch(dimension=2, width=8), + branch2=_make_unet_spatial_branch(dimension=2, width=8), + trunk=_make_trunk(out_features=8), + dimension=2, + width=8, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + auto_pad=True, + padding=8, + trunk_input="time", + ) + x = torch.randn(1, 8, 8, 2, 2) + x_branch2 = torch.randn(1, 8, 8, 2, 2) + return model, (x, x_branch2) + + +def _wrapper_2d_temporal() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Packed-input 2D builder exercising the ``temporal_projection`` decoder.""" + torch.manual_seed(_SEED) + model = DeepONet( + branch1=_make_unet_spatial_branch(dimension=2, width=8), + trunk=_make_trunk(out_features=8), + dimension=2, + width=8, + decoder_type="temporal_projection", + decoder_width=8, + decoder_layers=1, + output_window=3, + auto_pad=True, + padding=8, + trunk_input="time", + ) + x = torch.randn(1, 8, 8, 2, 2) + return model, (x,) + + +def _xfno_packed_3d() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Packed-input trunkless 3D operator (xFNO / U-FNO style). + + No trunk MLP; the branch produces a spatial latent that the decoder + projects to ``out_channels`` directly. Auto-padding is on but + ``time_modes`` is not set, so no time-axis-extend occurs. The input + is channels-last ``(B, *spatial, C)`` and the output is + ``(B, *spatial, out_channels)``. + """ + torch.manual_seed(_SEED) + branch1 = SpatialBranch( + dimension=3, + in_channels=2, + width=8, + num_fourier_layers=2, + num_unet_layers=1, + modes1=2, + modes2=2, + modes3=2, + kernel_size=3, + activation_fn="relu", + coord_features=True, + ) + model = DeepONet( + branch1=branch1, + trunk=None, + dimension=3, + width=8, + out_channels=1, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + auto_pad=True, + padding=8, + ) + x = torch.randn(1, 8, 8, 4, 2) # (B, H, W, T_in, C) + return model, (x,) + + +def _xfno_packed_3d_extend() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Packed-input trunkless 3D operator with ``time_modes`` set. + + The returned ``args`` tuple contains only the input tensor. Tests + that need to drive the time-axis-extend feature pass + ``target_times`` as a keyword argument when calling the model + (not part of the standard fixture-registry contract). + """ + torch.manual_seed(_SEED) + branch1 = SpatialBranch( + dimension=3, + in_channels=2, + width=8, + num_fourier_layers=2, + num_unet_layers=0, + modes1=2, + modes2=2, + modes3=2, + kernel_size=3, + activation_fn="relu", + coord_features=True, + ) + model = DeepONet( + branch1=branch1, + trunk=None, + dimension=3, + width=8, + out_channels=1, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + auto_pad=True, + padding=8, + time_modes=2, + ) + x = torch.randn(1, 8, 8, 4, 2) # (B, H, W, T_in=4, C) + return model, (x,) + + +def _packed_2d_multichannel() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Trunked packed-input 2D builder with ``out_channels=3``. + + Exercises the multi-channel-output path: the decoder's final layer + maps width to ``out_channels=3`` and the output tensor's trailing + dim is 3 (not squeezed). + """ + torch.manual_seed(_SEED) + model = DeepONet( + branch1=_make_unet_spatial_branch(dimension=2, width=8), + trunk=_make_trunk(out_features=8), + dimension=2, + width=8, + out_channels=3, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + auto_pad=True, + padding=8, + trunk_input="time", + ) + x = torch.randn(1, 8, 8, 2, 2) + return model, (x,) + + +def _packed_2d_kitchen_sink() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Kitchen-sink 2D builder exercising every major code path. + + Turns on every :class:`SpatialBranch` sub-stack (Fourier + UNet + + Conv) on **both** primary and secondary branches, the mionet + dual-branch Hadamard product, the ``temporal_projection`` decoder + with ``output_window > 1``, multi-channel output + (``out_channels=2``), ``trunk_input="grid"`` (trunk sees the full + ``(x, y, t)`` coordinate), a multi-layer pointwise lift network on + branch1, asymmetric ``coord_features`` and activation functions + between the two branches, and the Sin trunk activation. + + This is the most complex single configuration the model exposes; + if anything regresses across these knobs the recorded golden + payload (or the companion :class:`TestDeepONetStress` checks) will + surface the regression early. + """ + torch.manual_seed(_SEED) + # ``trunk_input="grid"`` in 2D reads the last ``dim+1 = 3`` channels + # of ``x`` (the (x, y, t) coords) to build the trunk input. Both + # branches therefore see ``in_channels=3``; ``coord_features=True`` + # on branch1 lifts to 5 effective channels before the linear lift. + branch1 = SpatialBranch( + dimension=2, + in_channels=3, + width=8, + num_fourier_layers=1, + num_unet_layers=1, + num_conv_layers=1, + modes1=2, + modes2=2, + kernel_size=3, + activation_fn="gelu", + coord_features=True, + lift_layers=2, + ) + branch2 = SpatialBranch( + dimension=2, + in_channels=3, + width=8, + num_fourier_layers=1, + num_unet_layers=1, + num_conv_layers=1, + modes1=2, + modes2=2, + kernel_size=3, + activation_fn="silu", + coord_features=False, + lift_layers=1, + ) + trunk = _make_trunk( + in_features=3, + out_features=8, + hidden_width=16, + num_layers=2, + activation_fn="sin", + ) + model = DeepONet( + branch1=branch1, + branch2=branch2, + trunk=trunk, + dimension=2, + width=8, + out_channels=2, + decoder_type="temporal_projection", + decoder_width=8, + decoder_layers=2, + decoder_activation_fn="tanh", + output_window=3, + auto_pad=True, + padding=8, + trunk_input="grid", + ) + x = torch.randn(1, 8, 8, 2, 3) + x_branch2 = torch.randn(1, 8, 8, 2, 3) + return model, (x, x_branch2) + + +def _core_3d_kitchen_sink() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Core-mode 3D mionet builder hitting code paths the other fixtures skip. + + Distinct from :func:`_packed_2d_kitchen_sink` along several axes + that no other fixture exercises: + + - ``auto_pad=False`` (core mode) in 3D with a trunk and a spatial + branch — the packed-input wrapper path is bypassed entirely; + ``forward`` dispatches through the ``(x_branch1, x_time, + x_branch2)`` core entry point. + - ``decoder_type="conv"`` — exercises the convolutional decoder + head (``Conv3dFCLayer`` stack with channel-first permute), which + no other fixture covers. + - 3D mionet — :func:`_wrapper_2d_mionet` covers the 2D path; this + is the 3D counterpart. + - Fourier + UNet + Conv sub-stacks composed on **both** 3D + branches (``_wrapper_3d`` is UNet-only, ``_xfno_packed_3d`` has + no Conv stack). + - ``lift_layers=3`` with ``lift_hidden_width`` set explicitly on + branch1 — exercises the multi-layer pointwise lift network with + a custom hidden width. + - Trunk with ``num_layers=3`` and ``output_activation=False`` — no + other fixture builds a 3-layer trunk or skips the trailing + activation wrapper. + - Activation palette: ``celu`` (branch1), ``leaky_relu`` (branch2), + ``tanh`` (trunk), ``elu`` (decoder). None of these appear in + another fixture. + + The 8x8x8 spatial input is chosen so the UNet sub-stack's pool + chain doesn't collapse to a 1x1x1 BatchNorm input (training mode + forbids that with batch_size=1). + """ + torch.manual_seed(_SEED) + branch1 = SpatialBranch( + dimension=3, + in_channels=2, + width=8, + num_fourier_layers=1, + num_unet_layers=1, + num_conv_layers=1, + modes1=2, + modes2=2, + modes3=2, + kernel_size=3, + activation_fn="celu", + coord_features=True, + lift_layers=3, + lift_hidden_width=12, + ) + branch2 = SpatialBranch( + dimension=3, + in_channels=2, + width=8, + num_fourier_layers=1, + num_unet_layers=1, + num_conv_layers=1, + modes1=2, + modes2=2, + modes3=2, + kernel_size=3, + activation_fn="leaky_relu", + coord_features=True, + lift_layers=2, + ) + trunk = _make_trunk( + in_features=1, + out_features=8, + hidden_width=12, + num_layers=3, + activation_fn="tanh", + output_activation=False, + ) + model = DeepONet( + branch1=branch1, + branch2=branch2, + trunk=trunk, + dimension=3, + width=8, + out_channels=2, + decoder_type="conv", + decoder_width=16, + decoder_layers=2, + decoder_activation_fn="elu", + auto_pad=False, + ) + x_branch1 = torch.randn(1, 8, 8, 8, 2) + x_time = torch.linspace(0, 1, 3).unsqueeze(-1) + x_branch2 = torch.randn(1, 8, 8, 8, 2) + return model, (x_branch1, x_time, x_branch2) + + +def _core_2d_mlpbranch() -> tuple[DeepONet, tuple[torch.Tensor, ...]]: + """Core-mode 2D builder exercising the MLP-branch (non-spatial) code path. + + The MLP branch consumes a flat ``(B, D_in)`` input rather than a + packed spatial tensor; this scenario is built against the core + forward (no ``auto_pad``). + """ + torch.manual_seed(_SEED) + model = DeepONet( + branch1=_make_mlp_branch( + in_features=4, + hidden_width=16, + out_features=8, + num_layers=2, + ), + trunk=_make_trunk(out_features=8), + dimension=2, + width=8, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + ) + x_branch1 = torch.randn(2, 4) # (B, D_in) + x_time = torch.linspace(0, 1, 3).unsqueeze(-1) # (T, 1) + return model, (x_branch1, x_time) + + +def _init_lazy(model, *args) -> None: + """Run one forward pass to materialise ``nn.LazyLinear`` parameters.""" + with torch.no_grad(): + model(*args) + + +def _load_golden(path: Path) -> dict[str, torch.Tensor | dict]: + """Load a golden fixture; fail with a regen hint if missing. + + Fixtures under ``test/experimental/models/xdeeponet/data/`` are + committed alongside this file and updated deliberately when model + numerics intentionally change. Regenerate with:: + + python test/experimental/models/xdeeponet/data/\\ + _generate_xdeeponet_goldens.py + + and commit the resulting ``.pth`` file. + """ + if not path.exists(): + pytest.fail( + f"Golden fixture {path.name} is missing. " + f"Regenerate with " + f"``python test/experimental/models/xdeeponet/data/" + f"_generate_xdeeponet_goldens.py`` and commit the " + f"resulting ``.pth`` file." + ) + # Golden payload is {str -> Tensor | dict[str, Tensor]} so + # ``weights_only=True`` is the safer default and avoids PyTorch 2.6's + # FutureWarning on the permissive load path. + return torch.load(path, weights_only=True) + + +# Registry of all (name, builder, golden-path) scenarios; consumed by the +# parameterised non-regression test below and by the golden generator +# script (``_generate_xdeeponet_goldens.py``) so new scenarios are picked +# up in both places by adding one entry here. +_FIXTURE_REGISTRY = [ + ("u_deeponet_packed_2d", _wrapper_2d, _GOLDEN_PACKED_2D), + ("u_deeponet_packed_3d", _wrapper_3d, _GOLDEN_PACKED_3D), + ("fourier_packed_2d", _wrapper_2d_fourier, _GOLDEN_PACKED_2D_FOURIER), + ("mionet_packed_2d", _wrapper_2d_mionet, _GOLDEN_PACKED_2D_MIONET), + ("temporal_packed_2d", _wrapper_2d_temporal, _GOLDEN_PACKED_2D_TEMPORAL), + ("packed_2d_multichannel", _packed_2d_multichannel, _GOLDEN_PACKED_2D_MULTICHANNEL), + ("kitchen_sink_packed_2d", _packed_2d_kitchen_sink, _GOLDEN_PACKED_2D_KITCHEN_SINK), + ("kitchen_sink_core_3d", _core_3d_kitchen_sink, _GOLDEN_CORE_3D_KITCHEN_SINK), + ("xfno_packed_3d", _xfno_packed_3d, _GOLDEN_XFNO_PACKED_3D), + ("xfno_packed_3d_extend", _xfno_packed_3d_extend, _GOLDEN_XFNO_PACKED_3D_EXTEND), + ("mlpbranch_core_2d", _core_2d_mlpbranch, _GOLDEN_CORE_2D_MLPBRANCH), +] + + +# ---------------------------------------------------------------------- +# Constructor + public attributes (MOD-008a) +# ---------------------------------------------------------------------- + + +class TestDeepONetConstructor: + """Constructor instantiates and exposes the documented public attributes.""" + + @pytest.mark.parametrize( + "config", + [ + {"width": 8, "decoder_type": "mlp"}, + {"width": 16, "decoder_type": "conv"}, + ], + ids=["default-ish", "custom"], + ) + def test_deeponet_2d_core(self, config): + """``DeepONet`` stores the constructor arguments on public attrs.""" + model = DeepONet( + branch1=_make_unet_spatial_branch(dimension=2, width=config["width"]), + trunk=_make_trunk(out_features=config["width"]), + dimension=2, + width=config["width"], + decoder_type=config["decoder_type"], + decoder_width=config["width"], + decoder_layers=1, + ) + assert model.dimension == 2 + assert model.width == config["width"] + assert model.decoder_type == config["decoder_type"] + assert model.decoder_activation_fn == "relu" + assert model.trunk is not None + + @pytest.mark.parametrize( + "config", + [ + {"width": 8, "decoder_type": "mlp"}, + {"width": 16, "decoder_type": "conv"}, + ], + ids=["default-ish", "custom"], + ) + def test_deeponet_3d_core(self, config): + """``DeepONet(dimension=3)`` stores the constructor arguments on public attrs.""" + model = DeepONet( + branch1=_make_unet_spatial_branch(dimension=3, width=config["width"]), + trunk=_make_trunk(out_features=config["width"]), + dimension=3, + width=config["width"], + decoder_type=config["decoder_type"], + decoder_width=config["width"], + decoder_layers=1, + ) + assert model.dimension == 3 + assert model.width == config["width"] + assert model.decoder_type == config["decoder_type"] + assert model.decoder_activation_fn == "relu" + assert model.trunk is not None + + @pytest.mark.parametrize( + "config", + [ + {"padding": 8, "trunk_input": "time"}, + {"padding": 16, "trunk_input": "grid"}, + ], + ids=["default-ish", "custom"], + ) + def test_packed_2d(self, config): + """``DeepONet(auto_pad=True)`` exposes padding / trunk_input.""" + model = DeepONet( + branch1=_make_unet_spatial_branch(dimension=2, width=8), + trunk=_make_trunk(out_features=8), + dimension=2, + width=8, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + auto_pad=True, + padding=config["padding"], + trunk_input=config["trunk_input"], + ) + assert model.auto_pad is True + assert model.padding == config["padding"] + assert model.trunk_input == config["trunk_input"] + + @pytest.mark.parametrize( + "config", + [ + {"padding": 8, "trunk_input": "time"}, + {"padding": 16, "trunk_input": "grid"}, + ], + ids=["default-ish", "custom"], + ) + def test_packed_3d(self, config): + """``DeepONet(dimension=3, auto_pad=True)`` exposes padding / trunk_input.""" + model = DeepONet( + branch1=_make_unet_spatial_branch(dimension=3, width=8), + trunk=_make_trunk(out_features=8), + dimension=3, + width=8, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + auto_pad=True, + padding=config["padding"], + trunk_input=config["trunk_input"], + ) + assert model.dimension == 3 + assert model.auto_pad is True + assert model.padding == config["padding"] + assert model.trunk_input == config["trunk_input"] + + def test_simple_fourier_construction(self): + """Direct DI construction with a Fourier branch + custom trunk. + + Sanity-check that hand-composing :class:`SpatialBranch` and + :class:`physicsnemo.models.mlp.FullyConnected` modules into a + :class:`DeepONet` produces a model with the expected attributes + and that the passed-in module instances are preserved as + submodules (not copied or rebuilt). + """ + torch.manual_seed(_SEED) + branch1 = SpatialBranch( + dimension=2, + in_channels=2, + width=8, + num_fourier_layers=1, + modes1=2, + modes2=2, + activation_fn="relu", + ) + trunk = FullyConnected( + in_features=1, + layer_size=16, + out_features=8, + num_layers=2, + activation_fn="tanh", + ) + model = DeepONet( + branch1=branch1, + trunk=trunk, + dimension=2, + width=8, + decoder_type="mlp", + decoder_width=8, + decoder_layers=1, + decoder_activation_fn="relu", + ) + assert model.dimension == 2 + assert model.width == 8 + assert model.auto_pad is False + # branch1 is a SpatialBranch -> not the MLP-branch path + assert model._branch1_is_mlp is False + # trunk is preserved as the passed-in instance (not rebuilt) + assert model.trunk is trunk + assert model.branch1 is branch1 + + +# ---------------------------------------------------------------------- +# Forward non-regression against committed golden files (MOD-008b) +# ---------------------------------------------------------------------- + + +def _golden_args(golden: dict) -> tuple[torch.Tensor, ...]: + """Read positional forward arguments from a golden payload. + + Two on-disk schemas are recognised: + + - ``{"args": (tensor, ...), "y": ..., "state_dict": ...}`` (multi-arg) + - ``{"x": tensor, "y": ..., "state_dict": ...}`` (single-input) + """ + if "args" in golden: + args = golden["args"] + if isinstance(args, (list, tuple)): + return tuple(args) + return (args,) + return (golden["x"],) + + +class TestDeepONetNonRegression: + """Forward output matches the committed golden fixture. + + Parameterised on the full :data:`_FIXTURE_REGISTRY` so adding a new + scenario is a one-line addition (and a regenerated ``.pth``). + """ + + @pytest.mark.parametrize( + "name, builder, golden_path", + _FIXTURE_REGISTRY, + ids=[entry[0] for entry in _FIXTURE_REGISTRY], + ) + def test_matches_golden(self, name, builder, golden_path): + """Forward output reproduces the stored golden output bit-for-bit.""" + del name # used only for the test ID + golden = _load_golden(golden_path) + args = _golden_args(golden) + model, _ = builder() + _init_lazy(model, *args) + model.load_state_dict(golden["state_dict"]) + with torch.no_grad(): + y = model(*args) + torch.testing.assert_close(y, golden["y"], rtol=1e-5, atol=1e-6) + + +class TestDeepONetTimeAxisExtend: + """Time-axis-extend (xFNO-style autoregressive bundling). + + Exercises the trunkless packed-input forward path when + ``time_modes`` is set and ``target_times`` is supplied at forward + time. Verifies that the output shape matches the requested forecast + horizon ``K`` and that the spatial axes are cropped to the + original input shape. + """ + + def test_predicts_K_future_steps(self): + model, (x,) = _xfno_packed_3d_extend() + _init_lazy(model, x) + # Choose K different from T_in (4) to trigger the time-extend + # code path. K=6 should produce output with the last spatial + # axis = K. + target_times = torch.linspace(0.5, 1.0, 6) + with torch.no_grad(): + y = model(x, target_times=target_times) + # x: (1, 8, 8, 4, 2); output should be (1, 8, 8, K=6, out_channels=1). + assert y.shape == (1, 8, 8, 6, 1) + + def test_K_equals_T_in_no_extend(self): + model, (x,) = _xfno_packed_3d_extend() + _init_lazy(model, x) + # K == T_in (4): time-extend short-circuits; output keeps the + # original time-axis length. + target_times = torch.linspace(0.0, 1.0, 4) + with torch.no_grad(): + y = model(x, target_times=target_times) + assert y.shape == (1, 8, 8, 4, 1) + + +# ---------------------------------------------------------------------- +# Checkpoint (.mdlus) round-trip (MOD-008c) +# ---------------------------------------------------------------------- + + +class TestDeepONetCheckpoint: + """``Module.save`` + ``Module.from_checkpoint`` round-trip. + + Verifies that :meth:`physicsnemo.Module.from_checkpoint` reconstructs a + byte-identical model. The loaded model's forward output is compared + **against the committed golden fixture** — not against a second forward + pass on the in-memory model — so the test fails if the serialized + state is incomplete, corrupted, or silently re-initialised. + + PyTorch's :meth:`torch.nn.Module.load_state_dict` natively materialises + :class:`torch.nn.LazyLinear` parameters from the saved tensors, so no + ``_init_lazy`` call is needed on the reloaded model. + + Round-trip is exercised on the wrapper fixtures only; ``Module`` + save/load is class-level, so once it works on one variant it works on + all of them. Picking the 2D and 3D U-DeepONet wrappers because those + are the most user-facing. + """ + + def _roundtrip(self, model, args, tmp_path): + _init_lazy(model, *args) + ckpt = tmp_path / "model.mdlus" + model.save(str(ckpt)) + loaded = Module.from_checkpoint(str(ckpt)) + with torch.no_grad(): + y_loaded = loaded(*args) + return loaded, y_loaded + + def test_wrapper_2d_roundtrip(self, tmp_path): + """2D wrapper: reloaded output matches the committed golden.""" + golden = _load_golden(_GOLDEN_PACKED_2D) + args = _golden_args(golden) + model, _ = _wrapper_2d() + loaded, y_loaded = self._roundtrip(model, args, tmp_path) + assert type(loaded).__name__ == type(model).__name__ + assert loaded.padding == model.padding + assert loaded.trunk_input == model.trunk_input + torch.testing.assert_close(y_loaded, golden["y"], rtol=1e-5, atol=1e-6) + + def test_wrapper_3d_roundtrip(self, tmp_path): + """3D wrapper: reloaded output matches the committed golden.""" + golden = _load_golden(_GOLDEN_PACKED_3D) + args = _golden_args(golden) + model, _ = _wrapper_3d() + loaded, y_loaded = self._roundtrip(model, args, tmp_path) + assert type(loaded).__name__ == type(model).__name__ + assert loaded.padding == model.padding + assert loaded.trunk_input == model.trunk_input + torch.testing.assert_close(y_loaded, golden["y"], rtol=1e-5, atol=1e-6) + + +# ---------------------------------------------------------------------- +# Gradient flow +# ---------------------------------------------------------------------- + + +class TestDeepONetGradientFlow: + """Backward pass produces non-None gradients on input and parameters. + + Tested for both the 2D and 3D wrappers since the 3D forward path + performs different tensor reshapes (extra unsqueeze, deeper + permutations) and could in principle fail to propagate gradients + even when the 2D path works. + """ + + def test_wrapper_2d_gradients(self): + """Gradients flow through the 2D wrapper.""" + model, (x,) = _wrapper_2d() + _init_lazy(model, x) + x = x.detach().requires_grad_(True) + y = model(x) + y.sum().backward() + assert x.grad is not None + trainable = [p for p in model.parameters() if p.requires_grad] + assert trainable, "model has no trainable parameters" + assert any(p.grad is not None for p in trainable) + + def test_wrapper_3d_gradients(self): + """Gradients flow through the 3D wrapper.""" + model, (x,) = _wrapper_3d() + _init_lazy(model, x) + x = x.detach().requires_grad_(True) + y = model(x) + y.sum().backward() + assert x.grad is not None + trainable = [p for p in model.parameters() if p.requires_grad] + assert trainable, "model has no trainable parameters" + assert any(p.grad is not None for p in trainable) + + +# ---------------------------------------------------------------------- +# torch.compile smoke test +# ---------------------------------------------------------------------- + + +class TestDeepONetCompile: + """``torch.compile`` wraps the model without raising. + + Two variants per dimensionality: + + - ``fullgraph=False`` (the default for production code): the model + must compile end-to-end with graph breaks tolerated. Output must + match eager numerically. + - ``fullgraph=True``: probes whether the entire forward is + graph-capturable with no breaks at all. Jaxtyping shape + decorators and the dynamic spatial-padding paths in + :func:`~physicsnemo.experimental.models.xdeeponet._padding.pad_spatial_right` + are evaluated under ``torch.compiler.is_compiling()`` guards so + they constant-fold during compile; both 2D and 3D forward paths + currently compile cleanly with no graph breaks across the + torch versions exercised in CI and locally. If a future torch + update reintroduces breaks the assertion below will fail; re-add + ``@pytest.mark.xfail(strict=False)`` until the breaks are fixed. + """ + + def test_wrapper_2d_compile(self): + """2D compiled model produces shape-compatible output vs eager.""" + model, (x,) = _wrapper_2d() + _init_lazy(model, x) + with torch.no_grad(): + y_eager = model(x) + compiled = torch.compile(model, fullgraph=False) + with torch.no_grad(): + y_compiled = compiled(x) + assert y_compiled.shape == y_eager.shape + torch.testing.assert_close(y_compiled, y_eager, rtol=1e-4, atol=1e-5) + + def test_wrapper_3d_compile(self): + """3D compiled model produces shape-compatible output vs eager.""" + model, (x,) = _wrapper_3d() + _init_lazy(model, x) + with torch.no_grad(): + y_eager = model(x) + compiled = torch.compile(model, fullgraph=False) + with torch.no_grad(): + y_compiled = compiled(x) + assert y_compiled.shape == y_eager.shape + torch.testing.assert_close(y_compiled, y_eager, rtol=1e-4, atol=1e-5) + + def test_wrapper_2d_compile_fullgraph(self): + """2D model compiles cleanly with ``fullgraph=True``.""" + model, (x,) = _wrapper_2d() + _init_lazy(model, x) + compiled = torch.compile(model, fullgraph=True) + with torch.no_grad(): + compiled(x) + + def test_wrapper_3d_compile_fullgraph(self): + """3D model compiles cleanly with ``fullgraph=True``.""" + model, (x,) = _wrapper_3d() + _init_lazy(model, x) + compiled = torch.compile(model, fullgraph=True) + with torch.no_grad(): + compiled(x) + + +# ---------------------------------------------------------------------- +# Stress test: kitchen-sink configuration +# ---------------------------------------------------------------------- + + +class TestDeepONetStress: + """Stress-test the kitchen-sink configurations end-to-end. + + The fixture-pinned non-regression checks on + ``kitchen_sink_packed_2d`` and ``kitchen_sink_core_3d`` (above, in + :class:`TestDeepONetNonRegression`) already verify numerics against + committed goldens. This class complements them by exercising the + same configurations through three dynamic-behaviour checks per + dimensionality, mirroring the structure of + :class:`TestDeepONetGradientFlow` and :class:`TestDeepONetCompile`: + + - forward output shape matches the expected contract, + - the backward pass populates gradients on every input tensor and + on at least one trainable parameter, + - ``torch.compile(fullgraph=False)`` produces eager-equivalent + output (full-graph compile is probed on the simpler wrappers in + :class:`TestDeepONetCompile`; skipping it here to keep the test + runtime reasonable). + + The 2D kitchen-sink combines: Fourier + UNet + Conv sub-stacks on + both branches, the mionet dual-branch Hadamard product, the + ``temporal_projection`` decoder with ``output_window > 1``, + multi-channel output, ``trunk_input="grid"``, a multi-layer + pointwise lift on branch1, asymmetric ``coord_features`` and + activation functions across branches, and the Sin trunk activation. + + The 3D kitchen-sink covers a deliberately disjoint set of code + paths: ``auto_pad=False`` (core mode) with a 3D spatial branch, + ``decoder_type="conv"``, 3D mionet, all three sub-stacks on both + branches, ``lift_layers=3`` with ``lift_hidden_width`` set, a + 3-layer trunk with ``output_activation=False``, and a non-default + activation palette (celu / leaky_relu / tanh / elu). + + If any code path among these regresses, this class flags it + independently of any fixture-numerics drift. + """ + + def test_forward_shape_2d(self): + """2D kitchen-sink: output shape matches ``(B, H, W, K, oc)``.""" + model, args = _packed_2d_kitchen_sink() + _init_lazy(model, *args) + with torch.no_grad(): + y = model(*args) + assert y.shape == (1, 8, 8, 3, 2) + + def test_forward_shape_3d(self): + """3D kitchen-sink: output shape matches ``(B, X, Y, Z, T, oc)``.""" + model, args = _core_3d_kitchen_sink() + _init_lazy(model, *args) + with torch.no_grad(): + y = model(*args) + assert y.shape == (1, 8, 8, 8, 3, 2) + + def test_gradients_2d(self): + """2D kitchen-sink: backward populates gradients on inputs and params.""" + model, args = _packed_2d_kitchen_sink() + _init_lazy(model, *args) + args = tuple(a.detach().requires_grad_(True) for a in args) + y = model(*args) + y.sum().backward() + for i, a in enumerate(args): + assert a.grad is not None, f"input arg[{i}] has no gradient" + trainable = [p for p in model.parameters() if p.requires_grad] + assert trainable, "model has no trainable parameters" + assert any(p.grad is not None for p in trainable) + + def test_gradients_3d(self): + """3D kitchen-sink: backward populates gradients on inputs and params.""" + model, args = _core_3d_kitchen_sink() + _init_lazy(model, *args) + args = tuple(a.detach().requires_grad_(True) for a in args) + y = model(*args) + y.sum().backward() + for i, a in enumerate(args): + assert a.grad is not None, f"input arg[{i}] has no gradient" + trainable = [p for p in model.parameters() if p.requires_grad] + assert trainable, "model has no trainable parameters" + assert any(p.grad is not None for p in trainable) + + def test_compile_2d(self): + """2D kitchen-sink: ``torch.compile(fullgraph=False)`` parity.""" + model, args = _packed_2d_kitchen_sink() + _init_lazy(model, *args) + with torch.no_grad(): + y_eager = model(*args) + compiled = torch.compile(model, fullgraph=False) + with torch.no_grad(): + y_compiled = compiled(*args) + assert y_compiled.shape == y_eager.shape + torch.testing.assert_close(y_compiled, y_eager, rtol=1e-4, atol=1e-5) + + def test_compile_3d(self): + """3D kitchen-sink: ``torch.compile(fullgraph=False)`` parity.""" + model, args = _core_3d_kitchen_sink() + _init_lazy(model, *args) + with torch.no_grad(): + y_eager = model(*args) + compiled = torch.compile(model, fullgraph=False) + with torch.no_grad(): + y_compiled = compiled(*args) + assert y_compiled.shape == y_eager.shape + torch.testing.assert_close(y_compiled, y_eager, rtol=1e-4, atol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])