Skip to content

PyTorch 2.12 compatibility#1648

Merged
ktangsali merged 4 commits into
2.1.0-rcfrom
torch-2-12
May 20, 2026
Merged

PyTorch 2.12 compatibility#1648
ktangsali merged 4 commits into
2.1.0-rcfrom
torch-2-12

Conversation

@peterdsharpe
Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe commented May 15, 2026

PhysicsNeMo Pull Request

Description

Hey folks! Noticed on a sibling PR (#1633) that some tests are failing with the recently-released PyTorch 2.12. Thought I'd start a branch on the main repo that we (maintainers) can collaborate on to iron out the minor CI issues.

This PR currently has just one commit, which is the result of uv sync --upgrade.

Synced version diffs:
- absl-py==2.4.0
- aiobotocore==3.4.0
+ aiobotocore==3.7.0
- botocore==1.42.84
+ botocore==1.43.0
- build==1.4.3
+ build==1.5.0
- cachetools==7.0.6
+ cachetools==7.1.1
- certifi==2026.2.25
+ certifi==2026.4.22
- click==8.3.2
+ click==8.3.3
- contourpy==1.3.3
- coverage==7.13.5
+ coverage==7.14.0
- cuda-pathfinder==1.5.3
+ cuda-pathfinder==1.5.4
- cuda-toolkit==13.2.1
+ cuda-toolkit==13.0.2
- cupy-cuda13x==14.0.1
+ cupy-cuda13x==13.6.0
- cycler==0.12.1
- cyclopts==4.10.2
- dm-tree==0.1.9
- docstring-parser==0.18.0
- docutils==0.22.4
- fonttools==4.62.1
- fsspec==2026.3.0
+ fsspec==2026.4.0
- gitpython==3.1.46
+ gitpython==3.1.50
- hf-xet==1.4.3
+ hf-xet==1.5.0
- huggingface-hub==1.11.0
+ huggingface-hub==1.15.0
- idna==3.11
+ idna==3.15
- importlib-metadata==8.9.0
+ importlib-metadata==8.7.1
- ipython==9.12.0
+ ipython==9.13.0
- jedi==0.19.2
+ jedi==0.20.0
- kiwisolver==1.5.0
- markdown-it-py==4.0.0
+ markdown-it-py==4.2.0
- matplotlib==3.10.8
- matplotlib-inline==0.2.1
+ matplotlib-inline==0.2.2
- numpy==2.4.4
+ numpy==2.4.5
- nvidia-cublas==13.4.0.1
+ nvidia-cublas==13.1.0.3
- nvidia-cuda-cccl==13.2.75
+ nvidia-cuda-cccl==13.0.85
- nvidia-cuda-nvcc==13.2.78
+ nvidia-cuda-nvcc==13.0.88
- nvidia-cuda-nvrtc==13.2.78
+ nvidia-cuda-nvrtc==13.0.88
- nvidia-cuda-runtime==13.2.75
+ nvidia-cuda-runtime==13.0.96
- nvidia-cudnn-cu13==9.19.0.56
+ nvidia-cudnn-cu13==9.20.0.48
- nvidia-cufft==12.2.0.46
+ nvidia-cufft==12.0.0.61
- nvidia-curand==10.4.2.55
+ nvidia-curand==10.4.0.35
- nvidia-cusolver==12.2.0.1
+ nvidia-cusolver==12.0.4.66
- nvidia-cusparse==12.7.10.1
+ nvidia-cusparse==12.6.3.3
- nvidia-cusparselt-cu13==0.8.0
+ nvidia-cusparselt-cu13==0.8.1
- nvidia-dali-cuda130==2.0.0
+ nvidia-dali-cuda130==2.1.0
- nvidia-nccl-cu13==2.30.4
+ nvidia-nccl-cu13==2.29.7
- nvidia-nvjitlink==13.2.78
+ nvidia-nvjitlink==13.0.88
- nvidia-nvvm==13.2.78
+ nvidia-nvvm==13.0.88
- nvidia-physicsnemo==2.1.0a0 (from file:///home/psharpe/gh/physicsnemo)
+ nvidia-physicsnemo==2.1.0a0 (from file:///home/psharpe/gh/physicsnemo3)
+ optree==0.19.1
- orjson==3.11.8
+ orjson==3.11.9
- packaging==26.2
+ packaging==26.0
- parso==0.8.6
+ parso==0.8.7
- pooch==1.9.0
- pre-commit==4.5.1
+ pre-commit==4.6.0
- propcache==0.4.1
+ propcache==0.5.2
- pyacvd==0.3.3
- pyarrow==24.0.0
+ pyarrow==23.0.1
- pykdtree==1.4.3
- pyparsing==3.3.2
- python-discovery==1.2.2
+ python-discovery==1.3.1
- pytz==2026.1.post1
+ pytz==2026.2
- pyvista==0.47.3
- requests==2.33.1
+ requests==2.34.2
- rich-rst==1.3.2
- ruff==0.15.11
+ ruff==0.15.13
- s3fs==2026.3.0
+ s3fs==2026.4.0
- scooby==0.11.0
- tensordict==0.12.2
+ tensordict==0.12.3
- timm==1.0.26
+ timm==1.0.27
- torch==2.11.0+cu130
+ torch==2.12.0+cu130
- torchvision==0.26.0+cu130
+ torchvision==0.27.0+cu130
- traitlets==5.14.3
+ traitlets==5.15.0
- triton==3.6.0
+ triton==3.7.0
- typer==0.24.1
+ typer==0.25.1
- urllib3==2.6.3
+ urllib3==2.7.0
- virtualenv==21.2.4
+ virtualenv==21.3.3
- vtk==9.6.1
- warp-lang==1.12.1
+ warp-lang==1.13.0
- wcwidth==0.6.0
+ wcwidth==0.7.0
- wheel==0.46.3
+ wheel==0.47.0

Failing tests:

FAILED test/experimental/models/flare/test_flare.py::test_flare_2d_forward[cpu] - AssertionError: assert False
 +  where False = validate_forward_accuracy(FLARE(\n  (preprocess): _TransolverMlp(\n    (layers): Sequential(\n      (0): Linear(in_features=2, out_features=128, bi... elementwise_affine=True, bias=True)\n        (1): Linear(in_features=64, out_features=1, bias=True)\n      )\n    )\n  )\n), (tensor([[[ 0.5711],\n         [ 0.1995],\n         [ 1.1276],\n         ...,\n         [-1.9894],\n         [-0.1758],\n   ...1, -0.0180,  ...,  0.3017, -0.3337,  0.4747],\n         [-2.5247,  0.8545,  0.1970,  ..., -0.4494,  0.8349, -1.0158]]])), file_name='experimental/models/flare/data/flare_2d_output.pth', atol=0.002)
FAILED test/experimental/models/flare/test_flare.py::test_flare_irregular_forward[cpu] - AssertionError: assert False
 +  where False = validate_forward_accuracy(FLARE(\n  (preprocess): _TransolverMlp(\n    (layers): Sequential(\n      (0): Linear(in_features=5, out_features=128, bi... elementwise_affine=True, bias=True)\n        (1): Linear(in_features=64, out_features=1, bias=True)\n      )\n    )\n  )\n), (tensor([[[ 0.9846, -1.1268,  1.4789],\n         [-0.3570, -0.4284, -1.1272],\n         [-0.4014, -0.4351,  0.5843],\n   ...5969, -2.5630],\n         ...,\n         [ 0.2311, -0.4634],\n         [ 1.1963,  0.8382],\n         [ 0.0543,  2.0379]]])), file_name='experimental/models/flare/data/flare_irregular_output.pth', atol=0.001)
