Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,13 @@ We follow the directory structure from the [Hydra-Lightning template](https://gi
└── README.md
```


## Architecture Overview

The diagram below shows how configuration files, datasets, and model components
relate to each other in the AETHER framework.

![Framework Architecture](docs/figures/diagram.png)



## Attribution

This repo is based on the [Hydra-Lightning template](https://github.com/ashleve/lightning-hydra-template).
Expand Down
3 changes: 1 addition & 2 deletions configs/data/heat_guatemala.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,5 @@ split_mode: "from_file"
save_split: false
saved_split_file_name: "split_indices_heat_guatemala_2026-02-20-1148.pth"


train_val_test_split: [0.7, 0.15, 0.15]
seed: ${seed}
seed: ${seed}
1 change: 1 addition & 0 deletions configs/experiment/alignment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
defaults:
- override /model: geoclip_alignment
- override /data: butterfly_coords_text
- override /metrics: contrastive_similarities

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
Expand Down
1 change: 1 addition & 0 deletions configs/experiment/alignment_llm2clip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
defaults:
- override /model: geoclip_llm2clip_alignment
- override /data: butterfly_coords_text
- override /metrics: contrastive_similarities

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/heat_guatemala_coords_reg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# configs/experiment/heat_guatemala_coords_reg.yaml
# Variant A: GeoClip coordinates only


defaults:
- override /model: heat_geoclip_reg
- override /data: heat_guatemala
- override /metrics: guatemala_regression

tags: ["heat_island", "guatemala", "coords_only", "regression"]
seed: 12345
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/heat_guatemala_fusion_reg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# configs/experiment/heat_guatemala_fusion_reg.yaml
# Variant C: GeoClip + tabular fusion


defaults:
- override /model: heat_fusion_reg
- override /data: heat_guatemala
- override /metrics: guatemala_regression

tags: ["heat_island", "guatemala", "fusion", "regression"]
seed: 12345
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/heat_guatemala_tabular_reg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# configs/experiment/heat_guatemala_tabular_reg.yaml
# Variant B: tabular features only


defaults:
- override /model: heat_tabular_reg
- override /data: heat_guatemala
- override /metrics: guatemala_regression

tags: ["heat_island", "guatemala", "tabular_only", "regression"]
seed: 12345
Expand Down
1 change: 1 addition & 0 deletions configs/experiment/prediction.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
defaults:
- override /model: predictive_geoclip
- override /data: butterfly_coords
- override /metrics: butterfly_predictive

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
Expand Down
6 changes: 6 additions & 0 deletions configs/metrics/butterfly_predictive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_target_: src.models.components.metrics.metrics_wrapper.MetricsWrapper

metrics:
- _target_: src.models.components.loss_fns.mse_loss.MSELoss
- _target_: src.models.components.metrics.top_k_accuracy.TopKAccuracy
k_list: [1, 5, 10]
4 changes: 4 additions & 0 deletions configs/metrics/contrastive_similarities.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: src.models.components.metrics.metrics_wrapper.MetricsWrapper

metrics:
- _target_: src.models.components.metrics.contrastive_similarities.CosineSimilarities
7 changes: 7 additions & 0 deletions configs/metrics/guatemala_regression.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_target_: src.models.components.metrics.metrics_wrapper.MetricsWrapper

metrics:
- _target_: src.models.components.loss_fns.mse_loss.MSELoss
- _target_: src.models.components.loss_fns.rmse_loss.RMSELoss
- _target_: src.models.components.loss_fns.mae_loss.MAELoss
- _target_: src.models.components.metrics.r2.RSquared
2 changes: 2 additions & 0 deletions configs/model/geoclip_alignment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ text_encoder:

trainable_modules: [text_encoder.projector, loss_fn.log_temp]

metrics: ${metrics}

optimizer:
_target_: torch.optim.Adam
_partial_: true
Expand Down
2 changes: 2 additions & 0 deletions configs/model/geoclip_llm2clip_alignment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ text_encoder:

trainable_modules: [text_encoder.projector, loss_fn.log_temp]

metrics: ${metrics}

optimizer:
_target_: torch.optim.Adam
_partial_: true
Expand Down
6 changes: 4 additions & 2 deletions configs/model/heat_fusion_reg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# Variant: GeoClip (coords) + tabular fusion
# Encoder output = GeoClip (512) + tabular projection (64) = 576-dim

_target_: src.models.predictive_model_regression.PredictiveRegressionModel
_target_: src.models.predictive_model.PredictiveModel

eo_encoder:
_target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder
use_coords: true
use_tabular: true
tab_embed_dim: 64
# tab_embed_dim: 64

prediction_head:
_target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead
Expand All @@ -22,6 +22,8 @@ prediction_head:
# GeoClip frozen; tabular projection + head are trained.
trainable_modules: [eo_encoder, prediction_head]

metrics: ${metrics}

optimizer:
_target_: torch.optim.Adam
_partial_: true
Expand Down
4 changes: 3 additions & 1 deletion configs/model/heat_geoclip_reg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
# Variant: coords only (GeoClip encodes lat/lon → 512-dim embedding)

_target_: src.models.predictive_model_regression.PredictiveRegressionModel
_target_: src.models.predictive_model.PredictiveModel

eo_encoder:
_target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder
Expand All @@ -21,6 +21,8 @@ prediction_head:
# Only the prediction head is trained; GeoClip encoder is frozen.
trainable_modules: [prediction_head]

metrics: ${metrics}

optimizer:
_target_: torch.optim.Adam
_partial_: true
Expand Down
4 changes: 3 additions & 1 deletion configs/model/heat_tabular_reg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# 3. PredictiveRegressionModel.setup() calls
# self.eo_encoder.build_tabular_branch(self.trainer.datamodule.tabular_dim)

_target_: src.models.predictive_model_regression.PredictiveRegressionModel
_target_: src.models.predictive_model.PredictiveModel

metrics: ${metrics}

eo_encoder:
_target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder
Expand Down
2 changes: 2 additions & 0 deletions configs/model/predictive_cnn_s2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ prediction_head:

trainable_modules: [eo_encoder, prediction_head]

metrics: ${metrics}

optimizer:
_target_: torch.optim.Adam
_partial_: true
Expand Down
2 changes: 2 additions & 0 deletions configs/model/predictive_geoclip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ prediction_head:

trainable_modules: [prediction_head]

metrics: ${metrics}

optimizer:
_target_: torch.optim.Adam
_partial_: true
Expand Down
3 changes: 2 additions & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: butterfly_coords_text
- data: butterfly_coords
- model: predictive_geoclip
- callbacks: default
- logger: ${oc.env:LOGGER,wandb}
- trainer: ${oc.env:TRAINER_PROFILE,default}
- paths: ${oc.env:STORAGE_MODE,local}
- extras: default
- hydra: default
- metrics: butterfly_predictive

# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
Expand Down
6 changes: 1 addition & 5 deletions src/data/base_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,7 @@ def __init__(

@property
def tabular_dim(self):
dataset = self.data_train
# Unwrap Subset wrappers (e.g. from random_split)
while hasattr(dataset, "dataset"):
dataset = dataset.dataset
return dataset.tabular_dim
return self.dataset.tabular_dim

@property
def num_classes(self) -> int:
Expand Down
25 changes: 11 additions & 14 deletions src/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,17 @@ def __init__(
self.df: pd.DataFrame = pd.read_csv(path_csv)

# Other attributes or placeholders
self.seed = seed
self.pooch_cli = None
self.num_classes = None
self.mode: str = mode # 'train', 'val', 'test'
self.tabular_dim = None
self.seed = seed
self.use_target_data: bool = use_target_data
self.use_aux_data: bool = use_aux_data
self.records: dict[str, Any] = self.get_records()
self.pooch_cli = None
self.use_features = use_features

self.mode: str = mode # 'train', 'val', 'test'
self.records: dict[str, Any] = self.get_records()

@final
def get_records(self) -> dict[str, Any]:
"""Gets record dictionary from the dataframe based on what is needed for the model (aux,
Expand Down Expand Up @@ -115,10 +117,12 @@ def get_records(self) -> dict[str, Any]:
if self.use_aux_data:
self.aux_names = [c for c in self.df.columns if "aux_" in c]
columns.extend(self.aux_names)

# Include tabular features
self.feat_names = [c for c in self.df.columns if c.startswith("feat_")]
columns.extend(self.feat_names)
if self.use_features:
self.feat_names = [c for c in self.df.columns if c.startswith("feat_")]
columns.extend(self.feat_names)
self.tabular_dim = len(self.feat_names)

return self.df.loc[:, columns].to_dict("records")

Expand All @@ -127,13 +131,6 @@ def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.records)

@final
@property
def tabular_dim(self) -> int:
if not self.use_features:
return 0
return len(getattr(self, "feat_names", []))

@abstractmethod
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Returns a single item from the dataset."""
Expand Down
15 changes: 2 additions & 13 deletions src/data/heat_guatemala_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Heat Guatemala LST dataset.
"""Heat Guatemala LST dataset.

Location: src/data/heat_guatemala_dataset.py

Expand All @@ -21,8 +20,7 @@


class HeatGuatemalaDataset(BaseDataset):
"""
Dataset for the urban heat island use case (Guatemala City, LST regression).
"""Dataset for the urban heat island use case (Guatemala City, LST regression).

CSV layout expected (produced by scripts/make_model_ready_heat_guatemala.py):
- name_loc : unique location identifier
Expand Down Expand Up @@ -63,15 +61,6 @@ def __init__(
use_features=use_features,
)

# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------

@property
def tabular_dim(self) -> int:
"""Number of tabular features (feat_* columns). 0 if none."""
return len(self.feat_names)

# ------------------------------------------------------------------
# Required overrides
# ------------------------------------------------------------------
Expand Down
33 changes: 18 additions & 15 deletions src/data_preprocessing/make_model_ready_heat_guatemala.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
Build model-ready CSV for the Heat Guatemala use case (data/heat_guatemala/model_ready_heat_guatemala.csv).
"""
"""Build model-ready CSV for the Heat Guatemala use case
(data/heat_guatemala/model_ready_heat_guatemala.csv)."""

import argparse
import re
Expand Down Expand Up @@ -87,19 +86,23 @@ def main(source_csv: str, out_csv: str, drop_zero_lst: bool = True) -> None:
# ------------------------------------------------------------------ #
# 2. Build output skeleton #
# ------------------------------------------------------------------ #
out = pd.DataFrame({
"name_loc": [f"heat_{i:06d}" for i in range(len(df))],
"lat": df["LAT"].astype(float),
"lon": df["LONG"].astype(float),
"target_lst": df[target_col].astype(float),
})
out = pd.DataFrame(
{
"name_loc": [f"heat_{i:06d}" for i in range(len(df))],
"lat": df["LAT"].astype(float),
"lon": df["LONG"].astype(float),
"target_lst": df[target_col].astype(float),
}
)

# Verify target is clean
assert out["target_lst"].isna().sum() == 0, "BUG: NaN in target after cleaning"
print(f"Target stats: mean={out['target_lst'].mean():.2f} "
f"std={out['target_lst'].std():.2f} "
f"min={out['target_lst'].min():.2f} "
f"max={out['target_lst'].max():.2f}")
print(
f"Target stats: mean={out['target_lst'].mean():.2f} "
f"std={out['target_lst'].std():.2f} "
f"min={out['target_lst'].min():.2f} "
f"max={out['target_lst'].max():.2f}"
)

# ------------------------------------------------------------------ #
# 3. Numeric features — impute or drop #
Expand Down Expand Up @@ -168,7 +171,7 @@ def main(source_csv: str, out_csv: str, drop_zero_lst: bool = True) -> None:
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--source_csv", required=True)
ap.add_argument("--out_csv", required=True)
ap.add_argument("--out_csv", required=True)
ap.add_argument("--drop_zero_lst", type=lambda x: x.lower() != "false", default=True)
args = ap.parse_args()
main(args.source_csv, args.out_csv, args.drop_zero_lst)
main(args.source_csv, args.out_csv, args.drop_zero_lst)
Loading