Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,43 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
combined-workflow and from-checkpoint round-trip tests. Most tests
run with `fullgraph=True` and `error_on_recompile` to catch
`torch.compile` regressions.
- Internal weight initialization in the distributed AFNO layers and the
`EarthAttention` blocks of `physicsnemo.nn.module.attention_layers` now
dispatches to `torch.nn.init.trunc_normal_` directly instead of going
through frozen in-tree copies of the pre-PyTorch-2.12 inverse-CDF
implementation. PyTorch 2.12 reimplemented `trunc_normal_` as a
rejection-sampling loop on top of `normal_()` (see
[pytorch/pytorch#174997](https://github.com/pytorch/pytorch/pull/174997)),
so seeded from-scratch initialization consumes the RNG stream
differently on 2.12+ vs older versions. Existing trained checkpoints
are unaffected (loading bypasses init). Forward-accuracy reference
outputs for `AFNO`, `ModAFNO`, `Transolver`, `FLARE`, and `Pangu` were
regenerated against the new algorithm. Rather than wiring per-model
skips, `test.common.validate_forward_accuracy` now uniformly skips on
`torch < 2.12` (the reference data is locked to that floor via a single
`_REFERENCE_DATA_MIN_TORCH` constant; bump it when a PyTorch
release next changes an init/RNG algorithm any forward-accuracy model
depends on, and regenerate the `.pth` files at the same time).

### Deprecated

- `physicsnemo.utils.mesh` is deprecated and will be removed in v2.2.0. For
isosurface extraction, use `physicsnemo.mesh.generate.marching_cubes` instead
of `sdf_to_stl`. For VTP/OBJ/STL file conversion (`combine_vtp_files`,
`convert_tesselated_files_in_directory`), use VTK or PyVista directly.
- `physicsnemo.nn.module.utils.trunc_normal_` (and its submodule path
`physicsnemo.nn.module.utils.weight_init.trunc_normal_`) is deprecated
and will be removed in v2.2.0. It is now a thin wrapper around
`torch.nn.init.trunc_normal_` that emits a `DeprecationWarning` on
call, replacing the frozen in-tree copy of the legacy inverse-CDF
implementation. Use `torch.nn.init.trunc_normal_` directly.

### Removed

- The legacy in-tree `trunc_normal_` implementation that lived in
`physicsnemo/models/afno/distributed/layers.py` (`_trunc_normal_` /
`_no_grad_trunc_normal_`) is removed. These names were private; all
in-tree call sites now use `torch.nn.init.trunc_normal_`.

### Fixed

Expand Down
5 changes: 2 additions & 3 deletions physicsnemo/models/afno/distributed/afno.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
DistributedMLP,
DistributedPatchEmbed,
DropPath,
_trunc_normal_,
)
from physicsnemo.nn.module.layer_norm import get_layer_norm_class

Expand Down Expand Up @@ -361,7 +360,7 @@ def __init__(
self.synchronized_head = False

# init weights
_trunc_normal_(self.pos_embed, std=0.02)
torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)

def _init_weights(self, m: nn.Module) -> None:
Expand All @@ -373,7 +372,7 @@ def _init_weights(self, m: nn.Module) -> None:
Module to initialize.
"""
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
_trunc_normal_(m.weight, std=0.02)
torch.nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, _LayerNormClass):
Expand Down
84 changes: 2 additions & 82 deletions physicsnemo/models/afno/distributed/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

from __future__ import annotations

import math
import warnings
from typing import Tuple

import torch
Expand All @@ -43,84 +41,6 @@
from physicsnemo.distributed.utils import compute_split_shapes


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases
# Method based on
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)

with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
low = norm_cdf((a - mean) / std)
up = norm_cdf((b - mean) / std)

# Uniformly fill tensor with values from [low, up], then translate to
# [2low-1, 2up-1].
tensor.uniform_(2 * low - 1, 2 * up - 1)

# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()

# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)

# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor


def _trunc_normal_(
tensor: Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
) -> Tensor:
r"""Fill the input tensor with values from a truncated normal distribution.

The values are drawn from :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within the bounds.
The method works best when :math:`a \leq \text{mean} \leq b`.

Parameters
----------
tensor : torch.Tensor
An n-dimensional tensor to fill.
mean : float, optional, default=0.0
Mean of the normal distribution.
std : float, optional, default=1.0
Standard deviation of the normal distribution.
a : float, optional, default=-2.0
Minimum cutoff value.
b : float, optional, default=2.0
Maximum cutoff value.

Returns
-------
torch.Tensor
The input tensor filled with truncated normal values.

Examples
--------
>>> w = torch.empty(3, 5)
>>> o = _trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)