FAILED test/models/afno/test_afno.py::test_afno_forward[cpu] - AssertionError: assert False
 +  where False = <function validate_forward_accuracy at 0x7f6a2fb3fce0>(AFNO(\n  (patch_embed): AFNOPatchEmbed(\n    (proj): Conv2d(2, 16, kernel_size=(8, 8), stride=(8, 8))\n  )\n  (pos_drop): ...(drop): Dropout(p=0.0, inplace=False)\n      )\n    )\n  )\n  (head): Linear(in_features=16, out_features=64, bias=False)\n), (tensor([[[[-4.7943e-01,  4.0143e-01, -3.7925e-01,  ...,  5.0208e-01,\n           -1.7293e+00,  2.5536e+00],\n          ...3e-01],\n          [-1.5788e-01,  2.5727e-01,  8.7710e-02,  ...,  4.1561e-01,\n            1.1829e+00, -1.4620e+00]]]]),), file_name='models/afno/data/afno_output.pth')
 +    where <function validate_forward_accuracy at 0x7f6a2fb3fce0> = common.validate_forward_accuracy
FAILED test/models/afno/test_modafno.py::test_modafno_forward[cpu] - AssertionError: assert False
 +  where False = <function validate_forward_accuracy at 0x7f6a2fb3fce0>(ModAFNO(\n  (patch_embed): AFNOPatchEmbed(\n    (proj): Conv2d(2, 16, kernel_size=(8, 8), stride=(8, 8))\n  )\n  (pos_drop...quential(\n      (0): Linear(in_features=64, out_features=64, bias=True)\n      (1): GELU(approximate='none')\n    )\n  )\n), (tensor([[[[-1.1156, -1.1069,  0.2088,  ..., -0.3796,  0.7840, -0.5480],\n          [-1.0344, -1.1430,  0.4686,  ..., -...667],\n          [-1.4664, -1.3720, -0.9867,  ...,  0.0991, -0.8756, -0.2588]]]]), tensor([[0.5000],\n        [0.5000]])), file_name='models/afno/data/modafno_output.pth')
 +    where <function validate_forward_accuracy at 0x7f6a2fb3fce0> = common.validate_forward_accuracy
FAILED test/models/transolver/test_transolver.py::test_transolver2d_forward[cpu] - AssertionError: assert False
 +  where False = validate_forward_accuracy(Transolver(\n  (preprocess): _TransolverMlp(\n    (layers): Sequential(\n      (0): Linear(in_features=2, out_features=12... elementwise_affine=True, bias=True)\n        (1): Linear(in_features=64, out_features=1, bias=True)\n      )\n    )\n  )\n), (tensor([[[-0.5897],\n         [ 1.0317],\n         [-0.2884],\n         ...,\n         [ 0.8900],\n         [ 0.2958],\n   ....5102e-01],\n         [ 3.6605e-01,  1.3253e-02, -1.0016e+00,  ...,  1.2162e+00,\n          -4.3989e-01, -1.1570e+00]]])), file_name='models/transolver/data/transolver2d_output.pth', atol=0.002)
FAILED test/models/transolver/test_transolver.py::test_transolver_irregular_forward[cpu] - AssertionError: assert False
 +  where False = validate_forward_accuracy(Transolver(\n  (preprocess): _TransolverMlp(\n    (layers): Sequential(\n      (0): Linear(in_features=5, out_features=12... elementwise_affine=True, bias=True)\n        (1): Linear(in_features=64, out_features=1, bias=True)\n      )\n    )\n  )\n), (tensor([[[ 0.3444,  0.5315,  1.3686],\n         [-0.6105,  0.8779,  0.0485],\n         [-2.2503, -0.5010, -1.1503],\n   ...1530, -0.1226],\n         ...,\n         [ 0.6356, -0.5943],\n         [ 0.2441, -0.2087],\n         [ 0.4954, -0.3232]]])), file_name='models/transolver/data/transolver_irregular_output.pth', atol=0.001)

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@peterdsharpe
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

