Summary:
- Combines video encoder, temporal attention, spatial transformer, and decoder
- Encodes video into spatio-temporal patches
- Aggregates temporal information per spatial patch
- Mixes spatial features across patches
- Decodes back to original spatial resolution
Detailed process:
The model takes daily SST (or similar) data in video format: x ∈ ℝ^{B × 1 × T × H × W} and a daily_mask indicating missing pixels. It also takes
land_mask_patch indicating land regions in the output.
# 1. Patch embedding:
X (VideoEncoder)---------> X_patch
# 2. Add temporal encoding +
# 3. Temporal aggregation:
X_patch + PE (TemporalAttentionAggregator)---------> X_temp_agg
# 4. Add spatial encoding +
# 5. Spatial transformer:
X_temp_agg + PE (SpatialTransformer) ---------> X_mixed
# 6. Decode to original resolution:
X_mixed (MonthlyConvDecoder)---------> OutputSummary:
- Masked pixels are removed but their locations are preserved
- Video is split into 3D patches (time × height × width)
- Each patch becomes a vector (embedding)
- Output is a sequence suitable for Transformer-based video models (e.g. VideoMAE)
Detailed process:
We start with a video: Input video x ∈ R^{B × 1 × T × H × W} (batch, channel,
time, height, width) and mask ∈ {0,1}^{B × 1 × T × H × W} where 1 (True) means
missing / masked at ocean pixels. We define a validity indicator: valid = 1 − mask. So: valid = 1 → observed pixel, valid = 0 → missing pixel. We zero out
missing values: x_valid = x ⊙ valid.
We then concatenate the validity mask as a second channel: x_cat = concat(x_valid, valid). Now the input has 2 channels: x_cat ∈ R^{B × 2 × T × H × W}. This allows the model to know which values were observed and which were
missing.
We split the video into non-overlapping spatio-temporal patches using a 3D
convolution, see
torch.nn.Conv3d.
Let the patch size be: (Pt, Ph, Pw). The convolution uses: kernel size = (Pt, Ph, Pw), stride = (Pt, Ph, Pw). This means each convolution output
corresponds to one patch and patches do not overlap. Resulting shape: z ∈ R^{B × D × T' × H' × W'} where: D = embed_dim, T' = T / Pt, H' = H / Ph, W' = W / Pw. Each (t', h', w') location is a patch embedding vector of length D.
We flatten the 3D grid of patches into a sequence: N = T' × H' × W'. So each
video becomes a sequence of patch embeddings, just like tokens in a Transformer.
For each patch embedding, Layer normalization is done. This stabilizes training by normalizing across the embedding dimension, see torch.nn.LayerNorm.
Randomly drops elements for regularization is done with Dropout. This helps prevent overfitting during training, see torch.nn.Dropout.
The final output is: {B × N_patches × embed_dim}. Each element
represents a spatio-temporal video patch, enriched with: visual information,
knowledge of which pixels were valid or missing.
Summary:
- Each time step gets a unique vector
- Encodings are deterministic and fixed
- No learnable parameters
- Based on the Transformer positional encoding design
Detailed process:
The purpose of temporal positional encoding is to generate fixed temporal position vectors so that a model can know at which time index a feature occurs. The encoding depends only on time index, not on the data.
Assume a temporal sequence of length: T = 0, 1, 2, ..., T−1. Each time index t
is assigned a vector of length embedding dim. For time index t and
embedding dimension index i: Even dimensions use sine, odd dimensions use
cosine. This produces a unique, smooth encoding for each time step.
For a maximum supported temporal length max_len, the class precomputes pe
where row t contains the encoding for time index t. This matrix is fixed,
not trainable and stored as a buffer. Later, in forward method, given a
requested temporal length T, we have output = PE[0:T] and resulting shape is
(T, embed_dim). No parameters are learned and no computation depends on input
data.
This is a temporal attention pooling over sequences of tokens.
Summary:
For each spatial patch (h, w):
- Collect its T temporal tokens: each spatial patch (h, w) has T temporal tokens
- Add temporal positional encoding for days within a month and for months within a year.
- Compute a learned scalar score per time step
- Apply softmax over time: it ensures the weights form a probability distribution over time.
- Compute a weighted temporal sum: output is a temporal summary vector for each patch, suitable for downstream tasks.
Detailed process:
We start with: x ∈ ℝ^{B × (M.T·H·W) × C}, where B = batch size, M= number
of months, T = number of temporal tokens per spatial patch, H, W = number
of spatial patches along height and width, C = embedding dimension. We can
reshape it to group by spatial patch where each spatial patch has its temporal
sequence of length T.
Then we add temporal positional encoding pe from TemporalPositionalEncoding.
Add it to each temporal token seq = seq + pe. This injects time information (for days and for months) into each patch’s token sequence.
Then we compute temporal attention weights by applying nn.Sequential to get a
scalar score, see
torch.nn.Sequential.
Here the explanation over each module in the sequential is as follows:
LayerNormnormalizes the features across the embedding dimension, which helps stabilize training. This is a common practice before attention computation.Linear(embed_dim, embed_dim)learns which features are important for temporal weighting.GELUallows learning non-linear relationships.Linear(embed_dim, 1)projects to a single scalar score.
see torch.nn for more details on each module.
We apply the mask padded_days_mask where padded days (beyond the actual number
of days in a month) are masked out. This ensures that the model does not attend
to padded tokens that do not correspond to real data.Then, we convert scores
into attention weights using softmax over time that represents importance of
each temporal token for this patch. Then we aggregate temporal tokens by
weighted sum over the temporal dimension. Result is one token per spatial patch.
Then we apply cross month mixing using nn.MultiheadAttention. This allows
tokens from different months to attend to each other, which can capture seasonal patterns. The attention is computed across the temporal dimension for each spatial patch. See torch.nn.MultiheadAttention for details on how multi-head attention works.
Then we apply nn.Sequential again to the output of attention to get a final scalar score for each time step. The explanation for each module is the same as before: LayerNorm, Linear, GELU, Linear to 1 scalar. This gives us a score for each time step in the temporal sequence of each spatial patch.
Summary:
- Generate fixed sinusoidal positional encodings for 2D spatial grid
- Encodings are not learnable
- Intended to be added to spatial tokens
- Sine and cosine functions of different frequencies allow the model to distinguish positions along height and width.
Detailed process:
The module generates fixed 2D positional encodings for a grid of size (H, W) and
embedding dimension embed_dim. Each spatial location (h, w) gets a unique vector
of length embed_dim. Encodings are deterministic (not learned) and based on
sine/cosine functions, similar to the Temporal positional encoding. The
encoding for each spatial location is a combination of sine and cosine functions
of different frequencies along height and width. This allows the model to know
the spatial position of each token when added to the spatial tokens.
It mixes spatial patch tokens using multi-head self-attention.
Summary:
- Applies a standard Transformer encoder to spatial tokens
- Mixes information across spatial locations: each patch token can attend to all other patches, allowing global spatial context
- Output is a sequence of spatially mixed tokens of the same shape as input
Detailed process:
We apply a single Transformer encoder block using nn.TransformerEncoderLayer
that performs self-attention and feedforward processing. In Multi-Head
Self-Attention, Every token looks at every other token and with num_heads=4,
this happens in 4 parallel "perspectives." In Feedforward Network (MLP), each
token is processed independently through a small neural network to allow complex
feature interactions. This is useful for spatial data and allows:
- Global context: Every patch can "see" every other patch
- Spatial mixing: Information flows across the entire image
- Learning relationships: Model learns which patches are relevant to each other
Then, nn.TransformerEncoder stacks multiple encoder layers sequentially; it's
a container that repeats the same transformer block depth times.
Summary:
- Reshape latent tokens
- Apply 1x1 convolution to mix features
- Use transposed convolution to upsample to original spatial size
- Applies a convolutional refinement block to smooth patch boundaries.
- Apply convolutional head for final output
- Optionally mask out land regions
- Add scale and bias to output
Detailed process:
We transforms the embedding dimension (default embed_dim=128) to the hidden
dimension (default hidden=128) at each spatial location independently by a 1×1
convolution (also called a "pointwise convolution") using nn.Conv2d(..., kernel_size=1). Even though both input and output dimensions are 128 by
default, this layer learns a linear transformation to mix and re-weight the
channel features.
Then we use a deconvolution (ConvTranspose2d) to map each patch to its
original pixel grid. It converts the low-resolution patch grid back to the
original image resolution. The padding is set to the overlap, which allows for
smooth upsampling and some overlap between patches.
Then we apply a convolutional refinement block to smooth out any artifacts from
the deconvolution. Deconvolution can sometimes produce blocky outputs, so this
refinement helps to smooth patch boundaries and improve spatial coherence. This
block consists of a nn.Sequential. This processing block contains:
Conv2d: Convolution to mix features and reduce artifacts.BatchNorm2d: Normalizes the features to stabilize training.GELU: Non-linear activation to allow complex feature interactions.Conv2d(1×1): Final 1×1 convolution for deeper feature refinement.BatchNorm2d: Re-normalizes the features to stabilize training.GELU: Another non-linear activation to allow complex feature interactions.
A single Conv2d(..., 1, kernel_size=1) could work, but the extra layers provide:
- Spatial refinement after deconvolution (which can produce artifacts)
- Non-linearity for more expressive power
- Better gradient flow during training
Finally, we apply a small convolutional head to produce the final single-channel output. Then we apply scale, bias, and (optional) land mask to get the final output (he reconstructed 2D map (e.g., SST) for each batch).