Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
2b3b2ee
Merge pull request #1 from WUR-AI/restructuring-to-hydra
vdplasthijs Dec 10, 2025
de9cbff
Merge pull request #14 from WUR-AI/develop
vdplasthijs Jan 9, 2026
065c8dc
Merge pull request #30 from WUR-AI/develop
gabrieletijunaityte Jan 23, 2026
db8c814
Update minor README
vdplasthijs Jan 26, 2026
b786e15
Merge pull request #31 from WUR-AI/vdplasthijs-patch-1
gabrieletijunaityte Jan 26, 2026
b8a86c5
Merge pull request #39 from WUR-AI/develop
gabrieletijunaityte Jan 29, 2026
ad0cd34
UrbanHeatIslands use case: Guatemala case
BachirNILU Feb 19, 2026
6062168
Update base data (add feat_ + add parameter for spatial splitting)
BachirNILU Feb 19, 2026
3f5705b
Add regression model/head
BachirNILU Feb 19, 2026
b4704c6
Add multimodel encoder (tabular+coords)
BachirNILU Feb 19, 2026
6c20ebc
Readme: add diagram
BachirNILU Feb 19, 2026
d268a29
update comments
BachirNILU Feb 19, 2026
0a12fbe
docs: move diagram to docs/figures/
BachirNILU Feb 19, 2026
c0ea0dd
docs: move diagram to docs/figures/
BachirNILU Feb 19, 2026
d020d27
Merge remote-tracking branch 'origin/develop' into feature/urban-heat…
BachirNILU Feb 19, 2026
4ce0ed8
Removing output_normalization='l2'
BachirNILU Feb 20, 2026
109d307
data config + split_dir fixed
BachirNILU Feb 20, 2026
114b75d
FIx the need and remove split_dir parameter
gabrieletijunaityte Feb 23, 2026
e10e5ad
Merge branch 'develop' into feature/urban-heat-islands
vdplasthijs Feb 23, 2026
30a11fb
fix: wire head, trainable modules, F.normalize, use_features param
BachirNILU Feb 23, 2026
e24ff52
fix: remove redundant _tabular_ready check in __init__
BachirNILU Feb 23, 2026
2044081
Delete split_indices_heat_guatemala_2026-02-20-1148.pth
gabrieletijunaityte Feb 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ Data folders should follow the following directory structure within `DATA_DIR`:
├── s2bms/ <- Dataset folder.
│ ├── model_ready_s2bms.csv <- Csv file with "name_loc" id, locations, aux data and target data.
│ ├── aux_classes.csv <- Csv file with explanations for aux data class names.
│ ├── caption_templates.json <- Json file with list of caption templates (referencing aux column names).
│ ├── caption_templates <- Caption templates
│ ├── v1.json <- Json file with list of caption templates (referencing aux column names).
│ ├── splits/ <- Torch data splits
│ ├── source/ <- Optional: source data used to create model_ready csv.
│ ├── eo/ <- EO data modalities
Expand Down Expand Up @@ -195,6 +196,16 @@ We follow the directory structure from the [Hydra-Lightning template](https://gi
└── README.md
```


## Architecture Overview

The diagram below shows how configuration files, datasets, and model components
relate to each other in the AETHER framework.

![Framework Architecture](docs/figures/diagram.png)



## Attribution

This repo is based on the [Hydra-Lightning template](https://github.com/ashleve/lightning-hydra-template).
Expand Down
24 changes: 24 additions & 0 deletions configs/data/heat_guatemala.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
_target_: src.data.base_datamodule.BaseDataModule

dataset:
_target_: src.data.heat_guatemala_dataset.HeatGuatemalaDataset
data_dir: ${paths.data_dir}
modalities:
coords: {}
use_target_data: true
use_features: true
use_aux_data: false
seed: ${seed}
cache_dir: ${paths.cache_dir}

batch_size: 64
num_workers: 8
pin_memory: true

split_mode: "from_file"
save_split: false
saved_split_file_name: "split_indices_heat_guatemala_2026-02-20-1148.pth"


train_val_test_split: [0.7, 0.15, 0.15]
seed: ${seed}
25 changes: 25 additions & 0 deletions configs/experiment/heat_guatemala_coords_reg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# @package _global_
# configs/experiment/heat_guatemala_coords_reg.yaml
# Variant A: GeoClip coordinates only


defaults:
- override /model: heat_geoclip_reg
- override /data: heat_guatemala

tags: ["heat_island", "guatemala", "coords_only", "regression"]
seed: 12345

trainer:
min_epochs: 1
max_epochs: 50

data:
batch_size: 64

logger:
wandb:
tags: ${tags}
group: "heat_island"
aim:
experiment: "heat_island"
25 changes: 25 additions & 0 deletions configs/experiment/heat_guatemala_fusion_reg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# @package _global_
# configs/experiment/heat_guatemala_fusion_reg.yaml
# Variant C: GeoClip + tabular fusion


defaults:
- override /model: heat_fusion_reg
- override /data: heat_guatemala

tags: ["heat_island", "guatemala", "fusion", "regression"]
seed: 12345

trainer:
min_epochs: 1
max_epochs: 50

data:
batch_size: 64

logger:
wandb:
tags: ${tags}
group: "heat_island"
aim:
experiment: "heat_island"
25 changes: 25 additions & 0 deletions configs/experiment/heat_guatemala_tabular_reg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# @package _global_
# configs/experiment/heat_guatemala_tabular_reg.yaml
# Variant B: tabular features only


defaults:
- override /model: heat_tabular_reg
- override /data: heat_guatemala

tags: ["heat_island", "guatemala", "tabular_only", "regression"]
seed: 12345

trainer:
min_epochs: 1
max_epochs: 50

data:
batch_size: 64

logger:
wandb:
tags: ${tags}
group: "heat_island"
aim:
experiment: "heat_island"
39 changes: 39 additions & 0 deletions configs/model/heat_fusion_reg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# configs/model/heat_fusion_reg.yaml
#
# Renamed from: predictive_geoclip_tabular_regression.yaml
# Reason: "fusion" is concise; "heat_" scopes to this use case.
#
# Variant: GeoClip (coords) + tabular fusion
# Encoder output = GeoClip (512) + tabular projection (64) = 576-dim

_target_: src.models.predictive_model_regression.PredictiveRegressionModel

eo_encoder:
_target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder
use_coords: true
use_tabular: true
tab_embed_dim: 64

prediction_head:
_target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead
nn_layers: 2
hidden_dim: 256

# GeoClip frozen; tabular projection + head are trained.
trainable_modules: [eo_encoder, prediction_head]

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.001
weight_decay: 0.0

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10

loss_fn:
_target_: torch.nn.MSELoss
38 changes: 38 additions & 0 deletions configs/model/heat_geoclip_reg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# configs/model/heat_geoclip_reg.yaml
#
# Renamed from: predictive_geoclip_regression.yaml
# Reason: prefixing with "heat_" scopes it to this use case;
# "geoclip_reg" is concise and descriptive.
#
# Variant: coords only (GeoClip encodes lat/lon → 512-dim embedding)

_target_: src.models.predictive_model_regression.PredictiveRegressionModel

eo_encoder:
_target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder
use_coords: true
use_tabular: false

prediction_head:
_target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead
nn_layers: 2
hidden_dim: 256

# Only the prediction head is trained; GeoClip encoder is frozen.
trainable_modules: [prediction_head]

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.001
weight_decay: 0.0

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10

loss_fn:
_target_: torch.nn.MSELoss
45 changes: 45 additions & 0 deletions configs/model/heat_tabular_reg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# configs/model/heat_tabular_reg.yaml
#
# Renamed from: predictive_tabular_regression.yaml
# Reason: prefixed with "heat_" to scope to this use case.
#
# Variant: tabular only (feat_* CSV columns encoded by a small MLP)
#
# NOTE: tabular_dim is NOT hardcoded here.
# It is resolved automatically at runtime:
# 1. HeatGuatemalaDataset.tabular_dim reads len(feat_names) from the CSV.
# 2. BaseDataModule.tabular_dim delegates to the train dataset.
# 3. PredictiveRegressionModel.setup() calls
# self.eo_encoder.build_tabular_branch(self.trainer.datamodule.tabular_dim)

_target_: src.models.predictive_model_regression.PredictiveRegressionModel

eo_encoder:
_target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder
use_coords: false
use_tabular: true
tab_embed_dim: 64

prediction_head:
_target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead
nn_layers: 2
hidden_dim: 256

# Both encoder and head have trainable parameters.
trainable_modules: [eo_encoder, prediction_head]

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.001
weight_decay: 0.0001

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10

loss_fn:
_target_: torch.nn.MSELoss
Binary file added docs/figures/diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 15 additions & 9 deletions src/data/base_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
saved_split_file_name: str | None = None,
caption_builder: BaseCaptionBuilder = None,
seed: int = 12345,
spatial_split_distance_m: int = 1000,
) -> None:
"""Datamodule class which handles dataset splits and batching.

Expand All @@ -47,6 +48,8 @@ def __init__(
:param save_split: if to save split file
:param saved_split_file_name: file name to save split file
:param caption_builder: instance of BaseCaptionBuilder for generating textual captions
:param spatial_split_distance_m: minimum distance in metres between clusters when
split_mode is 'spatial_clusters'. Default 1000 m.
"""
super().__init__()
self.save_hyperparameters(logger=False)
Expand All @@ -63,6 +66,14 @@ def __init__(

self.split_data()

@property
def tabular_dim(self):
dataset = self.data_train
# Unwrap Subset wrappers (e.g. from random_split)
while hasattr(dataset, "dataset"):
dataset = dataset.dataset
return dataset.tabular_dim

@property
def num_classes(self) -> int:
"""Number of classes in the dataset."""
Expand Down Expand Up @@ -120,7 +131,7 @@ def split_data(self) -> None:
"Warning: DBSCAN clustering on more than 2000 samples may be slow. Maybe set n_jobs in DBScan?"
)
# 4000 m distance between points. Use geodist to calculate true distance.
min_dist = 4000
min_dist = self.hparams.spatial_split_distance_m
clustering = DBSCAN(
eps=min_dist,
metric=lambda u, v: geodist(u, v).meters,
Expand Down Expand Up @@ -244,19 +255,14 @@ def split_data(self) -> None:

def save_split_indices(self, split_indices: dict[str, Any] | dict):
"""Save split indices into file."""
assert (
self.hparams.split_dir is not None
), "split_dir must be provided when saving a new data split."
assert os.path.exists(
self.hparams.split_dir
), f"Directory to save split indices does not exist: {self.hparams.split_dir}"
assert isinstance(split_indices, dict), "split_indices must be a dictionary to be saved."
self.split_dir = os.path.join(self.hparams.dataset.data_dir, "splits")
os.makedirs(self.split_dir, exist_ok=True)

timestamp = create_timestamp()
torch.save(
split_indices,
os.path.join(
self.hparams.split_dir,
self.split_dir,
f"split_indices_{self.hparams.dataset_name}_{timestamp}.pth",
),
)
Expand Down
14 changes: 14 additions & 0 deletions src/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
cache_dir: str = None,
implemented_mod: set[str] = None,
mock: bool = False,
use_features: bool = True,
) -> None:
"""Interface for any use case dataset.

Expand All @@ -45,6 +46,7 @@ def __init__(
:param cache_dir: directory to save cached data
:param implemented_mod: implemented modalities for each dataset
:param mock: whether to mock csv file
:param use_features: if tabular feat_* columns should be included. Default True.
"""

if mock:
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(
self.use_aux_data: bool = use_aux_data
self.records: dict[str, Any] = self.get_records()
self.pooch_cli = None
self.use_features = use_features

@final
def get_records(self) -> dict[str, Any]:
Expand Down Expand Up @@ -112,6 +115,10 @@ def get_records(self) -> dict[str, Any]:
if self.use_aux_data:
self.aux_names = [c for c in self.df.columns if "aux_" in c]
columns.extend(self.aux_names)

# Include tabular features
self.feat_names = [c for c in self.df.columns if c.startswith("feat_")]
columns.extend(self.feat_names)

return self.df.loc[:, columns].to_dict("records")

Expand All @@ -120,6 +127,13 @@ def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.records)

@final
@property
def tabular_dim(self) -> int:
if not self.use_features:
return 0
return len(getattr(self, "feat_names", []))

@abstractmethod
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Returns a single item from the dataset."""
Expand Down
Loading