diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c1b8fb70c..380cb0bf59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -206,6 +206,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 combined-workflow and from-checkpoint round-trip tests. Most tests run with `fullgraph=True` and `error_on_recompile` to catch `torch.compile` regressions. +- Internal weight initialization in the distributed AFNO layers and the + `EarthAttention` blocks of `physicsnemo.nn.module.attention_layers` now + dispatches to `torch.nn.init.trunc_normal_` directly instead of going + through frozen in-tree copies of the pre-PyTorch-2.12 inverse-CDF + implementation. PyTorch 2.12 reimplemented `trunc_normal_` as a + rejection-sampling loop on top of `normal_()` (see + [pytorch/pytorch#174997](https://github.com/pytorch/pytorch/pull/174997)), + so seeded from-scratch initialization consumes the RNG stream + differently on 2.12+ vs older versions. Existing trained checkpoints + are unaffected (loading bypasses init). Forward-accuracy reference + outputs for `AFNO`, `ModAFNO`, `Transolver`, `FLARE`, and `Pangu` were + regenerated against the new algorithm. Rather than wiring per-model + skips, `test.common.validate_forward_accuracy` now uniformly skips on + `torch < 2.12` (the reference data is locked to that floor via a single + `_REFERENCE_DATA_MIN_TORCH` constant; bump it when a PyTorch + release next changes an init/RNG algorithm any forward-accuracy model + depends on, and regenerate the `.pth` files at the same time). ### Deprecated @@ -213,6 +230,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 isosurface extraction, use `physicsnemo.mesh.generate.marching_cubes` instead of `sdf_to_stl`. For VTP/OBJ/STL file conversion (`combine_vtp_files`, `convert_tesselated_files_in_directory`), use VTK or PyVista directly. +- `physicsnemo.nn.module.utils.trunc_normal_` (and its submodule path + `physicsnemo.nn.module.utils.weight_init.trunc_normal_`) is deprecated + and will be removed in v2.2.0. It is now a thin wrapper around + `torch.nn.init.trunc_normal_` that emits a `DeprecationWarning` on + call, replacing the frozen in-tree copy of the legacy inverse-CDF + implementation. Use `torch.nn.init.trunc_normal_` directly. + +### Removed + +- The legacy in-tree `trunc_normal_` implementation that lived in + `physicsnemo/models/afno/distributed/layers.py` (`_trunc_normal_` / + `_no_grad_trunc_normal_`) is removed. These names were private; all + in-tree call sites now use `torch.nn.init.trunc_normal_`. ### Fixed diff --git a/physicsnemo/models/afno/distributed/afno.py b/physicsnemo/models/afno/distributed/afno.py index 7e218cd922..528015de9b 100644 --- a/physicsnemo/models/afno/distributed/afno.py +++ b/physicsnemo/models/afno/distributed/afno.py @@ -44,7 +44,6 @@ DistributedMLP, DistributedPatchEmbed, DropPath, - _trunc_normal_, ) from physicsnemo.nn.module.layer_norm import get_layer_norm_class @@ -361,7 +360,7 @@ def __init__( self.synchronized_head = False # init weights - _trunc_normal_(self.pos_embed, std=0.02) + torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: @@ -373,7 +372,7 @@ def _init_weights(self, m: nn.Module) -> None: Module to initialize. """ if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): - _trunc_normal_(m.weight, std=0.02) + torch.nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, _LayerNormClass): diff --git a/physicsnemo/models/afno/distributed/layers.py b/physicsnemo/models/afno/distributed/layers.py index 308bf160ee..8e2c2341e7 100644 --- a/physicsnemo/models/afno/distributed/layers.py +++ b/physicsnemo/models/afno/distributed/layers.py @@ -22,8 +22,6 @@ from __future__ import annotations -import math -import warnings from typing import Tuple import torch @@ -43,84 +41,6 @@ from physicsnemo.distributed.utils import compute_split_shapes -def _no_grad_trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - # Method based on - # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - low = norm_cdf((a - mean) / std) - up = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [low, up], then translate to - # [2low-1, 2up-1]. - tensor.uniform_(2 * low - 1, 2 * up - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def _trunc_normal_( - tensor: Tensor, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> Tensor: - r"""Fill the input tensor with values from a truncated normal distribution. - - The values are drawn from :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within the bounds. - The method works best when :math:`a \leq \text{mean} \leq b`. - - Parameters - ---------- - tensor : torch.Tensor - An n-dimensional tensor to fill. - mean : float, optional, default=0.0 - Mean of the normal distribution. - std : float, optional, default=1.0 - Standard deviation of the normal distribution. - a : float, optional, default=-2.0 - Minimum cutoff value. - b : float, optional, default=2.0 - Maximum cutoff value. - - Returns - ------- - torch.Tensor - The input tensor filled with truncated normal values. - - Examples - -------- - >>> w = torch.empty(3, 5) - >>> o = _trunc_normal_(w) - """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) - - @torch.compile def drop_path( x: Float[Tensor, "*dims"], drop_prob: float = 0.0, training: bool = False @@ -294,9 +214,9 @@ def __init__( def _init_weights(self) -> None: r"""Initialize weights using truncated normal distribution.""" - _trunc_normal_(self.w1, std=0.02) + torch.nn.init.trunc_normal_(self.w1, std=0.02) nn.init.constant_(self.b1, 0.0) - _trunc_normal_(self.w2, std=0.02) + torch.nn.init.trunc_normal_(self.w2, std=0.02) nn.init.constant_(self.b2, 0.0) def forward(self, x: Float[Tensor, "B C H W"]) -> Float[Tensor, "B C_out H W"]: diff --git a/physicsnemo/nn/module/attention_layers.py b/physicsnemo/nn/module/attention_layers.py index 9a8682eee1..7e8d555ea8 100644 --- a/physicsnemo/nn/module/attention_layers.py +++ b/physicsnemo/nn/module/attention_layers.py @@ -23,7 +23,7 @@ from physicsnemo.nn.module.conv_layers import Conv2d from physicsnemo.nn.module.group_norm import get_group_norm -from physicsnemo.nn.module.utils import get_earth_position_index, trunc_normal_ +from physicsnemo.nn.module.utils import get_earth_position_index class AttentionOp(torch.autograd.Function): @@ -237,9 +237,7 @@ def __init__( self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - self.earth_position_bias_table = trunc_normal_( - self.earth_position_bias_table, std=0.02 - ) + torch.nn.init.trunc_normal_(self.earth_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x: torch.Tensor, mask=None): @@ -345,9 +343,7 @@ def __init__( self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - self.earth_position_bias_table = trunc_normal_( - self.earth_position_bias_table, std=0.02 - ) + torch.nn.init.trunc_normal_(self.earth_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x: torch.Tensor, mask=None): diff --git a/physicsnemo/nn/module/dit_layers.py b/physicsnemo/nn/module/dit_layers.py index c069c04025..5efb5719ff 100644 --- a/physicsnemo/nn/module/dit_layers.py +++ b/physicsnemo/nn/module/dit_layers.py @@ -388,8 +388,8 @@ class Natten2DSelfAttention(AttentionModuleBase): -------- >>> import torch >>> from physicsnemo.nn.module.dit_layers import Natten2DSelfAttention - >>> attn = Natten2DSelfAttention(hidden_size=64, num_heads=4, attn_kernel=3) - >>> x = torch.randn(2, 16, 64) + >>> attn = Natten2DSelfAttention(hidden_size=64, num_heads=4, attn_kernel=3).cuda() + >>> x = torch.randn(2, 16, 64, device="cuda") >>> out = attn(x, latent_hw=(4, 4)) >>> out.shape torch.Size([2, 16, 64]) diff --git a/physicsnemo/nn/module/utils/weight_init.py b/physicsnemo/nn/module/utils/weight_init.py index a0a0470ca6..10af9b7561 100644 --- a/physicsnemo/nn/module/utils/weight_init.py +++ b/physicsnemo/nn/module/utils/weight_init.py @@ -14,66 +14,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import warnings import numpy as np import torch +from torch.nn.init import trunc_normal_ as _torch_trunc_normal_ -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 +def trunc_normal_(*args, **kwargs): + """Deprecated alias for :func:`torch.nn.init.trunc_normal_`. - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - u1 = norm_cdf((a - mean) / std) - u2 = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [u1, u2], then translate to - # [2u1-1, 2u2-1]. - tensor.uniform_(2 * u1 - 1, 2 * u2 - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): - # type: (Tensor, float, float, float, float) -> Tensor - r"""Cut & paste from timm master - Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. - - NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are - applied while sampling the normal with mean/std applied, therefore a, b args - should be adjusted to match the range of mean, std args. + This re-export exists only to preserve backward compatibility for code + that imported ``trunc_normal_`` from ``physicsnemo.nn.module.utils`` (or + its ``weight_init`` submodule path) prior to v2.1. It will be removed in + v2.2.0; new code should call :func:`torch.nn.init.trunc_normal_` + directly. """ - with torch.no_grad(): - return _trunc_normal_(tensor, mean, std, a, b) + warnings.warn( + "`physicsnemo.nn.module.utils.trunc_normal_` is deprecated and will " + "be removed in v2.2.0. Use `torch.nn.init.trunc_normal_` directly.", + DeprecationWarning, + stacklevel=2, + ) + return _torch_trunc_normal_(*args, **kwargs) def _weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): diff --git a/test/common/fwdaccuracy.py b/test/common/fwdaccuracy.py index 39a3c3898b..02d47105dd 100644 --- a/test/common/fwdaccuracy.py +++ b/test/common/fwdaccuracy.py @@ -18,6 +18,7 @@ from pathlib import Path from typing import Tuple, Union +import pytest import torch import physicsnemo @@ -102,7 +103,27 @@ def validate_forward_accuracy( ------ IOError Target output tensor file for this model was not found + + Notes + ----- + On ``torch < _REFERENCE_DATA_MIN_TORCH`` (currently ``"2.12"``) the + test is skipped via :func:`pytest.skip` instead of compared: the + saved ``.pth`` reference data is locked to post-2.12 RNG/init + algorithms (notably ``torch.nn.init.trunc_normal_``'s + rejection-sampling rewrite, `pytorch/pytorch#174997 + `_). """ + # Bump when a PyTorch RNG/init algo change breaks reference .pth bit-stability + # (most recent: trunc_normal_ in 2.12; pytorch/pytorch#174997). Regen at bump. + _REFERENCE_DATA_MIN_TORCH = "2.12" + if torch.__version__ < _REFERENCE_DATA_MIN_TORCH: + pytest.skip( + f"Forward-accuracy reference data requires torch >= " + f"{_REFERENCE_DATA_MIN_TORCH} (got {torch.__version__}); " + f"init/RNG algorithms changed (e.g. trunc_normal_, " + f"pytorch/pytorch#174997)." + ) + # Perform a foward pass of the model output = model.forward(*in_args) # Always use tuples for this comparison / saving diff --git a/test/experimental/models/flare/data/flare_2d_output.pth b/test/experimental/models/flare/data/flare_2d_output.pth index 107fbc5779..1f8a653ab8 100644 Binary files a/test/experimental/models/flare/data/flare_2d_output.pth and b/test/experimental/models/flare/data/flare_2d_output.pth differ diff --git a/test/experimental/models/flare/data/flare_irregular_output.pth b/test/experimental/models/flare/data/flare_irregular_output.pth index 90fcb6f7fa..2962639497 100644 Binary files a/test/experimental/models/flare/data/flare_irregular_output.pth and b/test/experimental/models/flare/data/flare_irregular_output.pth differ diff --git a/test/models/afno/data/afno_output.pth b/test/models/afno/data/afno_output.pth index 98b3ca8fc0..e2bed6f364 100644 Binary files a/test/models/afno/data/afno_output.pth and b/test/models/afno/data/afno_output.pth differ diff --git a/test/models/afno/data/modafno_output.pth b/test/models/afno/data/modafno_output.pth index 256453cb7f..1df22e4f2f 100644 Binary files a/test/models/afno/data/modafno_output.pth and b/test/models/afno/data/modafno_output.pth differ diff --git a/test/models/pangu/data/pangu_output.pth b/test/models/pangu/data/pangu_output.pth index 5f6329bece..f01e2307a3 100644 Binary files a/test/models/pangu/data/pangu_output.pth and b/test/models/pangu/data/pangu_output.pth differ diff --git a/test/models/transolver/data/transolver2d_output.pth b/test/models/transolver/data/transolver2d_output.pth index 031955ce2a..78861aba6f 100644 Binary files a/test/models/transolver/data/transolver2d_output.pth and b/test/models/transolver/data/transolver2d_output.pth differ diff --git a/test/models/transolver/data/transolver_irregular_output.pth b/test/models/transolver/data/transolver_irregular_output.pth index 8f3d852e6b..8e8c09150d 100644 Binary files a/test/models/transolver/data/transolver_irregular_output.pth and b/test/models/transolver/data/transolver_irregular_output.pth differ