Skip to content

🐛[BUG]: Shifted-window attention mask incorrectly partitions the cyclic longitude axis (Pangu, FengWu) #1599

@awikner

Description

@awikner

Version

2.1.0a0

On which installation method(s) does this occur?

Pip

Describe the issue

Summary

get_shift_window_mask in physicsnemo/nn/module/utils/shift_window_mask.py
applies an attention mask along the longitude dimension of the global
latitude-longitude grid. Because longitude is cyclic (the grid wraps around
the globe), this masking is incorrect: it prevents tokens that are physically
adjacent across the dateline from attending to one another inside the same
shifted window. The Pangu-Weather paper explicitly states that wrap-around
longitude windows should be merged into one window, not separated by a
mask.

This bug affects both the 3D path used by Pangu-Weather and the 2D path used
by FengWu.

Affected models and files

Model Attention path File
Pangu-Weather Transformer3DBlock (3D) physicsnemo/nn/module/utils/shift_window_mask.py
FengWu Transformer2DBlock (2D) physicsnemo/nn/module/utils/shift_window_mask.py

Background

Shifted-window self-attention (SW-MSA, Swin Transformer) uses a cyclic shift
(torch.roll) to allow attention across window boundaries, followed by a mask
that prevents tokens from different non-adjacent regions from attending to one
another within the same window.

For a non-cyclic spatial dimension (e.g., latitude, pressure level),
masking the shifted boundary is correct: the tokens brought together by the
shift were originally far apart and should not attend to each other.

For a cyclic spatial dimension (longitude on a global grid), the situation
is the opposite: after a cyclic roll, the tokens at both ends of the longitude
axis that end up sharing a window are physically adjacent — they are
neighbours across the dateline. They should be allowed to attend freely.

The Pangu-Weather paper (Bi et al., arXiv:2211.02556 / Nature 2023)
states this explicitly:

"Along the longitude dimension, the leftmost and rightmost indices are
actually close to each other. If half windows appear at both leftmost and
rightmost positions, they are directly merged into one window."

The same paper omits M_lon from the Earth-specific position bias because
"the longitude indices are cyclic and spacing is evenly distributed along
this axis."
The physicsnemo implementation of get_earth_position_index
(shared by both EarthAttention3D and EarthAttention2D) reproduces this
cyclic design and is explicitly annotated as derived from the official
Pangu-Weather pseudocode.

What the code currently does (incorrect)

In get_shift_window_mask for the 3D path:

img_mask = torch.zeros((1, Pl, Lat, Lon + shift_lon, 1))   # padded in Lon
...
lon_slices = (
    slice(0, -win_lon),
    slice(-win_lon, -shift_lon),
    slice(-shift_lon, None),
)
for pl in pl_slices:
    for lat in lat_slices:
        for lon in lon_slices:                  # ← iterates Lon as non-cyclic
            img_mask[:, pl, lat, lon, :] = cnt
            cnt += 1
img_mask = img_mask[:, :, :, :Lon, :]           # ← crops away the padded region

This assigns 27 distinct region IDs (3 Pl × 3 Lat × 3 Lon). After cropping
back to Lon, the third longitude region (slice(-shift_lon, None)) is
removed, leaving 18 effective region IDs (3 × 3 × 2). Any window that
straddles the dateline after the cyclic roll ends up with two non-attending
sub-groups — the opposite of "merged into one window."

The same structure is present in the 2D path (FengWu), producing 2 × 3 = 6
effective region IDs instead of 3.

What the code should do (correct)

The mask should iterate only over the non-cyclic axes — (Pl, Lat) in 3D
and Lat in 2D — and leave the longitude dimension unpartitioned:

img_mask = torch.zeros((1, Pl, Lat, Lon, 1))   # no Lon padding
...
# no lon_slices
for pl in pl_slices:
    for lat in lat_slices:
        img_mask[:, pl, lat, :, :] = cnt        # ← full Lon range unmasked
        cnt += 1
# no :Lon crop

This assigns 9 region IDs (3 Pl × 3 Lat), and wrap-around longitude windows
attend freely.

Relationship to PR #1492

PR #1492 fixed a shift_lat-vs-shift_lon typo that caused the forward
cyclic roll to shift the longitude axis by the wrong amount. That fix makes
the forward and reverse rolls consistent again. However, the attention mask
was built independently and still partitions the longitude axis as if it were
non-cyclic. The two bugs compound each other: even with the corrected roll,
the mask incorrectly blocks attention at the dateline.

Observed effect

With the default Pangu window configuration (window_size=(2, 6, 12),
shift_size=(1, 3, 6)), the attention mask contains 18 distinct region
boundaries instead of the correct 9. Any attention block in the shifted
(rolled) layer silently drops cross-dateline context for all tokens that fall
in a wrap-around longitude window.

Minimum reproducible example

python
from physicsnemo.nn.module.utils.shift_window_mask import get_shift_window_mask
import torch

mask = get_shift_window_mask(
    input_resolution=(8, 24, 48),
    window_size=(2, 6, 12),
    shift_size=(1, 3, 6),
    ndim=3,
)

# Reconstruct the underlying region-ID map to count distinct IDs
img_mask = torch.zeros((1, 8, 24, 48 + 6, 1))
pl_slices  = (slice(0, -2), slice(-2, -1), slice(-1, None))
lat_slices = (slice(0, -6), slice(-6, -3), slice(-3, None))
lon_slices = (slice(0, -12), slice(-12, -6), slice(-6, None))
cnt = 0
for pl in pl_slices:
    for lat in lat_slices:
        for lon in lon_slices:
            img_mask[:, pl, lat, lon, :] = cnt
            cnt += 1
img_mask = img_mask[:, :, :, :48, :]

n_regions = len(torch.unique(img_mask))
print(f"Region IDs in mask: {n_regions}")   # prints 18; should be 9

Relevant log output

Environment details

conda 25.7.0
Python 3.12.13
physicsnemo version: `2.1.0a0`

Metadata

Metadata

Assignees

No one assigned

    Labels

    ? - Needs TriageNeed team to review and classifybugSomething isn't workingexternalIssues/PR filed by people outside the team

    Type

    No fields configured for Bug.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions