From 5f38906b2c58d5dfbdbca3b2c0179c43a644ade9 Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 10:47:34 +0100 Subject: [PATCH 01/11] Introduce BaseMetrics and MetricsWrapper classes --- src/models/components/metrics/base_metrics.py | 19 ++++++++++++ .../components/metrics/metrics_wrapper.py | 30 +++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 src/models/components/metrics/base_metrics.py create mode 100644 src/models/components/metrics/metrics_wrapper.py diff --git a/src/models/components/metrics/base_metrics.py b/src/models/components/metrics/base_metrics.py new file mode 100644 index 0000000..0e6b96d --- /dev/null +++ b/src/models/components/metrics/base_metrics.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod +from typing import Dict + +import torch +from torch import nn + + +class BaseMetrics(nn.Module, ABC): + def __init__(self) -> None: + super().__init__() + + @abstractmethod + def forward( + self, pred: torch.Tensor, + labels: torch.Tensor, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs + ) -> Dict[str, torch.float]: + pass diff --git a/src/models/components/metrics/metrics_wrapper.py b/src/models/components/metrics/metrics_wrapper.py new file mode 100644 index 0000000..0ceb3eb --- /dev/null +++ b/src/models/components/metrics/metrics_wrapper.py @@ -0,0 +1,30 @@ +from typing import Dict, List + +import torch +from torch import nn + +from src.models.components.loss_fns.base_loss_fn import BaseLossFn +from src.models.components.metrics.base_metrics import BaseMetrics + + +class MetricsWrapper(nn.Module): + def __init__(self, metrics: List[BaseMetrics | BaseLossFn]) -> None: + super().__init__() + self.metrics = metrics + + def forward( + self, + pred: torch.Tensor, + batch: torch.Tensor, + mode='train', + **kwargs + ) -> Dict[str, torch.float]: + """Calculates all metrics and adds all the results into one dictionary for logging""" + compiled_dict = {} + + for metric in self.metrics: + metric_results = metric(pred, batch, return_label=True, **kwargs) + for k, v in metric_results.items(): + compiled_dict[f'{mode}_{k}'] = v + + return compiled_dict From 19189dbbeac6bebf8ef4568477bbc383d8e5942b Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 10:47:39 +0100 Subject: [PATCH 02/11] Introduce BaseMetrics and MetricsWrapper classes --- src/models/components/metrics/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/models/components/metrics/__init__.py diff --git a/src/models/components/metrics/__init__.py b/src/models/components/metrics/__init__.py new file mode 100644 index 0000000..e69de29 From 543c48ca6dd84d8e18bd5c5c836c58de18631270 Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 15:19:07 +0100 Subject: [PATCH 03/11] New metrics --- src/models/components/loss_fns/mae_loss.py | 29 ++++++++++++ src/models/components/loss_fns/mse_loss.py | 29 ++++++++++++ src/models/components/loss_fns/rmse_loss.py | 29 ++++++++++++ src/models/components/metrics/base_metrics.py | 8 ++-- .../metrics/contrastive_similarities.py | 47 +++++++++++++++++++ .../components/metrics/metrics_wrapper.py | 14 ++---- src/models/components/metrics/r2.py | 28 +++++++++++ 7 files changed, 170 insertions(+), 14 deletions(-) create mode 100644 src/models/components/loss_fns/mae_loss.py create mode 100644 src/models/components/loss_fns/mse_loss.py create mode 100644 src/models/components/loss_fns/rmse_loss.py create mode 100644 src/models/components/metrics/contrastive_similarities.py create mode 100644 src/models/components/metrics/r2.py diff --git a/src/models/components/loss_fns/mae_loss.py b/src/models/components/loss_fns/mae_loss.py new file mode 100644 index 0000000..a0019e2 --- /dev/null +++ b/src/models/components/loss_fns/mae_loss.py @@ -0,0 +1,29 @@ +from typing import Dict, override + +import torch + +from src.models.components.loss_fns.base_loss_fn import BaseLossFn + + +class MAELoss(BaseLossFn): + def __init__(self) -> None: + super().__init__() + self.criterion = torch.nn.L1Loss() + self.name = "mae_loss" + + @override + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + + labels = labels if labels is not None else batch.get("target") + loss = self.criterion(pred, labels) + + if "return_label" in kwargs: + return {self.name: loss} + else: + return loss diff --git a/src/models/components/loss_fns/mse_loss.py b/src/models/components/loss_fns/mse_loss.py new file mode 100644 index 0000000..748fd31 --- /dev/null +++ b/src/models/components/loss_fns/mse_loss.py @@ -0,0 +1,29 @@ +from typing import Dict, override + +import torch + +from src.models.components.loss_fns.base_loss_fn import BaseLossFn + + +class MSELoss(BaseLossFn): + def __init__(self) -> None: + super().__init__() + self.criterion = torch.nn.MSELoss() + self.name = "mse_loss" + + @override + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + + labels = labels if labels is not None else batch.get("target") + mse_loss = self.criterion(pred, labels) + + if "return_label" in kwargs: + return {self.name: mse_loss} + else: + return mse_loss diff --git a/src/models/components/loss_fns/rmse_loss.py b/src/models/components/loss_fns/rmse_loss.py new file mode 100644 index 0000000..45e98a2 --- /dev/null +++ b/src/models/components/loss_fns/rmse_loss.py @@ -0,0 +1,29 @@ +from typing import Dict, override + +import torch + +from src.models.components.loss_fns.base_loss_fn import BaseLossFn + + +class RMSELoss(BaseLossFn): + def __init__(self) -> None: + super().__init__() + self.criterion = torch.nn.MSELoss() + self.name = "rmse_loss" + + @override + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + + labels = labels if labels is not None else batch.get("target") + loss = torch.sqrt(self.criterion(pred, labels)) + + if "return_label" in kwargs: + return {self.name: loss} + else: + return loss diff --git a/src/models/components/metrics/base_metrics.py b/src/models/components/metrics/base_metrics.py index 0e6b96d..1b373bf 100644 --- a/src/models/components/metrics/base_metrics.py +++ b/src/models/components/metrics/base_metrics.py @@ -11,9 +11,9 @@ def __init__(self) -> None: @abstractmethod def forward( - self, pred: torch.Tensor, - labels: torch.Tensor, - batch: Dict[str, torch.Tensor] | None = None, - **kwargs + self, + pred: torch.Tensor, + batch: Dict[str, torch.Tensor], + **kwargs, ) -> Dict[str, torch.float]: pass diff --git a/src/models/components/metrics/contrastive_similarities.py b/src/models/components/metrics/contrastive_similarities.py new file mode 100644 index 0000000..c3aae7e --- /dev/null +++ b/src/models/components/metrics/contrastive_similarities.py @@ -0,0 +1,47 @@ +from typing import override + +import torch +import torch.nn.functional as F + +from src.models.components.metrics.base_metrics import BaseMetrics + + +class CosineSimilarities(BaseMetrics): + def __init__(self, k_list=None) -> None: + super().__init__() + self.k_list = k_list or [1, 5, 10] + + @override + def forward( + self, + mode: str, + eo_feats: torch.Tensor, + text_feats: torch.Tensor, + local_batch_size: int, + **kwargs, + ): + """Calculate cosine similarity between eo and text embeddings and logs it.""" + + # Similarity matrix + cos_sim_matrix = F.cosine_similarity(eo_feats[:, None, :], text_feats[None, :, :], dim=-1) + + # Average for positive and negative pairs + # TODO change label option if we change what gets treated to be pos/neg + id_matrix = torch.eye(cos_sim_matrix.shape[0], dtype=torch.bool) + pos_sim = cos_sim_matrix[id_matrix] + neg_sim = cos_sim_matrix[~id_matrix] + + # Average + avr_sim = torch.mean(cos_sim_matrix) + sub_neg_sim = neg_sim[ + torch.randperm(len(neg_sim))[: len(pos_sim)] + ] # pick same amount of negatives as positives + balanced_sim = torch.cat([pos_sim, sub_neg_sim], dim=0) + balanced_avr_sim = torch.mean(balanced_sim) + + return { + f"{mode}_avr_sim": avr_sim, + f"{mode}_avr_sim_balanced": balanced_avr_sim, + f"{mode}_pos_sim": torch.mean(pos_sim), + f"{mode}_neg_sim": torch.mean(neg_sim), + } diff --git a/src/models/components/metrics/metrics_wrapper.py b/src/models/components/metrics/metrics_wrapper.py index 0ceb3eb..c15f395 100644 --- a/src/models/components/metrics/metrics_wrapper.py +++ b/src/models/components/metrics/metrics_wrapper.py @@ -12,19 +12,13 @@ def __init__(self, metrics: List[BaseMetrics | BaseLossFn]) -> None: super().__init__() self.metrics = metrics - def forward( - self, - pred: torch.Tensor, - batch: torch.Tensor, - mode='train', - **kwargs - ) -> Dict[str, torch.float]: - """Calculates all metrics and adds all the results into one dictionary for logging""" + def forward(self, mode="train", **kwargs) -> Dict[str, torch.float]: + """Calculates all metrics and adds all the results into one dictionary for logging.""" compiled_dict = {} for metric in self.metrics: - metric_results = metric(pred, batch, return_label=True, **kwargs) + metric_results = metric(mode=mode, return_label=True, **kwargs) for k, v in metric_results.items(): - compiled_dict[f'{mode}_{k}'] = v + compiled_dict[f"{mode}_{k}"] = v return compiled_dict diff --git a/src/models/components/metrics/r2.py b/src/models/components/metrics/r2.py new file mode 100644 index 0000000..cd2d38e --- /dev/null +++ b/src/models/components/metrics/r2.py @@ -0,0 +1,28 @@ +from typing import Dict, override + +import torch + +from src.models.components.metrics.base_metrics import BaseMetrics + + +class RSquared(BaseMetrics): + def __init__(self) -> None: + super().__init__() + self.name = "r2" + + @override + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + + labels = labels if labels is not None else batch.get("target") + + ss_res = torch.sum((labels - pred) ** 2) + ss_tot = torch.sum((labels - torch.mean(labels)) ** 2) + 1e-12 + r2 = 1.0 - ss_res / ss_tot + + return {self.name: r2} From d64ae84fe32de8b2b0578c8837b28c5e3ad1e770 Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 15:19:46 +0100 Subject: [PATCH 04/11] Adapt metrics, losses for metrics wrapper --- .../components/loss_fns/base_loss_fn.py | 10 ++++++- src/models/components/loss_fns/bce_loss.py | 20 +++++++++++-- src/models/components/loss_fns/clip_loss.py | 12 ++++++-- .../{loss_fns => metrics}/top_k_accuracy.py | 28 ++++++++++--------- 4 files changed, 50 insertions(+), 20 deletions(-) rename src/models/components/{loss_fns => metrics}/top_k_accuracy.py (68%) diff --git a/src/models/components/loss_fns/base_loss_fn.py b/src/models/components/loss_fns/base_loss_fn.py index 2f993f8..c9a07f9 100644 --- a/src/models/components/loss_fns/base_loss_fn.py +++ b/src/models/components/loss_fns/base_loss_fn.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Dict import torch from torch import nn @@ -8,7 +9,14 @@ class BaseLossFn(nn.Module, ABC): def __init__(self) -> None: super().__init__() self.criterion: nn.Module | None = None + self.name: str | None = None @abstractmethod - def forward(self, pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: pass diff --git a/src/models/components/loss_fns/bce_loss.py b/src/models/components/loss_fns/bce_loss.py index cd378aa..93ccd07 100644 --- a/src/models/components/loss_fns/bce_loss.py +++ b/src/models/components/loss_fns/bce_loss.py @@ -1,4 +1,4 @@ -from typing import override +from typing import Dict, override import torch from torch import nn @@ -10,10 +10,24 @@ class BCELoss(BaseLossFn): def __init__(self) -> None: super().__init__() self.criterion = nn.BCELoss(reduction="mean") + self.name: str = "bce_loss" @override - def forward(self, pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - return self.criterion(pred, labels) + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor or Dict[str, torch.Tensor]: + + labels = labels if labels is not None else batch.get("target") + loss = self.criterion(pred, labels) + + if "return_label" in kwargs: + return {self.name: loss} + else: + return loss if __name__ == "__main__": diff --git a/src/models/components/loss_fns/clip_loss.py b/src/models/components/loss_fns/clip_loss.py index d451ac5..af958bb 100644 --- a/src/models/components/loss_fns/clip_loss.py +++ b/src/models/components/loss_fns/clip_loss.py @@ -1,4 +1,4 @@ -from typing import override +from typing import Dict, override import torch from torch import nn @@ -14,13 +14,15 @@ def __init__( ) -> None: super().__init__() self.log_temp = nn.Parameter(torch.log(torch.tensor(temperature))) + self.name = "CLIPLoss" @override def forward( self, eo_mod: torch.Tensor, text_mod: torch.Tensor, - ) -> torch.Tensor: + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: # Normalise inputs eo_mod = F.normalize(eo_mod, dim=-1) @@ -37,7 +39,11 @@ def forward( loss1 = F.cross_entropy(dot_product, targets) loss2 = F.cross_entropy(dot_product.T, targets) - return (loss1 + loss2) / 2 + loss = (loss1 + loss2) / 2 + if "return_label" in kwargs: + return {self.name: loss} + else: + return loss if __name__ == "__main__": diff --git a/src/models/components/loss_fns/top_k_accuracy.py b/src/models/components/metrics/top_k_accuracy.py similarity index 68% rename from src/models/components/loss_fns/top_k_accuracy.py rename to src/models/components/metrics/top_k_accuracy.py index 0a53e7a..bf9d0d4 100644 --- a/src/models/components/loss_fns/top_k_accuracy.py +++ b/src/models/components/metrics/top_k_accuracy.py @@ -1,18 +1,25 @@ -from typing import override +from typing import Dict, override import torch -import torch.nn.functional as F -from src.models.components.loss_fns.base_loss_fn import BaseLossFn +from src.models.components.metrics.base_metrics import BaseMetrics -class TopKAccuracy(BaseLossFn): - def __init__(self, k_list: list[int] = [1, 5, 10]) -> None: +class TopKAccuracy(BaseMetrics): + def __init__(self, k_list=None) -> None: super().__init__() - self.k_list = k_list + self.k_list = k_list or [1, 5, 10] @override - def forward(self, pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + def forward( + self, + pred: torch.Tensor, + batch: Dict[str, torch.Tensor], + **kwargs, + ) -> Dict[str, torch.float]: + + labels = batch.get("target") + inds_sorted_preds = torch.argsort( pred, dim=1, descending=True ) # dim =1; sort along 2nd dimension (ie per sample) @@ -32,10 +39,5 @@ def forward(self, pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: tmp_joint = tmp_pred_greater_th * tmp_label_greater_th n_present = torch.sum(tmp_joint, dim=1) # sum per batch sample top_k_acc = n_present.float() / k # accuracy per batch sample - accs[k] = top_k_acc.mean() - + accs[f"top_{k}_acc"] = top_k_acc.mean() return accs - - -if __name__ == "__main__": - _ = TopKAccuracy() From d4796d2e45a789fb402ec6d0e1d2175737f523a4 Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 15:20:01 +0100 Subject: [PATCH 05/11] Introduce metrics wrapper as config --- configs/train.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/train.yaml b/configs/train.yaml index 58222ff..93b8f02 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,7 +4,7 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - data: butterfly_coords_text + - data: butterfly_coords - model: predictive_geoclip - callbacks: default - logger: ${oc.env:LOGGER,wandb} @@ -12,6 +12,7 @@ defaults: - paths: ${oc.env:STORAGE_MODE,local} - extras: default - hydra: default + - metrics: butterfly_predictive # experiment configs allow for version control of specific hyperparameters # e.g. best hyperparameters for given model and datamodule From 890ba9f9732278709586914dc4af31c8d4e7046c Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 15:21:38 +0100 Subject: [PATCH 06/11] Format hooks --- src/data/heat_guatemala_dataset.py | 6 ++-- .../make_model_ready_heat_guatemala.py | 33 ++++++++++--------- .../pred_heads/mlp_regression_head.py | 3 +- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/data/heat_guatemala_dataset.py b/src/data/heat_guatemala_dataset.py index 44e81dc..5ed42bf 100644 --- a/src/data/heat_guatemala_dataset.py +++ b/src/data/heat_guatemala_dataset.py @@ -1,5 +1,4 @@ -""" -Heat Guatemala LST dataset. +"""Heat Guatemala LST dataset. Location: src/data/heat_guatemala_dataset.py @@ -21,8 +20,7 @@ class HeatGuatemalaDataset(BaseDataset): - """ - Dataset for the urban heat island use case (Guatemala City, LST regression). + """Dataset for the urban heat island use case (Guatemala City, LST regression). CSV layout expected (produced by scripts/make_model_ready_heat_guatemala.py): - name_loc : unique location identifier diff --git a/src/data_preprocessing/make_model_ready_heat_guatemala.py b/src/data_preprocessing/make_model_ready_heat_guatemala.py index cfc1da9..19d595e 100644 --- a/src/data_preprocessing/make_model_ready_heat_guatemala.py +++ b/src/data_preprocessing/make_model_ready_heat_guatemala.py @@ -1,6 +1,5 @@ -""" -Build model-ready CSV for the Heat Guatemala use case (data/heat_guatemala/model_ready_heat_guatemala.csv). -""" +"""Build model-ready CSV for the Heat Guatemala use case +(data/heat_guatemala/model_ready_heat_guatemala.csv).""" import argparse import re @@ -87,19 +86,23 @@ def main(source_csv: str, out_csv: str, drop_zero_lst: bool = True) -> None: # ------------------------------------------------------------------ # # 2. Build output skeleton # # ------------------------------------------------------------------ # - out = pd.DataFrame({ - "name_loc": [f"heat_{i:06d}" for i in range(len(df))], - "lat": df["LAT"].astype(float), - "lon": df["LONG"].astype(float), - "target_lst": df[target_col].astype(float), - }) + out = pd.DataFrame( + { + "name_loc": [f"heat_{i:06d}" for i in range(len(df))], + "lat": df["LAT"].astype(float), + "lon": df["LONG"].astype(float), + "target_lst": df[target_col].astype(float), + } + ) # Verify target is clean assert out["target_lst"].isna().sum() == 0, "BUG: NaN in target after cleaning" - print(f"Target stats: mean={out['target_lst'].mean():.2f} " - f"std={out['target_lst'].std():.2f} " - f"min={out['target_lst'].min():.2f} " - f"max={out['target_lst'].max():.2f}") + print( + f"Target stats: mean={out['target_lst'].mean():.2f} " + f"std={out['target_lst'].std():.2f} " + f"min={out['target_lst'].min():.2f} " + f"max={out['target_lst'].max():.2f}" + ) # ------------------------------------------------------------------ # # 3. Numeric features — impute or drop # @@ -168,7 +171,7 @@ def main(source_csv: str, out_csv: str, drop_zero_lst: bool = True) -> None: if __name__ == "__main__": ap = argparse.ArgumentParser() ap.add_argument("--source_csv", required=True) - ap.add_argument("--out_csv", required=True) + ap.add_argument("--out_csv", required=True) ap.add_argument("--drop_zero_lst", type=lambda x: x.lower() != "false", default=True) args = ap.parse_args() - main(args.source_csv, args.out_csv, args.drop_zero_lst) \ No newline at end of file + main(args.source_csv, args.out_csv, args.drop_zero_lst) diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index e835b5a..c52509a 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -1,5 +1,4 @@ -""" -MLP regression prediction head. +"""MLP regression prediction head. Renamed from: mlp_reg_pred_head.py → mlp_regression_head.py Location: src/models/components/pred_heads/mlp_regression_head.py From 3efc3d982a51e4b8a2d0db072a80ab5f12f02a2d Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 15:22:07 +0100 Subject: [PATCH 07/11] Format hooks --- README.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/README.md b/README.md index 8cd772d..c560cdf 100644 --- a/README.md +++ b/README.md @@ -196,7 +196,6 @@ We follow the directory structure from the [Hydra-Lightning template](https://gi └── README.md ``` - ## Architecture Overview The diagram below shows how configuration files, datasets, and model components @@ -204,8 +203,6 @@ relate to each other in the AETHER framework. ![Framework Architecture](docs/figures/diagram.png) - - ## Attribution This repo is based on the [Hydra-Lightning template](https://github.com/ashleve/lightning-hydra-template). From 2dc3858fc4cbca5ec404e263eee9703a15bd884a Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 15:22:28 +0100 Subject: [PATCH 08/11] Format hooks --- .../eo_encoders/multimodal_encoder.py | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/src/models/components/eo_encoders/multimodal_encoder.py b/src/models/components/eo_encoders/multimodal_encoder.py index c8f6603..7dbf61d 100644 --- a/src/models/components/eo_encoders/multimodal_encoder.py +++ b/src/models/components/eo_encoders/multimodal_encoder.py @@ -1,11 +1,8 @@ -""" -Unified multimodal encoder for EO data. - +"""Unified multimodal encoder for EO data. Controlled entirely via constructor flags: - use_coords: encode lat/lon with GeoClip - use_tabular: encode feat_* tabular columns - """ from typing import Dict, override @@ -19,9 +16,9 @@ class MultiModalEncoder(BaseEOEncoder): """ - - coords only (use_coords=True, use_tabular=False) - - tabular only (use_coords=False, use_tabular=True) - - coords + tabular (use_coords=True, use_tabular=True) + - coords only (use_coords=True, use_tabular=False) + - tabular only (use_coords=False, use_tabular=True) + - coords + tabular (use_coords=True, use_tabular=True) """ def __init__( @@ -29,13 +26,11 @@ def __init__( use_coords: bool = True, use_tabular: bool = False, tab_embed_dim: int = 64, - tabular_dim: int = None, + tabular_dim: int = None, ) -> None: super().__init__() - assert use_coords or use_tabular, ( - "At least one of use_coords or use_tabular must be True." - ) + assert use_coords or use_tabular, "At least one of use_coords or use_tabular must be True." self.use_coords = use_coords self.use_tabular = use_tabular @@ -45,7 +40,7 @@ def __init__( coords_dim = 0 if use_coords: self.coords_encoder = GeoClipCoordinateEncoder() - coords_dim = self.coords_encoder.output_dim # 512 + coords_dim = self.coords_encoder.output_dim # 512 self._coords_dim = coords_dim @@ -62,12 +57,10 @@ def __init__( # ------------------------------------------------------------------ def build_tabular_branch(self, tabular_dim: int) -> None: - """ - Build (or rebuild) the tabular projection layer. - """ + """Build (or rebuild) the tabular projection layer.""" if self._tabular_ready and hasattr(self, "_last_tabular_dim"): if self._last_tabular_dim == tabular_dim: - return # already built with correct dim + return # already built with correct dim self.tabular_proj = nn.Sequential( nn.LayerNorm(tabular_dim), @@ -87,14 +80,14 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: parts = [] if self.use_coords: - parts.append(self.coords_encoder(batch)) # (B, 512) + parts.append(self.coords_encoder(batch)) # (B, 512) if self.use_tabular: assert self._tabular_ready, ( "Tabular branch not built yet. Call build_tabular_branch(tabular_dim) first, " "or pass tabular_dim to the constructor." ) - tab = batch["eo"]["tabular"].float() # (B, tabular_dim) - parts.append(self.tabular_proj(tab)) # (B, tab_embed_dim) + tab = batch["eo"]["tabular"].float() # (B, tabular_dim) + parts.append(self.tabular_proj(tab)) # (B, tab_embed_dim) return torch.cat(parts, dim=-1) From 74d6e6d09abb492ad389262b505aeb07116e097e Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 15:22:53 +0100 Subject: [PATCH 09/11] De-duplicate regression predictive model --- src/data/heat_guatemala_dataset.py | 9 -- src/models/base_model.py | 22 ++-- src/models/predictive_model.py | 35 ++++-- src/models/predictive_model_regression.py | 123 ---------------------- src/models/text_alignment_model.py | 119 +++++++-------------- src/train.py | 4 +- 6 files changed, 78 insertions(+), 234 deletions(-) delete mode 100644 src/models/predictive_model_regression.py diff --git a/src/data/heat_guatemala_dataset.py b/src/data/heat_guatemala_dataset.py index 5ed42bf..3099b44 100644 --- a/src/data/heat_guatemala_dataset.py +++ b/src/data/heat_guatemala_dataset.py @@ -61,15 +61,6 @@ def __init__( use_features=use_features, ) - # ------------------------------------------------------------------ - # Properties - # ------------------------------------------------------------------ - - @property - def tabular_dim(self) -> int: - """Number of tabular features (feat_* columns). 0 if none.""" - return len(self.feat_names) - # ------------------------------------------------------------------ # Required overrides # ------------------------------------------------------------------ diff --git a/src/models/base_model.py b/src/models/base_model.py index 3a7da11..407d08a 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -5,6 +5,7 @@ from lightning import LightningModule from src.models.components.loss_fns.base_loss_fn import BaseLossFn +from src.models.components.metrics.metrics_wrapper import MetricsWrapper class BaseModel(LightningModule, ABC): @@ -14,26 +15,29 @@ def __init__( optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, + metrics: MetricsWrapper, num_classes: int | None = None, + tabular_dim: int | None = None, ) -> None: """Interface for any model. - :param trainable_modules: - :param optimizer: - :param scheduler: - :param loss_fn: - :param num_classes: + :param trainable_modules: which modules to train + :param optimizer: optimizer for the model weight update + :param scheduler: scheduler for the model weight update + :param loss_fn: loss function + :param metrics: metrics to track for model performance estimation + :param num_classes: number of classes to predict """ super().__init__() self.save_hyperparameters( - ignore=["loss_fn", "eo_encoder", "prediction_head", "text_encoder"] + ignore=["loss_fn", "eo_encoder", "prediction_head", "text_encoder", "metrics"] ) self.trainable_modules = trainable_modules - self.num_classes: int = num_classes - - # Loss + self.num_classes = num_classes + self.tabular_dim = tabular_dim self.loss_fn = loss_fn + self.metrics = metrics @final def freezer(self) -> None: diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index b9b1f5f..f46a4dc 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -5,8 +5,9 @@ from src.models.base_model import BaseModel from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder +from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn -from src.models.components.loss_fns.top_k_accuracy import TopKAccuracy +from src.models.components.metrics.metrics_wrapper import MetricsWrapper from src.models.components.pred_heads.linear_pred_head import ( BasePredictionHead, ) @@ -21,7 +22,9 @@ def __init__( optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, + metrics: MetricsWrapper, num_classes: int | None = None, + tabular_dim: int | None = None, ) -> None: """Implementation of the predictive model with replaceable EO encoder, and prediction head. @@ -31,12 +34,26 @@ def __init__( :param optimizer: optimizer to use for training :param scheduler: scheduler to use for training :param loss_fn: loss function to use + :param metrics: metrics to use for model performance evaluation :param num_classes: number of target classes + :param tabular_dim: number of tabular features """ - super().__init__(trainable_modules, optimizer, scheduler, loss_fn, num_classes) + + super().__init__( + trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim + ) + # EO encoder configuration self.eo_encoder = eo_encoder + # TODO: move to multi-modal eo encoder + if ( + isinstance(self.eo_encoder, MultiModalEncoder) + and self.eo_encoder.use_tabular + and not self.eo_encoder._tabular_ready + ): + self.eo_encoder.build_tabular_branch(tabular_dim) + # Prediction head self.prediction_head = prediction_head self.prediction_head.set_dim(input_dim=self.eo_encoder.output_dim, output_dim=num_classes) @@ -56,15 +73,15 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: @override def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: feats = self.forward(batch) + + log_kwargs = dict( + on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=feats.size(0) + ) loss = self.loss_fn(feats, batch.get("target")) + self.log(f"{mode}_loss", loss, **log_kwargs) - self.log(f"{mode}_loss", loss, on_step=False, on_epoch=True, prog_bar=True) - mse_loss = F.mse_loss(feats, batch.get("target")) - self.log(f"{mode}_mse_loss", mse_loss, on_step=False, on_epoch=True) - top_k_accs = TopKAccuracy(k_list=[1, 5, 10])(feats, batch.get("target")) - for k, acc in top_k_accs.items(): - self.log(f"{mode}_top_{k}_acc", acc, on_step=False, on_epoch=True) - return loss + metrics = self.metrics(pred=feats, batch=batch, mode=mode) + self.log_dict(metrics, **log_kwargs) if __name__ == "__main__": diff --git a/src/models/predictive_model_regression.py b/src/models/predictive_model_regression.py deleted file mode 100644 index 764f472..0000000 --- a/src/models/predictive_model_regression.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -Regression variant of the predictive model (MSE / MAE / RMSE / R²). - -Location: src/models/predictive_model_regression.py - -Key changes vs original: - - setup(stage) injects tabular_dim into MultiModalEncoder automatically, - so tabular_dim never needs to be hardcoded in any config. - - num_classes defaults to 1 (single LST value). - - All regression metrics (MSE, RMSE, MAE, R²) are logged per split. -""" - -from typing import Dict, override - -import torch -import torch.nn.functional as F - -from src.models.base_model import BaseModel -from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder - - -class PredictiveRegressionModel(BaseModel): - - def __init__( - self, - eo_encoder, - prediction_head, - trainable_modules, - optimizer, - scheduler, - loss_fn, - num_classes: int = 1, - **kwargs, - ): - super().__init__( - trainable_modules=trainable_modules, - optimizer=optimizer, - scheduler=scheduler, - loss_fn=loss_fn, - num_classes=num_classes, - ) - - self.eo_encoder = eo_encoder - self.prediction_head = prediction_head - - # Prediction head wiring happens AFTER setup() resolves tabular_dim. - # If the encoder does NOT need tabular data, we can wire immediately. - if not ( - isinstance(self.eo_encoder, MultiModalEncoder) - and self.eo_encoder.use_tabular - ): - self._wire_head() - self.freezer() - - # ------------------------------------------------------------------ - # Lightning hooks - # ------------------------------------------------------------------ - - def setup(self, stage: str) -> None: - """ - Called by Lightning after the datamodule is ready. - Injects tabular_dim into the encoder if it needs it, - then wires the prediction head dimensions. - """ - if ( - isinstance(self.eo_encoder, MultiModalEncoder) - and self.eo_encoder.use_tabular - and not self.eo_encoder._tabular_ready - ): - # Pull tabular_dim from the datamodule — no hardcoding needed! - tabular_dim = self.trainer.datamodule.tabular_dim - self.eo_encoder.build_tabular_branch(tabular_dim) - self._wire_head() - - self.freezer() - - # ------------------------------------------------------------------ - # Internals - # ------------------------------------------------------------------ - - def _wire_head(self) -> None: - """Connect encoder output_dim → head input_dim, then build head layers.""" - self.prediction_head.set_dim( - input_dim=self.eo_encoder.output_dim, - output_dim=self.num_classes, - ) - self.prediction_head.configure_nn() - if "prediction_head" not in self.trainable_modules: - self.trainable_modules.append("prediction_head") - - # ------------------------------------------------------------------ - # Forward & step - # ------------------------------------------------------------------ - - @override - def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - feats = self.eo_encoder(batch) - feats = F.normalize(feats, dim=-1) - return self.prediction_head(feats) - - @override - def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: - y_hat = self.forward(batch) - y = batch["target"] - - loss = self.loss_fn(y_hat, y) - - mse = F.mse_loss(y_hat, y) - rmse = torch.sqrt(mse) - mae = F.l1_loss(y_hat, y) - - ss_res = torch.sum((y - y_hat) ** 2) - ss_tot = torch.sum((y - torch.mean(y)) ** 2) + 1e-12 - r2 = 1.0 - ss_res / ss_tot - - log_kwargs = dict(on_step=False, on_epoch=True) - self.log(f"{mode}_loss", loss, prog_bar=True, **log_kwargs) - self.log(f"{mode}_mse", mse, **log_kwargs) - self.log(f"{mode}_rmse", rmse, **log_kwargs) - self.log(f"{mode}_mae", mae, **log_kwargs) - self.log(f"{mode}_r2", r2, **log_kwargs) - - return loss diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index cb2e04e..21b476a 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -5,7 +5,9 @@ from src.models.base_model import BaseModel from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder +from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn +from src.models.components.metrics.metrics_wrapper import MetricsWrapper from src.models.components.pred_heads.linear_pred_head import ( BasePredictionHead, ) @@ -22,9 +24,11 @@ def __init__( optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, - trainable_modules: list[str] | None = None, - prediction_head: BasePredictionHead | None = None, + trainable_modules: list[str], + metrics: MetricsWrapper, num_classes: int | None = None, + tabular_dim: int | None = None, + prediction_head: BasePredictionHead | None = None, ) -> None: """Implementation of contrastive text-eo modality alignment model. @@ -34,13 +38,25 @@ def __init__( :param scheduler: scheduler to use for training :param loss_fn: loss function to use (contrastive) :param trainable_modules: list of modules to train (parts/modules or modules, modules) - :param prediction_head: optional prediction head module + :param metrics: metrics to use for model performance evaluation :param num_classes: number of target classes + :param tabular_dim: number of tabular features + :param prediction_head: prediction head """ - super().__init__(trainable_modules, optimizer, scheduler, loss_fn, num_classes) + super().__init__( + trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim + ) # Encoders configuration self.eo_encoder = eo_encoder + # TODO: move to multi-modal eo encoder + if ( + isinstance(self.eo_encoder, MultiModalEncoder) + and self.eo_encoder.use_tabular + and not self.eo_encoder._tabular_ready + ): + self.eo_encoder.build_tabular_branch(tabular_dim) + self.text_encoder = text_encoder # TODO: if eo==geoclip_img pass on shared mlp @@ -93,94 +109,31 @@ def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Te feats = feats.reshape(2, -1, feats.size(-1)) eo_feats, text_feats = feats[0], feats[1] + # Get loss + loss = self.loss_fn(eo_feats, text_feats) + # Get similarities with torch.no_grad(): - _ = self._cos_sim_calc(eo_feats, text_feats, mode) + metrics = self.metrics( + mode=mode, + eo_feats=eo_feats, + text_feats=text_feats, + local_batch_size=local_batch_size, + ) - # Get loss - loss = self.loss_fn(eo_feats, text_feats) - self.log( - f"{mode}_loss", - loss, + # Logging + log_kwargs = dict( on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=local_batch_size, ) - if self.loss_fn.__getattr__("log_temp") and mode == "train": - self.log( - "temp", - self.loss_fn.__getattr__("log_temp").exp(), - on_step=False, - on_epoch=True, - prog_bar=True, - sync_dist=True, - batch_size=local_batch_size, - ) + self.log(f"{mode}_loss", loss, **log_kwargs) - return loss - - def _cos_sim_calc(self, eo_feats, text_feats, mode, log=True): - """Calculate cosine similarity between eo and text embeddings and logs it.""" - # Similarity matrix - cos_sim_matrix = F.cosine_similarity(eo_feats[:, None, :], text_feats[None, :, :], dim=-1) - - local_batch_size = eo_feats.size(0) - - # Average for positive and negative pairs - # TODO change label option if we change what gets treated to be pos/neg - id_matrix = torch.eye(cos_sim_matrix.shape[0], dtype=torch.bool) - pos_sim = cos_sim_matrix[id_matrix] - neg_sim = cos_sim_matrix[~id_matrix] - - # Average - avr_sim = torch.mean(cos_sim_matrix) - sub_neg_sim = neg_sim[ - torch.randperm(len(neg_sim))[: len(pos_sim)] - ] # pick same amount of negatives as positives - balanced_sim = torch.cat([pos_sim, sub_neg_sim], dim=0) - balanced_avr_sim = torch.mean(balanced_sim) - - if log: - self.log( - f"{mode}_avr_sim", - avr_sim, - on_step=False, - on_epoch=True, - prog_bar=True, - sync_dist=True, - batch_size=local_batch_size, - ) - self.log( - f"{mode}_avr_sim_balanced", - balanced_avr_sim, - on_step=False, - on_epoch=True, - prog_bar=True, - sync_dist=True, - batch_size=local_batch_size, - ) - self.log( - f"{mode}_pos_sim", - torch.mean(pos_sim), - on_step=False, - on_epoch=True, - prog_bar=True, - sync_dist=True, - batch_size=local_batch_size, - ) - self.log( - f"{mode}_neg_sim", - torch.mean(neg_sim), - on_step=False, - on_epoch=True, - prog_bar=True, - sync_dist=True, - batch_size=local_batch_size, - ) - return avr_sim, pos_sim, neg_sim + if self.loss_fn.__getattr__("log_temp") and mode == "train": + self.log("temp", self.loss_fn.__getattr__("log_temp").exp(), **log_kwargs) + self.log_dict(metrics, **log_kwargs) -if __name__ == "__main__": - _ = TextAlignmentModel(None, None, None, None, None, None, None) + return loss diff --git a/src/train.py b/src/train.py index 4c09a8c..34f4a99 100644 --- a/src/train.py +++ b/src/train.py @@ -51,7 +51,9 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: datamodule: BaseDataModule = hydra.utils.instantiate(cfg.data) log.info(f"Instantiating model <{cfg.model._target_}>") - model: LightningModule = hydra.utils.instantiate(cfg.model, num_classes=datamodule.num_classes) + model: LightningModule = hydra.utils.instantiate( + cfg.model, num_classes=datamodule.num_classes, tabular_dim=datamodule.tabular_dim + ) log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) From 254f5ed780ea070f26b19a1d027c081b663cfec8 Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 15:23:21 +0100 Subject: [PATCH 10/11] Metrics as configs, use predictive model instead of predictive regression model --- configs/data/heat_guatemala.yaml | 3 +-- configs/experiment/alignment.yaml | 1 + configs/experiment/alignment_llm2clip.yaml | 1 + configs/experiment/heat_guatemala_coords_reg.yaml | 2 +- configs/experiment/heat_guatemala_fusion_reg.yaml | 2 +- configs/experiment/heat_guatemala_tabular_reg.yaml | 2 +- configs/experiment/prediction.yaml | 1 + configs/metrics/butterfly_predictive.yaml | 6 ++++++ configs/metrics/contrastive_similarities.yaml | 4 ++++ configs/metrics/guatemala_regression.yaml | 7 +++++++ configs/model/geoclip_alignment.yaml | 2 ++ configs/model/geoclip_llm2clip_alignment.yaml | 2 ++ configs/model/heat_fusion_reg.yaml | 6 ++++-- configs/model/heat_geoclip_reg.yaml | 4 +++- configs/model/heat_tabular_reg.yaml | 4 +++- configs/model/predictive_cnn_s2.yaml | 2 ++ configs/model/predictive_geoclip.yaml | 2 ++ 17 files changed, 42 insertions(+), 9 deletions(-) create mode 100644 configs/metrics/butterfly_predictive.yaml create mode 100644 configs/metrics/contrastive_similarities.yaml create mode 100644 configs/metrics/guatemala_regression.yaml diff --git a/configs/data/heat_guatemala.yaml b/configs/data/heat_guatemala.yaml index eb42022..271fd65 100644 --- a/configs/data/heat_guatemala.yaml +++ b/configs/data/heat_guatemala.yaml @@ -19,6 +19,5 @@ split_mode: "from_file" save_split: false saved_split_file_name: "split_indices_heat_guatemala_2026-02-20-1148.pth" - train_val_test_split: [0.7, 0.15, 0.15] -seed: ${seed} \ No newline at end of file +seed: ${seed} diff --git a/configs/experiment/alignment.yaml b/configs/experiment/alignment.yaml index a649ca7..d6281f8 100644 --- a/configs/experiment/alignment.yaml +++ b/configs/experiment/alignment.yaml @@ -6,6 +6,7 @@ defaults: - override /model: geoclip_alignment - override /data: butterfly_coords_text + - override /metrics: contrastive_similarities # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters diff --git a/configs/experiment/alignment_llm2clip.yaml b/configs/experiment/alignment_llm2clip.yaml index 3311990..2f841ab 100644 --- a/configs/experiment/alignment_llm2clip.yaml +++ b/configs/experiment/alignment_llm2clip.yaml @@ -6,6 +6,7 @@ defaults: - override /model: geoclip_llm2clip_alignment - override /data: butterfly_coords_text + - override /metrics: contrastive_similarities # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters diff --git a/configs/experiment/heat_guatemala_coords_reg.yaml b/configs/experiment/heat_guatemala_coords_reg.yaml index f335d38..3789370 100644 --- a/configs/experiment/heat_guatemala_coords_reg.yaml +++ b/configs/experiment/heat_guatemala_coords_reg.yaml @@ -2,10 +2,10 @@ # configs/experiment/heat_guatemala_coords_reg.yaml # Variant A: GeoClip coordinates only - defaults: - override /model: heat_geoclip_reg - override /data: heat_guatemala + - override /metrics: guatemala_regression tags: ["heat_island", "guatemala", "coords_only", "regression"] seed: 12345 diff --git a/configs/experiment/heat_guatemala_fusion_reg.yaml b/configs/experiment/heat_guatemala_fusion_reg.yaml index 4ab69f5..1c42ffa 100644 --- a/configs/experiment/heat_guatemala_fusion_reg.yaml +++ b/configs/experiment/heat_guatemala_fusion_reg.yaml @@ -2,10 +2,10 @@ # configs/experiment/heat_guatemala_fusion_reg.yaml # Variant C: GeoClip + tabular fusion - defaults: - override /model: heat_fusion_reg - override /data: heat_guatemala + - override /metrics: guatemala_regression tags: ["heat_island", "guatemala", "fusion", "regression"] seed: 12345 diff --git a/configs/experiment/heat_guatemala_tabular_reg.yaml b/configs/experiment/heat_guatemala_tabular_reg.yaml index 740b125..4c14499 100644 --- a/configs/experiment/heat_guatemala_tabular_reg.yaml +++ b/configs/experiment/heat_guatemala_tabular_reg.yaml @@ -2,10 +2,10 @@ # configs/experiment/heat_guatemala_tabular_reg.yaml # Variant B: tabular features only - defaults: - override /model: heat_tabular_reg - override /data: heat_guatemala + - override /metrics: guatemala_regression tags: ["heat_island", "guatemala", "tabular_only", "regression"] seed: 12345 diff --git a/configs/experiment/prediction.yaml b/configs/experiment/prediction.yaml index 18ccd52..63839ca 100644 --- a/configs/experiment/prediction.yaml +++ b/configs/experiment/prediction.yaml @@ -6,6 +6,7 @@ defaults: - override /model: predictive_geoclip - override /data: butterfly_coords + - override /metrics: butterfly_predictive # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters diff --git a/configs/metrics/butterfly_predictive.yaml b/configs/metrics/butterfly_predictive.yaml new file mode 100644 index 0000000..5e9d2e4 --- /dev/null +++ b/configs/metrics/butterfly_predictive.yaml @@ -0,0 +1,6 @@ +_target_: src.models.components.metrics.metrics_wrapper.MetricsWrapper + +metrics: + - _target_: src.models.components.loss_fns.mse_loss.MSELoss + - _target_: src.models.components.metrics.top_k_accuracy.TopKAccuracy + k_list: [1, 5, 10] diff --git a/configs/metrics/contrastive_similarities.yaml b/configs/metrics/contrastive_similarities.yaml new file mode 100644 index 0000000..842dda2 --- /dev/null +++ b/configs/metrics/contrastive_similarities.yaml @@ -0,0 +1,4 @@ +_target_: src.models.components.metrics.metrics_wrapper.MetricsWrapper + +metrics: + - _target_: src.models.components.metrics.contrastive_similarities.CosineSimilarities diff --git a/configs/metrics/guatemala_regression.yaml b/configs/metrics/guatemala_regression.yaml new file mode 100644 index 0000000..79c441d --- /dev/null +++ b/configs/metrics/guatemala_regression.yaml @@ -0,0 +1,7 @@ +_target_: src.models.components.metrics.metrics_wrapper.MetricsWrapper + +metrics: + - _target_: src.models.components.loss_fns.mse_loss.MSELoss + - _target_: src.models.components.loss_fns.rmse_loss.RMSELoss + - _target_: src.models.components.loss_fns.mae_loss.MAELoss + - _target_: src.models.components.metrics.r2.RSquared diff --git a/configs/model/geoclip_alignment.yaml b/configs/model/geoclip_alignment.yaml index b021cb2..0753e79 100644 --- a/configs/model/geoclip_alignment.yaml +++ b/configs/model/geoclip_alignment.yaml @@ -9,6 +9,8 @@ text_encoder: trainable_modules: [text_encoder.projector, loss_fn.log_temp] +metrics: ${metrics} + optimizer: _target_: torch.optim.Adam _partial_: true diff --git a/configs/model/geoclip_llm2clip_alignment.yaml b/configs/model/geoclip_llm2clip_alignment.yaml index 45de482..76f8878 100644 --- a/configs/model/geoclip_llm2clip_alignment.yaml +++ b/configs/model/geoclip_llm2clip_alignment.yaml @@ -9,6 +9,8 @@ text_encoder: trainable_modules: [text_encoder.projector, loss_fn.log_temp] +metrics: ${metrics} + optimizer: _target_: torch.optim.Adam _partial_: true diff --git a/configs/model/heat_fusion_reg.yaml b/configs/model/heat_fusion_reg.yaml index 6a554c5..e507dd1 100644 --- a/configs/model/heat_fusion_reg.yaml +++ b/configs/model/heat_fusion_reg.yaml @@ -6,13 +6,13 @@ # Variant: GeoClip (coords) + tabular fusion # Encoder output = GeoClip (512) + tabular projection (64) = 576-dim -_target_: src.models.predictive_model_regression.PredictiveRegressionModel +_target_: src.models.predictive_model.PredictiveModel eo_encoder: _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder use_coords: true use_tabular: true - tab_embed_dim: 64 +# tab_embed_dim: 64 prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead @@ -22,6 +22,8 @@ prediction_head: # GeoClip frozen; tabular projection + head are trained. trainable_modules: [eo_encoder, prediction_head] +metrics: ${metrics} + optimizer: _target_: torch.optim.Adam _partial_: true diff --git a/configs/model/heat_geoclip_reg.yaml b/configs/model/heat_geoclip_reg.yaml index eeb5f9e..a33b976 100644 --- a/configs/model/heat_geoclip_reg.yaml +++ b/configs/model/heat_geoclip_reg.yaml @@ -6,7 +6,7 @@ # # Variant: coords only (GeoClip encodes lat/lon → 512-dim embedding) -_target_: src.models.predictive_model_regression.PredictiveRegressionModel +_target_: src.models.predictive_model.PredictiveModel eo_encoder: _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder @@ -21,6 +21,8 @@ prediction_head: # Only the prediction head is trained; GeoClip encoder is frozen. trainable_modules: [prediction_head] +metrics: ${metrics} + optimizer: _target_: torch.optim.Adam _partial_: true diff --git a/configs/model/heat_tabular_reg.yaml b/configs/model/heat_tabular_reg.yaml index edd8dd3..affadab 100644 --- a/configs/model/heat_tabular_reg.yaml +++ b/configs/model/heat_tabular_reg.yaml @@ -12,7 +12,9 @@ # 3. PredictiveRegressionModel.setup() calls # self.eo_encoder.build_tabular_branch(self.trainer.datamodule.tabular_dim) -_target_: src.models.predictive_model_regression.PredictiveRegressionModel +_target_: src.models.predictive_model.PredictiveModel + +metrics: ${metrics} eo_encoder: _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder diff --git a/configs/model/predictive_cnn_s2.yaml b/configs/model/predictive_cnn_s2.yaml index df1db6d..628ba66 100644 --- a/configs/model/predictive_cnn_s2.yaml +++ b/configs/model/predictive_cnn_s2.yaml @@ -11,6 +11,8 @@ prediction_head: trainable_modules: [eo_encoder, prediction_head] +metrics: ${metrics} + optimizer: _target_: torch.optim.Adam _partial_: true diff --git a/configs/model/predictive_geoclip.yaml b/configs/model/predictive_geoclip.yaml index f9a2df0..ca7390f 100644 --- a/configs/model/predictive_geoclip.yaml +++ b/configs/model/predictive_geoclip.yaml @@ -8,6 +8,8 @@ prediction_head: trainable_modules: [prediction_head] +metrics: ${metrics} + optimizer: _target_: torch.optim.Adam _partial_: true From f2df046492eac4115b2bc013fb3664f8d12de0ce Mon Sep 17 00:00:00 2001 From: gabriele Date: Tue, 24 Feb 2026 15:23:51 +0100 Subject: [PATCH 11/11] Tabular dimensions introduced earlier to remove need for wiring --- src/data/base_datamodule.py | 6 +----- src/data/base_dataset.py | 25 +++++++++++-------------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/src/data/base_datamodule.py b/src/data/base_datamodule.py index 2c94f6e..3b639ec 100644 --- a/src/data/base_datamodule.py +++ b/src/data/base_datamodule.py @@ -68,11 +68,7 @@ def __init__( @property def tabular_dim(self): - dataset = self.data_train - # Unwrap Subset wrappers (e.g. from random_split) - while hasattr(dataset, "dataset"): - dataset = dataset.dataset - return dataset.tabular_dim + return self.dataset.tabular_dim @property def num_classes(self) -> int: diff --git a/src/data/base_dataset.py b/src/data/base_dataset.py index 258fb43..ae669d2 100644 --- a/src/data/base_dataset.py +++ b/src/data/base_dataset.py @@ -74,15 +74,17 @@ def __init__( self.df: pd.DataFrame = pd.read_csv(path_csv) # Other attributes or placeholders - self.seed = seed + self.pooch_cli = None self.num_classes = None - self.mode: str = mode # 'train', 'val', 'test' + self.tabular_dim = None + self.seed = seed self.use_target_data: bool = use_target_data self.use_aux_data: bool = use_aux_data - self.records: dict[str, Any] = self.get_records() - self.pooch_cli = None self.use_features = use_features + self.mode: str = mode # 'train', 'val', 'test' + self.records: dict[str, Any] = self.get_records() + @final def get_records(self) -> dict[str, Any]: """Gets record dictionary from the dataframe based on what is needed for the model (aux, @@ -115,10 +117,12 @@ def get_records(self) -> dict[str, Any]: if self.use_aux_data: self.aux_names = [c for c in self.df.columns if "aux_" in c] columns.extend(self.aux_names) - + # Include tabular features - self.feat_names = [c for c in self.df.columns if c.startswith("feat_")] - columns.extend(self.feat_names) + if self.use_features: + self.feat_names = [c for c in self.df.columns if c.startswith("feat_")] + columns.extend(self.feat_names) + self.tabular_dim = len(self.feat_names) return self.df.loc[:, columns].to_dict("records") @@ -127,13 +131,6 @@ def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.records) - @final - @property - def tabular_dim(self) -> int: - if not self.use_features: - return 0 - return len(getattr(self, "feat_names", [])) - @abstractmethod def __getitem__(self, idx: int) -> Dict[str, Any]: """Returns a single item from the dataset."""