From cbb400cc2331c4ce7af2bf46f26cd981b3af6291 Mon Sep 17 00:00:00 2001 From: Alexander Wikner Date: Tue, 28 Apr 2026 04:07:53 -0500 Subject: [PATCH 1/3] Fix over-masking of cyclic longitude axis in shifted-window attention mask get_shift_window_mask was partitioning the attention mask along all three spatial axes (Pl, Lat, Lon) in 3D, and along both axes (Lat, Lon) in 2D. Because longitude is cyclic on the global ERA5 grid, the cyclic torch.roll shift is sufficient to handle wrap-around windows; adding a hard longitude mask additionally prevents tokens that are physically adjacent across the dateline from attending to one another. The Pangu-Weather paper (Bi et al., arXiv:2211.02556) states explicitly that wrap-around longitude windows "are directly merged into one window" and treats longitude as cyclic throughout (M_lon is omitted from the position bias for this reason). The fix removes the Lon + shift_lon padding, the lon_slices loop, and the trailing :Lon crop, leaving region IDs assigned by a Pl x Lat double loop only. The 2D path (FengWu) receives the symmetric change: its attention module is directly derived from Pangu-Weather and inherits the same cyclic-longitude design. This is a follow-up to #1492, which fixed the shift_lat typo in the forward cyclic roll; this PR fixes the mask that roll was supposed to be working with. Closes #1599. Signed-off-by: Alexander Wikner Co-Authored-By: Claude Sonnet 4.6 (1M context) --- CHANGELOG.md | 5 + .../nn/module/utils/shift_window_mask.py | 27 +-- test/nn/module/test_shift_window_mask.py | 188 ++++++++++++++++++ 3 files changed, 204 insertions(+), 16 deletions(-) create mode 100644 test/nn/module/test_shift_window_mask.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c1c799ee75..22fb17cacd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed graph break caused by `FunctionSpec` dispatch (`max(key=)` is not supported by `torch.compile`) - Fixed bug in Pangu, FengWu attention window shift for asymmetric longitudes +- Fixed over-masking of the cyclic longitude axis in `get_shift_window_mask` for + Pangu/FengWu shifted-window attention. The mask now partitions only the + (pressure level, latitude) plane, matching the original Pangu-Weather paper's + treatment of longitude as a cyclic dimension whose wrap-around windows + "are directly merged into one window". - Fixed a bug in `mesh.sampling.find_nearest_cells`, where a mixup between L2 and L-inf norms could cause slightly incorrect nearest-neighbor assignments in highly skewed meshes. - Fixed TensorDict key-ordering bug in GLOBE's Barnes-Hut kernel that caused diff --git a/physicsnemo/nn/module/utils/shift_window_mask.py b/physicsnemo/nn/module/utils/shift_window_mask.py index 8303323435..796922dba9 100644 --- a/physicsnemo/nn/module/utils/shift_window_mask.py +++ b/physicsnemo/nn/module/utils/shift_window_mask.py @@ -90,7 +90,11 @@ def window_reverse(windows, window_size, Pl=1, Lat=1, Lon=1, ndim=3): def get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3): """ Along the longitude dimension, the leftmost and rightmost indices are actually close to each other. - If half windows apper at both leftmost and rightmost positions, they are dircetly merged into one window. + If half windows appear at both leftmost and rightmost positions, they are directly merged into one window. + Because longitude is cyclic, no mask is applied along the longitude axis: the cyclic shift alone + (torch.roll on dim=Lon) is sufficient to handle wrap-around windows, and any additional longitude + masking would partition tokens that are physically adjacent across the dateline. Mask region IDs + are assigned by partitioning the (Pl, Lat) plane only. Args: input_resolution (tuple[int]): [pressure levels, latitude, longitude] or [latitude, longitude] window_size (tuple[int]): Window size [pressure levels, latitude, longitude] or [latitude, longitude] @@ -105,13 +109,13 @@ def get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3): win_pl, win_lat, win_lon = window_size shift_pl, shift_lat, shift_lon = shift_size - img_mask = torch.zeros((1, Pl, Lat, Lon + shift_lon, 1)) + img_mask = torch.zeros((1, Pl, Lat, Lon, 1)) elif ndim == 2: Lat, Lon = input_resolution win_lat, win_lon = window_size shift_lat, shift_lon = shift_size - img_mask = torch.zeros((1, Lat, Lon + shift_lon, 1)) + img_mask = torch.zeros((1, Lat, Lon, 1)) if ndim == 3: pl_slices = ( @@ -124,26 +128,17 @@ def get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3): slice(-win_lat, -shift_lat), slice(-shift_lat, None), ) - lon_slices = ( - slice(0, -win_lon), - slice(-win_lon, -shift_lon), - slice(-shift_lon, None), - ) cnt = 0 if ndim == 3: for pl in pl_slices: for lat in lat_slices: - for lon in lon_slices: - img_mask[:, pl, lat, lon, :] = cnt - cnt += 1 - img_mask = img_mask[:, :, :, :Lon, :] + img_mask[:, pl, lat, :, :] = cnt + cnt += 1 elif ndim == 2: for lat in lat_slices: - for lon in lon_slices: - img_mask[:, lat, lon, :] = cnt - cnt += 1 - img_mask = img_mask[:, :, :Lon, :] + img_mask[:, lat, :, :] = cnt + cnt += 1 mask_windows = window_partition( img_mask, window_size, ndim=ndim diff --git a/test/nn/module/test_shift_window_mask.py b/test/nn/module/test_shift_window_mask.py new file mode 100644 index 0000000000..1384beeb04 --- /dev/null +++ b/test/nn/module/test_shift_window_mask.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Earth-specific shifted-window attention mask utilities. + +Covers get_shift_window_mask, window_partition, and window_reverse for both +the 3D (Pangu-Weather) and 2D (FengWu) attention paths. +""" + +import pytest +import torch + +from physicsnemo.nn.module.utils.shift_window_mask import ( + get_shift_window_mask, + window_partition, + window_reverse, +) + + +class TestGetShiftWindowMask3D: + """Tests for get_shift_window_mask with ndim=3 (Pangu-Weather path).""" + + @pytest.mark.parametrize( + "input_resolution, window_size, shift_size", + [ + ((8, 24, 48), (2, 6, 12), (1, 3, 6)), # default Pangu config + ((4, 12, 24), (2, 6, 12), (1, 3, 6)), # smaller resolution + ((8, 24, 48), (2, 6, 6), (1, 3, 3)), # square lon window + ], + ) + def test_output_shape(self, input_resolution, window_size, shift_size): + """Mask shape must be (n_lon, n_pl*n_lat, W, W).""" + Pl, Lat, Lon = input_resolution + win_pl, win_lat, win_lon = window_size + mask = get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3) + n_lon = Lon // win_lon + n_pl_lat = (Pl // win_pl) * (Lat // win_lat) + W = win_pl * win_lat * win_lon + assert tuple(mask.shape) == (n_lon, n_pl_lat, W, W) + + @pytest.mark.parametrize( + "input_resolution, window_size, shift_size", + [ + ((8, 24, 48), (2, 6, 12), (1, 3, 6)), + ((4, 12, 24), (2, 6, 12), (1, 3, 6)), + ], + ) + def test_values_binary(self, input_resolution, window_size, shift_size): + """Mask must contain only 0.0 and -100.0.""" + mask = get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3) + unique = sorted(torch.unique(mask).tolist()) + assert unique == [-100.0, 0.0] + + def test_longitude_unmasked_region_count(self): + """Longitude must not be partitioned: only Pl x Lat region IDs (9).""" + input_resolution = (8, 24, 48) + window_size = (2, 6, 12) + shift_size = (1, 3, 6) + Pl, Lat, Lon = input_resolution + win_pl, win_lat, win_lon = window_size + shift_pl, shift_lat, shift_lon = shift_size + + # Reconstruct the underlying region-ID map directly + img_mask = torch.zeros((1, Pl, Lat, Lon, 1)) + pl_slices = (slice(0, -win_pl), slice(-win_pl, -shift_pl), slice(-shift_pl, None)) + lat_slices = (slice(0, -win_lat), slice(-win_lat, -shift_lat), slice(-shift_lat, None)) + cnt = 0 + for pl in pl_slices: + for lat in lat_slices: + img_mask[:, pl, lat, :, :] = cnt + cnt += 1 + + n_regions = len(torch.unique(img_mask)) + # 3 Pl bands x 3 Lat bands = 9; longitude must NOT add more partitions + assert n_regions == 9, ( + f"Expected 9 region IDs (Pl x Lat only), got {n_regions}. " + "Longitude axis must not be partitioned in the mask." + ) + + def test_no_shift_produces_zero_mask(self): + """With shift_size=(0,0,0) no roll occurs; attn_mask should be None (not called), + but if called directly the mask should be all zeros (no region boundaries).""" + # shift_size of all-zeros means every token maps to region 0 + input_resolution = (8, 24, 48) + window_size = (2, 6, 12) + shift_size = (0, 0, 0) + # With zero shift, slices like slice(0, 0) produce empty ranges. + # The function should still return a valid tensor of the right shape. + # (This exercises the edge case; the Transformer block skips calling + # get_shift_window_mask when roll=False, but the function itself should + # not error.) + # We only check it does not raise. + try: + mask = get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3) + assert mask is not None + except Exception as exc: + pytest.fail(f"get_shift_window_mask raised unexpectedly with zero shift: {exc}") + + +class TestGetShiftWindowMask2D: + """Tests for get_shift_window_mask with ndim=2 (FengWu path).""" + + @pytest.mark.parametrize( + "input_resolution, window_size, shift_size", + [ + ((24, 48), (6, 12), (3, 6)), # default FengWu config (scaled) + ((12, 24), (6, 12), (3, 6)), + ], + ) + def test_output_shape(self, input_resolution, window_size, shift_size): + """Mask shape must be (n_lon, n_lat, W, W).""" + Lat, Lon = input_resolution + win_lat, win_lon = window_size + mask = get_shift_window_mask(input_resolution, window_size, shift_size, ndim=2) + n_lon = Lon // win_lon + n_lat = Lat // win_lat + W = win_lat * win_lon + assert tuple(mask.shape) == (n_lon, n_lat, W, W) + + def test_values_binary(self): + """Mask must contain only 0.0 and -100.0.""" + mask = get_shift_window_mask((24, 48), (6, 12), (3, 6), ndim=2) + unique = sorted(torch.unique(mask).tolist()) + assert unique == [-100.0, 0.0] + + def test_longitude_unmasked_region_count(self): + """Longitude must not be partitioned: only Lat region IDs (3).""" + Lat, Lon = 24, 48 + win_lat, win_lon = 6, 12 + shift_lat, shift_lon = 3, 6 + + img_mask = torch.zeros((1, Lat, Lon, 1)) + lat_slices = (slice(0, -win_lat), slice(-win_lat, -shift_lat), slice(-shift_lat, None)) + cnt = 0 + for lat in lat_slices: + img_mask[:, lat, :, :] = cnt + cnt += 1 + + n_regions = len(torch.unique(img_mask)) + assert n_regions == 3, ( + f"Expected 3 region IDs (Lat only), got {n_regions}. " + "Longitude axis must not be partitioned in the mask." + ) + + +class TestWindowPartitionReverse: + """Round-trip tests: window_reverse(window_partition(x)) == x.""" + + @pytest.mark.parametrize( + "shape, window_size", + [ + ((2, 8, 24, 48, 16), (2, 6, 12)), # (B, Pl, Lat, Lon, C) 3D + ((2, 24, 48, 16), (6, 12)), # (B, Lat, Lon, C) 2D + ], + ) + def test_roundtrip(self, shape, window_size): + """window_reverse(window_partition(x)) must recover x exactly.""" + torch.manual_seed(0) + x = torch.randn(*shape) + ndim = len(window_size) + + partitioned = window_partition(x, window_size, ndim=ndim) + + if ndim == 3: + B, Pl, Lat, Lon, C = shape + recovered = window_reverse( + partitioned, window_size, Pl=Pl, Lat=Lat, Lon=Lon, ndim=ndim + ) + else: + B, Lat, Lon, C = shape + recovered = window_reverse( + partitioned, window_size, Lat=Lat, Lon=Lon, ndim=ndim + ) + + assert torch.allclose(x, recovered), "window_reverse did not invert window_partition" From a29c3e058827a143411b952f74ed7133524c90cc Mon Sep 17 00:00:00 2001 From: Alexander Wikner Date: Tue, 28 Apr 2026 05:09:45 -0500 Subject: [PATCH 2/3] Fix ruff lint errors in test_shift_window_mask Remove unused win_lon and shift_lon local variables from TestGetShiftWindowMask2D.test_longitude_unmasked_region_count; the test only needs win_lat and shift_lat to construct the latitude slices. Signed-off-by: Alexander Wikner Co-Authored-By: Claude Sonnet 4.6 (1M context) --- test/nn/module/test_shift_window_mask.py | 46 ++++++++++++++++-------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/test/nn/module/test_shift_window_mask.py b/test/nn/module/test_shift_window_mask.py index 1384beeb04..e567122f9b 100644 --- a/test/nn/module/test_shift_window_mask.py +++ b/test/nn/module/test_shift_window_mask.py @@ -36,9 +36,9 @@ class TestGetShiftWindowMask3D: @pytest.mark.parametrize( "input_resolution, window_size, shift_size", [ - ((8, 24, 48), (2, 6, 12), (1, 3, 6)), # default Pangu config - ((4, 12, 24), (2, 6, 12), (1, 3, 6)), # smaller resolution - ((8, 24, 48), (2, 6, 6), (1, 3, 3)), # square lon window + ((8, 24, 48), (2, 6, 12), (1, 3, 6)), # default Pangu config + ((4, 12, 24), (2, 6, 12), (1, 3, 6)), # smaller resolution + ((8, 24, 48), (2, 6, 6), (1, 3, 3)), # square lon window ], ) def test_output_shape(self, input_resolution, window_size, shift_size): @@ -75,8 +75,16 @@ def test_longitude_unmasked_region_count(self): # Reconstruct the underlying region-ID map directly img_mask = torch.zeros((1, Pl, Lat, Lon, 1)) - pl_slices = (slice(0, -win_pl), slice(-win_pl, -shift_pl), slice(-shift_pl, None)) - lat_slices = (slice(0, -win_lat), slice(-win_lat, -shift_lat), slice(-shift_lat, None)) + pl_slices = ( + slice(0, -win_pl), + slice(-win_pl, -shift_pl), + slice(-shift_pl, None), + ) + lat_slices = ( + slice(0, -win_lat), + slice(-win_lat, -shift_lat), + slice(-shift_lat, None), + ) cnt = 0 for pl in pl_slices: for lat in lat_slices: @@ -104,10 +112,14 @@ def test_no_shift_produces_zero_mask(self): # not error.) # We only check it does not raise. try: - mask = get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3) + mask = get_shift_window_mask( + input_resolution, window_size, shift_size, ndim=3 + ) assert mask is not None except Exception as exc: - pytest.fail(f"get_shift_window_mask raised unexpectedly with zero shift: {exc}") + pytest.fail( + f"get_shift_window_mask raised unexpectedly with zero shift: {exc}" + ) class TestGetShiftWindowMask2D: @@ -116,7 +128,7 @@ class TestGetShiftWindowMask2D: @pytest.mark.parametrize( "input_resolution, window_size, shift_size", [ - ((24, 48), (6, 12), (3, 6)), # default FengWu config (scaled) + ((24, 48), (6, 12), (3, 6)), # default FengWu config (scaled) ((12, 24), (6, 12), (3, 6)), ], ) @@ -139,11 +151,15 @@ def test_values_binary(self): def test_longitude_unmasked_region_count(self): """Longitude must not be partitioned: only Lat region IDs (3).""" Lat, Lon = 24, 48 - win_lat, win_lon = 6, 12 - shift_lat, shift_lon = 3, 6 + win_lat = 6 + shift_lat = 3 img_mask = torch.zeros((1, Lat, Lon, 1)) - lat_slices = (slice(0, -win_lat), slice(-win_lat, -shift_lat), slice(-shift_lat, None)) + lat_slices = ( + slice(0, -win_lat), + slice(-win_lat, -shift_lat), + slice(-shift_lat, None), + ) cnt = 0 for lat in lat_slices: img_mask[:, lat, :, :] = cnt @@ -162,8 +178,8 @@ class TestWindowPartitionReverse: @pytest.mark.parametrize( "shape, window_size", [ - ((2, 8, 24, 48, 16), (2, 6, 12)), # (B, Pl, Lat, Lon, C) 3D - ((2, 24, 48, 16), (6, 12)), # (B, Lat, Lon, C) 2D + ((2, 8, 24, 48, 16), (2, 6, 12)), # (B, Pl, Lat, Lon, C) 3D + ((2, 24, 48, 16), (6, 12)), # (B, Lat, Lon, C) 2D ], ) def test_roundtrip(self, shape, window_size): @@ -185,4 +201,6 @@ def test_roundtrip(self, shape, window_size): partitioned, window_size, Lat=Lat, Lon=Lon, ndim=ndim ) - assert torch.allclose(x, recovered), "window_reverse did not invert window_partition" + assert torch.allclose(x, recovered), ( + "window_reverse did not invert window_partition" + ) From d9f9439f10a2994beb16895e95f53d0717853df3 Mon Sep 17 00:00:00 2001 From: Alexander Wikner Date: Tue, 28 Apr 2026 08:20:04 -0500 Subject: [PATCH 3/3] Address automated review comments on shift_window_mask MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two changes based on greptile review of #1600: 1. Mark shift_lon as intentionally unused in get_shift_window_mask (both 3D and 2D branches) by replacing the variable name with _ . Leaving it named shift_lon implied it still played a role in longitude masking, which is the opposite of what this PR intends. 2. Rewrite test_longitude_unmasked_region_count (3D and 2D) to drive the assertion through the public API rather than a hand-rolled copy of the implementation. The new tests call get_shift_window_mask and assert that the returned mask is identical for every longitude window index — the direct invariant of an unmasked cyclic axis. Signed-off-by: Alexander Wikner Co-Authored-By: Claude Sonnet 4.6 (1M context) --- .../nn/module/utils/shift_window_mask.py | 4 +- test/nn/module/test_shift_window_mask.py | 79 +++++++------------ 2 files changed, 32 insertions(+), 51 deletions(-) diff --git a/physicsnemo/nn/module/utils/shift_window_mask.py b/physicsnemo/nn/module/utils/shift_window_mask.py index 796922dba9..e5982fc401 100644 --- a/physicsnemo/nn/module/utils/shift_window_mask.py +++ b/physicsnemo/nn/module/utils/shift_window_mask.py @@ -107,13 +107,13 @@ def get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3): if ndim == 3: Pl, Lat, Lon = input_resolution win_pl, win_lat, win_lon = window_size - shift_pl, shift_lat, shift_lon = shift_size + shift_pl, shift_lat, _ = shift_size # longitude is cyclic; lon shift unused img_mask = torch.zeros((1, Pl, Lat, Lon, 1)) elif ndim == 2: Lat, Lon = input_resolution win_lat, win_lon = window_size - shift_lat, shift_lon = shift_size + shift_lat, _ = shift_size # longitude is cyclic; lon shift unused img_mask = torch.zeros((1, Lat, Lon, 1)) diff --git a/test/nn/module/test_shift_window_mask.py b/test/nn/module/test_shift_window_mask.py index e567122f9b..d0550edf92 100644 --- a/test/nn/module/test_shift_window_mask.py +++ b/test/nn/module/test_shift_window_mask.py @@ -65,37 +65,22 @@ def test_values_binary(self, input_resolution, window_size, shift_size): assert unique == [-100.0, 0.0] def test_longitude_unmasked_region_count(self): - """Longitude must not be partitioned: only Pl x Lat region IDs (9).""" - input_resolution = (8, 24, 48) - window_size = (2, 6, 12) - shift_size = (1, 3, 6) - Pl, Lat, Lon = input_resolution - win_pl, win_lat, win_lon = window_size - shift_pl, shift_lat, shift_lon = shift_size - - # Reconstruct the underlying region-ID map directly - img_mask = torch.zeros((1, Pl, Lat, Lon, 1)) - pl_slices = ( - slice(0, -win_pl), - slice(-win_pl, -shift_pl), - slice(-shift_pl, None), - ) - lat_slices = ( - slice(0, -win_lat), - slice(-win_lat, -shift_lat), - slice(-shift_lat, None), + """Longitude must not be partitioned: mask must be identical for every n_lon index. + + If the longitude axis were masked, different longitude window positions would + have different attention patterns. With longitude unmasked, the mask depends + only on (Pl, Lat) region membership, so all n_lon slices must be equal. + """ + mask = get_shift_window_mask( + input_resolution=(8, 24, 48), + window_size=(2, 6, 12), + shift_size=(1, 3, 6), + ndim=3, ) - cnt = 0 - for pl in pl_slices: - for lat in lat_slices: - img_mask[:, pl, lat, :, :] = cnt - cnt += 1 - - n_regions = len(torch.unique(img_mask)) - # 3 Pl bands x 3 Lat bands = 9; longitude must NOT add more partitions - assert n_regions == 9, ( - f"Expected 9 region IDs (Pl x Lat only), got {n_regions}. " - "Longitude axis must not be partitioned in the mask." + # mask shape: (n_lon, n_pl*n_lat, W, W) + assert torch.all(mask == mask[0].unsqueeze(0)), ( + "Mask differs across longitude window indices — longitude axis is being " + "partitioned. All n_lon slices must be identical when longitude is cyclic." ) def test_no_shift_produces_zero_mask(self): @@ -149,26 +134,22 @@ def test_values_binary(self): assert unique == [-100.0, 0.0] def test_longitude_unmasked_region_count(self): - """Longitude must not be partitioned: only Lat region IDs (3).""" - Lat, Lon = 24, 48 - win_lat = 6 - shift_lat = 3 - - img_mask = torch.zeros((1, Lat, Lon, 1)) - lat_slices = ( - slice(0, -win_lat), - slice(-win_lat, -shift_lat), - slice(-shift_lat, None), + """Longitude must not be partitioned: mask must be identical for every n_lon index. + + If the longitude axis were masked, different longitude window positions would + have different attention patterns. With longitude unmasked, the mask depends + only on Lat region membership, so all n_lon slices must be equal. + """ + mask = get_shift_window_mask( + input_resolution=(24, 48), + window_size=(6, 12), + shift_size=(3, 6), + ndim=2, ) - cnt = 0 - for lat in lat_slices: - img_mask[:, lat, :, :] = cnt - cnt += 1 - - n_regions = len(torch.unique(img_mask)) - assert n_regions == 3, ( - f"Expected 3 region IDs (Lat only), got {n_regions}. " - "Longitude axis must not be partitioned in the mask." + # mask shape: (n_lon, n_lat, W, W) + assert torch.all(mask == mask[0].unsqueeze(0)), ( + "Mask differs across longitude window indices — longitude axis is being " + "partitioned. All n_lon slices must be identical when longitude is cyclic." )