@torch.compile
def drop_path(
x: Float[Tensor, "*dims"], drop_prob: float = 0.0, training: bool = False
Expand Down Expand Up @@ -294,9 +214,9 @@ def __init__(

def _init_weights(self) -> None:
r"""Initialize weights using truncated normal distribution."""
_trunc_normal_(self.w1, std=0.02)
torch.nn.init.trunc_normal_(self.w1, std=0.02)
nn.init.constant_(self.b1, 0.0)
_trunc_normal_(self.w2, std=0.02)
torch.nn.init.trunc_normal_(self.w2, std=0.02)
nn.init.constant_(self.b2, 0.0)

def forward(self, x: Float[Tensor, "B C H W"]) -> Float[Tensor, "B C_out H W"]:
Expand Down
10 changes: 3 additions & 7 deletions physicsnemo/nn/module/attention_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from physicsnemo.nn.module.conv_layers import Conv2d
from physicsnemo.nn.module.group_norm import get_group_norm
from physicsnemo.nn.module.utils import get_earth_position_index, trunc_normal_
from physicsnemo.nn.module.utils import get_earth_position_index


class AttentionOp(torch.autograd.Function):
Expand Down Expand Up @@ -237,9 +237,7 @@ def __init__(
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self.earth_position_bias_table = trunc_normal_(
self.earth_position_bias_table, std=0.02
)
torch.nn.init.trunc_normal_(self.earth_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x: torch.Tensor, mask=None):
Expand Down Expand Up @@ -345,9 +343,7 @@ def __init__(
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self.earth_position_bias_table = trunc_normal_(
self.earth_position_bias_table, std=0.02
)
torch.nn.init.trunc_normal_(self.earth_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x: torch.Tensor, mask=None):
Expand Down
4 changes: 2 additions & 2 deletions physicsnemo/nn/module/dit_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ class Natten2DSelfAttention(AttentionModuleBase):
--------
>>> import torch
>>> from physicsnemo.nn.module.dit_layers import Natten2DSelfAttention
>>> attn = Natten2DSelfAttention(hidden_size=64, num_heads=4, attn_kernel=3)
>>> x = torch.randn(2, 16, 64)
>>> attn = Natten2DSelfAttention(hidden_size=64, num_heads=4, attn_kernel=3).cuda()
>>> x = torch.randn(2, 16, 64, device="cuda")
>>> out = attn(x, latent_hw=(4, 4))
>>> out.shape
torch.Size([2, 16, 64])
Expand Down
67 changes: 15 additions & 52 deletions physicsnemo/nn/module/utils/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,66 +14,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import warnings

import numpy as np
import torch
from torch.nn.init import trunc_normal_ as _torch_trunc_normal_


def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
def trunc_normal_(*args, **kwargs):
"""Deprecated alias for :func:`torch.nn.init.trunc_normal_`.

if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)

# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
u1 = norm_cdf((a - mean) / std)
u2 = norm_cdf((b - mean) / std)

# Uniformly fill tensor with values from [u1, u2], then translate to
# [2u1-1, 2u2-1].
tensor.uniform_(2 * u1 - 1, 2 * u2 - 1)

# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()

# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)

# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor


def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Cut & paste from timm master
Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.

NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
applied while sampling the normal with mean/std applied, therefore a, b args
should be adjusted to match the range of mean, std args.
This re-export exists only to preserve backward compatibility for code
that imported ``trunc_normal_`` from ``physicsnemo.nn.module.utils`` (or
its ``weight_init`` submodule path) prior to v2.1. It will be removed in
v2.2.0; new code should call :func:`torch.nn.init.trunc_normal_`
directly.
"""
with torch.no_grad():
return _trunc_normal_(tensor, mean, std, a, b)
warnings.warn(
"`physicsnemo.nn.module.utils.trunc_normal_` is deprecated and will "
"be removed in v2.2.0. Use `torch.nn.init.trunc_normal_` directly.",
DeprecationWarning,
stacklevel=2,
)
return _torch_trunc_normal_(*args, **kwargs)


def _weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int):
Expand Down
21 changes: 21 additions & 0 deletions test/common/fwdaccuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pathlib import Path
from typing import Tuple, Union

import pytest
import torch

import physicsnemo
Expand Down Expand Up @@ -102,7 +103,27 @@ def validate_forward_accuracy(
------
IOError
Target output tensor file for this model was not found

Notes
-----
On ``torch < _REFERENCE_DATA_MIN_TORCH`` (currently ``"2.12"``) the
test is skipped via :func:`pytest.skip` instead of compared: the
saved ``.pth`` reference data is locked to post-2.12 RNG/init
algorithms (notably ``torch.nn.init.trunc_normal_``'s
rejection-sampling rewrite, `pytorch/pytorch#174997
<https://github.com/pytorch/pytorch/pull/174997>`_).
"""
# Bump when a PyTorch RNG/init algo change breaks reference .pth bit-stability
# (most recent: trunc_normal_ in 2.12; pytorch/pytorch#174997). Regen at bump.
_REFERENCE_DATA_MIN_TORCH = "2.12"
if torch.__version__ < _REFERENCE_DATA_MIN_TORCH:
Comment thread
peterdsharpe marked this conversation as resolved.
pytest.skip(
f"Forward-accuracy reference data requires torch >= "
f"{_REFERENCE_DATA_MIN_TORCH} (got {torch.__version__}); "
f"init/RNG algorithms changed (e.g. trunc_normal_, "
f"pytorch/pytorch#174997)."
)

# Perform a foward pass of the model
output = model.forward(*in_args)
# Always use tuples for this comparison / saving
Expand Down
Binary file modified test/experimental/models/flare/data/flare_2d_output.pth
Binary file not shown.
Binary file modified test/experimental/models/flare/data/flare_irregular_output.pth
Binary file not shown.
Binary file modified test/models/afno/data/afno_output.pth
Binary file not shown.
Binary file modified test/models/afno/data/modafno_output.pth
Binary file not shown.
Binary file modified test/models/pangu/data/pangu_output.pth
Binary file not shown.
Binary file modified test/models/transolver/data/transolver2d_output.pth
Binary file not shown.
Binary file modified test/models/transolver/data/transolver_irregular_output.pth
Binary file not shown.