Version
2.1.0a0
On which installation method(s) does this occur?
Pip
Describe the issue
Summary
get_shift_window_mask in physicsnemo/nn/module/utils/shift_window_mask.py
applies an attention mask along the longitude dimension of the global
latitude-longitude grid. Because longitude is cyclic (the grid wraps around
the globe), this masking is incorrect: it prevents tokens that are physically
adjacent across the dateline from attending to one another inside the same
shifted window. The Pangu-Weather paper explicitly states that wrap-around
longitude windows should be merged into one window, not separated by a
mask.
This bug affects both the 3D path used by Pangu-Weather and the 2D path used
by FengWu.
Affected models and files
| Model |
Attention path |
File |
| Pangu-Weather |
Transformer3DBlock (3D) |
physicsnemo/nn/module/utils/shift_window_mask.py |
| FengWu |
Transformer2DBlock (2D) |
physicsnemo/nn/module/utils/shift_window_mask.py |
Background
Shifted-window self-attention (SW-MSA, Swin Transformer) uses a cyclic shift
(torch.roll) to allow attention across window boundaries, followed by a mask
that prevents tokens from different non-adjacent regions from attending to one
another within the same window.
For a non-cyclic spatial dimension (e.g., latitude, pressure level),
masking the shifted boundary is correct: the tokens brought together by the
shift were originally far apart and should not attend to each other.
For a cyclic spatial dimension (longitude on a global grid), the situation
is the opposite: after a cyclic roll, the tokens at both ends of the longitude
axis that end up sharing a window are physically adjacent — they are
neighbours across the dateline. They should be allowed to attend freely.
The Pangu-Weather paper (Bi et al., arXiv:2211.02556 / Nature 2023)
states this explicitly:
"Along the longitude dimension, the leftmost and rightmost indices are
actually close to each other. If half windows appear at both leftmost and
rightmost positions, they are directly merged into one window."
The same paper omits M_lon from the Earth-specific position bias because
"the longitude indices are cyclic and spacing is evenly distributed along
this axis." The physicsnemo implementation of get_earth_position_index
(shared by both EarthAttention3D and EarthAttention2D) reproduces this
cyclic design and is explicitly annotated as derived from the official
Pangu-Weather pseudocode.
What the code currently does (incorrect)
In get_shift_window_mask for the 3D path:
img_mask = torch.zeros((1, Pl, Lat, Lon + shift_lon, 1)) # padded in Lon
...
lon_slices = (
slice(0, -win_lon),
slice(-win_lon, -shift_lon),
slice(-shift_lon, None),
)
for pl in pl_slices:
for lat in lat_slices:
for lon in lon_slices: # ← iterates Lon as non-cyclic
img_mask[:, pl, lat, lon, :] = cnt
cnt += 1
img_mask = img_mask[:, :, :, :Lon, :] # ← crops away the padded region
This assigns 27 distinct region IDs (3 Pl × 3 Lat × 3 Lon). After cropping
back to Lon, the third longitude region (slice(-shift_lon, None)) is
removed, leaving 18 effective region IDs (3 × 3 × 2). Any window that
straddles the dateline after the cyclic roll ends up with two non-attending
sub-groups — the opposite of "merged into one window."
The same structure is present in the 2D path (FengWu), producing 2 × 3 = 6
effective region IDs instead of 3.
What the code should do (correct)
The mask should iterate only over the non-cyclic axes — (Pl, Lat) in 3D
and Lat in 2D — and leave the longitude dimension unpartitioned:
img_mask = torch.zeros((1, Pl, Lat, Lon, 1)) # no Lon padding
...
# no lon_slices
for pl in pl_slices:
for lat in lat_slices:
img_mask[:, pl, lat, :, :] = cnt # ← full Lon range unmasked
cnt += 1
# no :Lon crop
This assigns 9 region IDs (3 Pl × 3 Lat), and wrap-around longitude windows
attend freely.
Relationship to PR #1492
PR #1492 fixed a shift_lat-vs-shift_lon typo that caused the forward
cyclic roll to shift the longitude axis by the wrong amount. That fix makes
the forward and reverse rolls consistent again. However, the attention mask
was built independently and still partitions the longitude axis as if it were
non-cyclic. The two bugs compound each other: even with the corrected roll,
the mask incorrectly blocks attention at the dateline.
Observed effect
With the default Pangu window configuration (window_size=(2, 6, 12),
shift_size=(1, 3, 6)), the attention mask contains 18 distinct region
boundaries instead of the correct 9. Any attention block in the shifted
(rolled) layer silently drops cross-dateline context for all tokens that fall
in a wrap-around longitude window.
Minimum reproducible example
python
from physicsnemo.nn.module.utils.shift_window_mask import get_shift_window_mask
import torch
mask = get_shift_window_mask(
input_resolution=(8, 24, 48),
window_size=(2, 6, 12),
shift_size=(1, 3, 6),
ndim=3,
)
# Reconstruct the underlying region-ID map to count distinct IDs
img_mask = torch.zeros((1, 8, 24, 48 + 6, 1))
pl_slices = (slice(0, -2), slice(-2, -1), slice(-1, None))
lat_slices = (slice(0, -6), slice(-6, -3), slice(-3, None))
lon_slices = (slice(0, -12), slice(-12, -6), slice(-6, None))
cnt = 0
for pl in pl_slices:
for lat in lat_slices:
for lon in lon_slices:
img_mask[:, pl, lat, lon, :] = cnt
cnt += 1
img_mask = img_mask[:, :, :, :48, :]
n_regions = len(torch.unique(img_mask))
print(f"Region IDs in mask: {n_regions}") # prints 18; should be 9
Relevant log output
Environment details
conda 25.7.0
Python 3.12.13
physicsnemo version: `2.1.0a0`
Version
2.1.0a0
On which installation method(s) does this occur?
Pip
Describe the issue
Summary
get_shift_window_maskinphysicsnemo/nn/module/utils/shift_window_mask.pyapplies an attention mask along the longitude dimension of the global
latitude-longitude grid. Because longitude is cyclic (the grid wraps around
the globe), this masking is incorrect: it prevents tokens that are physically
adjacent across the dateline from attending to one another inside the same
shifted window. The Pangu-Weather paper explicitly states that wrap-around
longitude windows should be merged into one window, not separated by a
mask.
This bug affects both the 3D path used by Pangu-Weather and the 2D path used
by FengWu.
Affected models and files
Transformer3DBlock(3D)physicsnemo/nn/module/utils/shift_window_mask.pyTransformer2DBlock(2D)physicsnemo/nn/module/utils/shift_window_mask.pyBackground
Shifted-window self-attention (SW-MSA, Swin Transformer) uses a cyclic shift
(
torch.roll) to allow attention across window boundaries, followed by a maskthat prevents tokens from different non-adjacent regions from attending to one
another within the same window.
For a non-cyclic spatial dimension (e.g., latitude, pressure level),
masking the shifted boundary is correct: the tokens brought together by the
shift were originally far apart and should not attend to each other.
For a cyclic spatial dimension (longitude on a global grid), the situation
is the opposite: after a cyclic roll, the tokens at both ends of the longitude
axis that end up sharing a window are physically adjacent — they are
neighbours across the dateline. They should be allowed to attend freely.
The Pangu-Weather paper (Bi et al., arXiv:2211.02556 / Nature 2023)
states this explicitly:
The same paper omits
M_lonfrom the Earth-specific position bias because"the longitude indices are cyclic and spacing is evenly distributed along
this axis." The physicsnemo implementation of
get_earth_position_index(shared by both
EarthAttention3DandEarthAttention2D) reproduces thiscyclic design and is explicitly annotated as derived from the official
Pangu-Weather pseudocode.
What the code currently does (incorrect)
In
get_shift_window_maskfor the 3D path:This assigns 27 distinct region IDs (3 Pl × 3 Lat × 3 Lon). After cropping
back to
Lon, the third longitude region (slice(-shift_lon, None)) isremoved, leaving 18 effective region IDs (3 × 3 × 2). Any window that
straddles the dateline after the cyclic roll ends up with two non-attending
sub-groups — the opposite of "merged into one window."
The same structure is present in the 2D path (FengWu), producing 2 × 3 = 6
effective region IDs instead of 3.
What the code should do (correct)
The mask should iterate only over the non-cyclic axes —
(Pl, Lat)in 3Dand
Latin 2D — and leave the longitude dimension unpartitioned:This assigns 9 region IDs (3 Pl × 3 Lat), and wrap-around longitude windows
attend freely.
Relationship to PR #1492
PR #1492 fixed a
shift_lat-vs-shift_lontypo that caused the forwardcyclic roll to shift the longitude axis by the wrong amount. That fix makes
the forward and reverse rolls consistent again. However, the attention mask
was built independently and still partitions the longitude axis as if it were
non-cyclic. The two bugs compound each other: even with the corrected roll,
the mask incorrectly blocks attention at the dateline.
Observed effect
With the default Pangu window configuration (
window_size=(2, 6, 12),shift_size=(1, 3, 6)), the attention mask contains 18 distinct regionboundaries instead of the correct 9. Any attention block in the shifted
(rolled) layer silently drops cross-dateline context for all tokens that fall
in a wrap-around longitude window.
Minimum reproducible example
python from physicsnemo.nn.module.utils.shift_window_mask import get_shift_window_mask import torch mask = get_shift_window_mask( input_resolution=(8, 24, 48), window_size=(2, 6, 12), shift_size=(1, 3, 6), ndim=3, ) # Reconstruct the underlying region-ID map to count distinct IDs img_mask = torch.zeros((1, 8, 24, 48 + 6, 1)) pl_slices = (slice(0, -2), slice(-2, -1), slice(-1, None)) lat_slices = (slice(0, -6), slice(-6, -3), slice(-3, None)) lon_slices = (slice(0, -12), slice(-12, -6), slice(-6, None)) cnt = 0 for pl in pl_slices: for lat in lat_slices: for lon in lon_slices: img_mask[:, pl, lat, lon, :] = cnt cnt += 1 img_mask = img_mask[:, :, :, :48, :] n_regions = len(torch.unique(img_mask)) print(f"Region IDs in mask: {n_regions}") # prints 18; should be 9Relevant log output
Environment details
conda 25.7.0 Python 3.12.13 physicsnemo version: `2.1.0a0`