- Updated internal weight initialization in distributed AFNO layers and EarthAttention blocks to utilize `torch.nn.init.trunc_normal_` instead of legacy implementations.
- Deprecated `trunc_normal_` wrapper in `physicsnemo.nn.module.utils` and removed the in-tree legacy implementation.
- Regenerated forward-accuracy reference outputs for several models to align with the new initialization method.
- Updated tests to skip on PyTorch versions below 2.12 due to changes in RNG algorithms affecting output consistency.
@peterdsharpe peterdsharpe reopened this May 19, 2026
@peterdsharpe peterdsharpe requested a review from coreyjadams May 19, 2026 15:25
@peterdsharpe peterdsharpe requested review from ktangsali and mnabian May 19, 2026 15:25
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 19, 2026

Greptile Summary

This PR updates PhysicsNeMo for PyTorch 2.12 compatibility by replacing frozen in-tree trunc_normal_ implementations with torch.nn.init.trunc_normal_ (which changed to rejection-sampling in 2.12), regenerating forward-accuracy reference .pth files, and adding a version gate that skips those tests on older PyTorch installs.

  • Removes private _trunc_normal_ / _no_grad_trunc_normal_ from layers.py and the deprecated in-tree copy in weight_init.py; replaces all call sites with torch.nn.init.trunc_normal_ directly.
  • Deprecates the public physicsnemo.nn.module.utils.trunc_normal_ re-export with a DeprecationWarning; scheduled for removal in v2.2.0.
  • Adds _REFERENCE_DATA_MIN_TORCH = "2.12" guard in validate_forward_accuracy and regenerates reference .pth files for AFNO, ModAFNO, Transolver, FLARE, and Pangu.

Important Files Changed

Filename Overview
test/common/fwdaccuracy.py Adds a PyTorch version gate to skip forward-accuracy tests on torch < 2.12; the gate uses lexicographic string comparison which silently misfires for versions 2.0–2.9 (P1 bug).
physicsnemo/models/afno/distributed/layers.py Removes the private in-tree _trunc_normal_ / _no_grad_trunc_normal_ implementations and replaces all call sites with torch.nn.init.trunc_normal_. Clean removal.
physicsnemo/models/afno/distributed/afno.py Updates weight-init call sites from the private _trunc_normal_ to torch.nn.init.trunc_normal_ and removes the now-unused import.
physicsnemo/nn/module/attention_layers.py Replaces deprecated trunc_normal_ wrapper calls with direct torch.nn.init.trunc_normal_ in EarthAttention3D and EarthAttention2D; assignment is correctly dropped since the PyTorch init is in-place.
physicsnemo/nn/module/utils/weight_init.py Replaces the in-tree inverse-CDF trunc_normal_ implementation with a thin deprecation-warning wrapper that delegates to torch.nn.init.trunc_normal_; numpy import retained for _weight_init.
CHANGELOG.md Well-documented changelog entries covering the deprecation, removal, and test-versioning change.

Reviews (1): Last reviewed commit: "Refactor weight initialization to use Py..." | Re-trigger Greptile

Comment thread test/common/fwdaccuracy.py
Copy link
Copy Markdown
Collaborator

@ktangsali ktangsali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to afno path look good. Thanks @peterdsharpe

@peterdsharpe
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

5 similar comments
@ktangsali
Copy link
Copy Markdown
Collaborator

/blossom-ci

@ktangsali
Copy link
Copy Markdown
Collaborator

/blossom-ci

@ktangsali
Copy link
Copy Markdown
Collaborator

/blossom-ci

@ktangsali
Copy link
Copy Markdown
Collaborator

/blossom-ci

@ktangsali
Copy link
Copy Markdown
Collaborator

/blossom-ci

@ktangsali ktangsali merged commit 1f0798d into 2.1.0-rc May 20, 2026
@ktangsali ktangsali deleted the torch-2-12 branch May 20, 2026 18:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants