Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 50 additions & 13 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
from .utils import add_month_day_dims
import xarray as xr
Expand All @@ -16,12 +18,10 @@ def __init__(
land_mask: xr.DataArray = None,
time_dim: str = "time",
spatial_dims: Tuple[str, str] = ("lat", "lon"),
patch_size: Tuple[int, int] = (16, 16),
overlap: int = 0,
patch_size: Tuple[int, int] = (16, 16), # (lat, lon)
):
self.spatial_dims = spatial_dims
self.patch_size = patch_size
self.overlap = overlap

# Check that the input data has the expected dimensions
if time_dim not in daily_da.dims or time_dim not in monthly_da.dims:
Expand All @@ -30,6 +30,13 @@ def __init__(
if dim not in daily_da.dims or dim not in monthly_da.dims:
raise ValueError(f"Spatial dimension '{dim}' not found in input data")

if (
patch_size[0] > daily_da.sizes[spatial_dims[0]] or patch_size[1] > daily_da.sizes[spatial_dims[1]]
):
raise ValueError(
f"Patch size {patch_size} is larger than data dimensions {daily_da.sizes[spatial_dims]}"
)

# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31)
daily_mt, monthly_m, padded_days_mask = add_month_day_dims(
Expand All @@ -41,6 +48,10 @@ def __init__(
self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float
self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool

# Store coordinate arrays
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
self.lon_coords = daily_da[spatial_dims[1]].to_numpy().copy()

if land_mask is not None:
lm = land_mask.to_numpy().copy()
if lm.ndim == 3:
Expand All @@ -60,25 +71,45 @@ def __init__(
self.padded_days_tensor = torch.from_numpy(self.padded_mask_np).bool()

# Precompute lazy index mapping for patches
self.stride = self.patch_size[0] - self.overlap
H, W = self.daily_np.shape[2], self.daily_np.shape[3]
self.n_i = (H - self.patch_size[0]) // self.stride + 1
self.n_j = (W - self.patch_size[1]) // self.stride + 1
self.patch_indices = self._compute_patch_indices(H, W)

def _compute_patch_indices(self, H: int, W: int) -> list:
"""Generate non-overlapping patch start indices with coverage warning."""
ph, pw = self.patch_size

# Compute number of full non-overlapping patches
n_patches_h = H // ph
n_patches_w = W // pw

# Check for incomplete coverage
remainder_h = H % ph
remainder_w = W % pw

if remainder_h > 0 or remainder_w > 0:
warnings.warn(
f"Patch size {self.patch_size} does not evenly divide image dimensions (H={H}, W={W}). "
f"Uncovered pixels: {remainder_h} in height, {remainder_w} in width. "
f"Consider adjusting patch_size or image dimensions for full coverage.",
UserWarning
)

# Generate non-overlapping patch indices
i_starts = [i * ph for i in range(n_patches_h)]
j_starts = [j * pw for j in range(n_patches_w)]

return [(i, j) for i in i_starts for j in j_starts]

# Total length is only spatial patches (all months included in each sample)
self.total_len = self.n_i * self.n_j

def __len__(self):
return self.total_len
return len(self.patch_indices)

def __getitem__(self, idx):
"""Get a spatiotemporal patch sample based on the index."""
if idx < 0 or idx >= self.total_len:
if idx < 0 or idx >= len(self.patch_indices):
raise IndexError("Index out of range")

i_idx, j_idx = divmod(idx, self.n_j)
i = i_idx * self.stride
j = j_idx * self.stride
i, j = self.patch_indices[idx]
ph, pw = self.patch_size

# Extract spatial patch via numpy slicing — faster than xarray indexing
Expand Down Expand Up @@ -108,6 +139,10 @@ def __getitem__(self, idx):
~land_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(0)
)

# Extract lat/lon coordinates for this patch
lat_patch = self.lat_coords[i : i + ph]
lon_patch = self.lon_coords[j : j + pw]

# Convert to tensors
return {
"daily_patch": daily_tensor, # (C=1, M, T=31, H, W)
Expand All @@ -116,4 +151,6 @@ def __getitem__(self, idx):
"land_mask_patch": land_tensor, # (H,W) True=Land
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
"coords": (i, j),
"lat_patch": lat_patch, # (H,)
"lon_patch": lon_patch, # (W,)
}
97 changes: 56 additions & 41 deletions climanet/st_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12):
def forward(self, x, M, T, H, W, padded_days_mask=None):
"""
Args:
x: (B, M*T*H*W, C) containing spatio-temporal tokens, where C is the embedding dimension.
x: (B, M, T, H, W, C) containing spatio-temporal tokens, where C is the embedding dimension.
M: number of months
T: number of temporal tokens per month after temporal patching (Tp)
H: spatial height after spatial patching
Expand All @@ -194,9 +194,12 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):
True indicating which day tokens are padded (because some months
have fewer days). This is used to mask out padded tokens in attention computation.
Returns:
Tensor of shape (B, M*H*W, C) with one temporally aggregated, where C is the embedding dimension.
Tensor of shape (B, M, H*W, C) with one temporally aggregated, where C is the embedding dimension.
"""
seq = rearrange(x, "b (m t h w) c -> b (h w) m t c", m=M, t=T, h=H, w=W)
B, M, Tp, Hp, Wp, C = x.shape

# Reshape to (B, Hp*Wp, M, Tp, C) for temporal processing
seq = x.permute(0, 3, 4, 1, 2, 5).reshape(B, Hp * Wp, M, Tp, C)

pe_days = self.pos_days(T).to(seq.device).to(seq.dtype) # (T, C)
pe_months = self.pos_months(M).to(seq.device).to(seq.dtype) # (M, C)
Expand All @@ -209,10 +212,10 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):

# padded_days_mask is (B, M, T) true=padded, -> (B, HW, M, T)
if padded_days_mask is not None:
pad = padded_days_mask[:, None, :, :].expand(x.shape[0], H * W, M, T)
pad = padded_days_mask[:, None, :, :].expand(B, H * W, M, T)
day_logits = day_logits.masked_fill(pad, float("-inf"))

day_w = torch.softmax(day_logits, dim=-1)
day_w = torch.softmax(day_logits, dim=-1) # turns inf to 0
month_tokens = (seq * day_w.unsqueeze(-1)).sum(dim=3) # (B, HW, M, C)

# Cross-month attention at each spatial location
Expand All @@ -222,10 +225,10 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):
z = z + attn_out
z = z + self.month_ffn(z)

# Back to flattened tokens with month kept
z = rearrange(z, "(b s) m c -> b s m c", b=x.shape[0], s=H * W)
out = rearrange(z, "b (h w) m c -> b (m h w) c", h=H, w=W)
return out # (B, M*H*W, C) C: embedding dimension
# Back to (B, M, Hp*Wp, C)
z = z.view(B, Hp * Wp, M, C)
out = z.permute(0, 2, 1, 3) # (B, M, Hp*Wp, C)
return out # (B, M, H*W, C) C: embedding dimension


class MonthlyConvDecoder(nn.Module):
Expand Down Expand Up @@ -293,10 +296,10 @@ def __init__(
# Refinement block: a small conv layers to smooth patch boundaries
self.refine = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.GroupNorm(num_groups=8, num_channels=out_channels),
nn.GELU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.GroupNorm(num_groups=8, num_channels=out_channels),
nn.GELU(),
)

Expand All @@ -314,18 +317,18 @@ def forward(self, latent, M, out_H, out_W, land_mask=None):
M: Number of months (temporal patches)
out_H: Target output height (must be divisible by patch_h)
out_W: Target output width (must be divisible by patch_w)
land_mask: Optional boolean tensor of shape (out_H, out_W). Values set to True
land_mask: Optional boolean tensor of shape (B, out_H, out_W). Values set to True
will be masked out (set to 0) in the output (only ocean pixels exist).
Returns:
Tensor of shape (B, M, out_H, out_W) representing the monthly variable e.g. SST.
"""
B, MHW, C = latent.shape
B, M, Np, C = latent.shape
Hp = out_H // self.patch_h
Wp = out_W // self.patch_w
assert MHW == M * Hp * Wp, f"Token mismatch: got {MHW}, expected {M * Hp * Wp}"
assert Np == Hp * Wp, f"Token mismatch: got {Np}, expected {Hp * Wp}"

# transforms the latent tensor from sequence format to image format for
# convolution operations; (B, M*Hp*Wp, C) -> (B*M, C, Hp, Wp)
# convolution operations;
out = latent.view(B, M, Hp, Wp, C).permute(0, 1, 4, 2, 3).contiguous()
out = out.view(B * M, C, Hp, Wp)

Expand All @@ -349,7 +352,7 @@ def forward(self, latent, M, out_H, out_W, land_mask=None):

# Mask out land areas if land_mask is provided
if land_mask is not None:
out = out.masked_fill(land_mask.bool()[None, None, :, :], 0.0)
out = out.masked_fill(land_mask.bool()[:, None, :, :], 0.0)
return out # (B, M, out_H, out_W)


Expand Down Expand Up @@ -500,10 +503,11 @@ def __init__(
patch_size=(1, 4, 4),
max_days=31,
max_months=12,
hidden=128,
num_months=12,
hidden=256,
overlap=1,
max_H=1024,
max_W=1024,
max_H=256,
max_W=256,
spatial_depth=2,
spatial_heads=4,
):
Expand All @@ -515,6 +519,7 @@ def __init__(
patch_size: Tuple of (T, H, W) patch sizes for temporal and spatial patching
max_days: Maximum number of days for temporal positional encoding
max_months: Maximum number of months for temporal positional encoding
num_months: Number of months to predict (output channels in decoder)
hidden: Hidden dimension used in the decoder
overlap: Overlap for deconvolution in the decoder
max_H: Maximum spatial height for 2D positional encoding
Expand All @@ -541,7 +546,7 @@ def __init__(
patch_w=patch_size[2],
hidden=hidden,
overlap=overlap,
num_months=max_months,
num_months=num_months,
)
self.patch_size = patch_size

Expand All @@ -552,7 +557,7 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
daily_data: Tensor of shape (B, C, M, T, H, W) containing daily
data, where C is the number of channels (e.g., 1 for SST)
daily_mask: Boolean tensor of same shape as daily_data indicating missing values
land_mask_patch: Boolean tensor of shape (H, W) to mask land areas in the output
land_mask_patch: Boolean tensor of shape (B, H, W) to mask land areas in the output
padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded
(True for padded tokens). Used to mask out padded tokens in temporal attention.
Returns:
Expand All @@ -574,18 +579,6 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
)
assert T % self.patch_size[0] == 0, "T must be divisible by patch size"

# Step 1: Encode spatio-temporal patches
# each month independently by folding M into batch
daily_data_reshaped = daily_data.reshape(B * M, C, T, H, W)
daily_mask_reshaped = daily_mask.reshape(B * M, C, T, H, W)
latent = self.encoder(
daily_data_reshaped, daily_mask_reshaped
) # (B*M, N_patches, embed_dim)

# Step 2: Aggregate temporal information for each spatial patch
# latent -> (B, M*Np, embed_dim) to match the aggregator input x: (B, M*Tp*Hp*Wp, embed_dim)
latent = latent.reshape(B, M * Np, -1)

if padded_days_mask is not None and self.patch_size[0] > 1:
B, M, T_days = padded_days_mask.shape
if T_days % self.patch_size[0] != 0:
Expand All @@ -596,23 +589,45 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
B, M, T_days // self.patch_size[0], self.patch_size[0]
).all(dim=-1) # (B, M, Tp)

# Step 1: Encode spatio-temporal patches
# each month independently by folding M into batch
# encoder input shape = (B, C, T, H, W) where C is channel.
# encoder output shape = (B, N_patches, embed_dim)
# so M is folded into B, and T, H, W are the spatio-temporal dimensions to be patched.
daily_data_reshaped = daily_data.reshape(B * M, C, T, H, W)
daily_mask_reshaped = daily_mask.reshape(B * M, C, T, H, W)

latent = self.encoder(
daily_data_reshaped, daily_mask_reshaped
) # (B*M, N_patches, embed_dim)

# Step 2: Aggregate temporal information for each spatial patch
# temporal input shape = (B, M*T*H*W, C), C: embedding dimension
# temporal output shape = (B, M, H*W, C) C: embedding dimension
embed_dim = latent.shape[-1]
latent = latent.view(B, M, Tp, Hp, Wp, embed_dim)

agg_latent = self.temporal(
latent, M, Tp, Hp, Wp, padded_days_mask=padded_days_mask
) # (B, M*Hp*Wp, embed_dim)
) # (B, M, Hp*Wp, embed_dim)

# Step 3: Add spatial positional encodings and mix spatial features
E = agg_latent.shape[-1]
agg_latent = agg_latent.view(B, M, Hp * Wp, E)
# spatial PE output shape = (Hp, Wp, embed_dim)
pe = (
self.spatial_pe(Hp, Wp).to(agg_latent.device).to(agg_latent.dtype)
) # (Hp*Wp, E)
x = agg_latent + pe[None, None, :, :]
) # (Hp, Wp, E)
x = agg_latent + pe[None, None, :, :] # (B, M, Hp*Wp, E)

# Step 4: Spatial mixing with Transformer
x = x.view(B * M, Hp * Wp, E)
x = self.spatial_tr(x) # (B*M, Hp*Wp, E)
x = x.view(B, M * Hp * Wp, E) # back to (B, M*Hp*Wp, E)
# spatial transformer input shape = (B, N, C), output shape = (B, N, C) C: embedding dimension
# M is folded in B.
C = x.shape[-1]
x = x.reshape(B * M, Hp * Wp, C)
x = self.spatial_tr(x)
x = x.view(B, M, Hp * Wp, C)

# Step 5: Decode to full-resolution 2D map
# decoder input shape is (B, M*Hp*Wp, C), C: embedding dimension
# decoder output shape is (B, M, H, W)
monthly_pred = self.decoder(x, M, H, W, land_mask_patch) # (B, M, H, W)
return monthly_pred
6 changes: 4 additions & 2 deletions climanet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None):
"""
pred: (B, M, H_pad,W_pad) or (B, H, W) torch tensor
orig_H/W: original sizes before padding (optional)
land_mask: (H_pad,W_pad) or (H,W) bool; if given, land will be set to NaN
land_mask: (B, H_pad,W_pad) or (B, H,W) bool; if given, land will be set to NaN
returns: (H,W) numpy array
"""
# crop to original size if provided
Expand All @@ -145,6 +145,8 @@ def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None):
# set land to NaN (broadcast mask across batch)
if land_mask is not None:
pred = pred.clone().to(torch.float32)
pred[:, :, land_mask.bool()] = float("nan")
land_mask = land_mask.bool()
land_mask = land_mask.unsqueeze(1) # (B, H,W) -> (B, 1, H, W) for broadcasting
pred = torch.where(land_mask, torch.full_like(pred, float("nan")), pred)

return pred.detach().cpu().numpy()
Loading