diff --git a/README.md b/README.md index 6ef1978..8cd772d 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,8 @@ Data folders should follow the following directory structure within `DATA_DIR`: ├── s2bms/ <- Dataset folder. │ ├── model_ready_s2bms.csv <- Csv file with "name_loc" id, locations, aux data and target data. │ ├── aux_classes.csv <- Csv file with explanations for aux data class names. -│ ├── caption_templates.json <- Json file with list of caption templates (referencing aux column names). +│ ├── caption_templates <- Caption templates +│ ├── v1.json <- Json file with list of caption templates (referencing aux column names). │ ├── splits/ <- Torch data splits │ ├── source/ <- Optional: source data used to create model_ready csv. │ ├── eo/ <- EO data modalities @@ -195,6 +196,16 @@ We follow the directory structure from the [Hydra-Lightning template](https://gi └── README.md ``` + +## Architecture Overview + +The diagram below shows how configuration files, datasets, and model components +relate to each other in the AETHER framework. + +![Framework Architecture](docs/figures/diagram.png) + + + ## Attribution This repo is based on the [Hydra-Lightning template](https://github.com/ashleve/lightning-hydra-template). diff --git a/configs/data/heat_guatemala.yaml b/configs/data/heat_guatemala.yaml new file mode 100644 index 0000000..eb42022 --- /dev/null +++ b/configs/data/heat_guatemala.yaml @@ -0,0 +1,24 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.heat_guatemala_dataset.HeatGuatemalaDataset + data_dir: ${paths.data_dir} + modalities: + coords: {} + use_target_data: true + use_features: true + use_aux_data: false + seed: ${seed} + cache_dir: ${paths.cache_dir} + +batch_size: 64 +num_workers: 8 +pin_memory: true + +split_mode: "from_file" +save_split: false +saved_split_file_name: "split_indices_heat_guatemala_2026-02-20-1148.pth" + + +train_val_test_split: [0.7, 0.15, 0.15] +seed: ${seed} \ No newline at end of file diff --git a/configs/experiment/heat_guatemala_coords_reg.yaml b/configs/experiment/heat_guatemala_coords_reg.yaml new file mode 100644 index 0000000..f335d38 --- /dev/null +++ b/configs/experiment/heat_guatemala_coords_reg.yaml @@ -0,0 +1,25 @@ +# @package _global_ +# configs/experiment/heat_guatemala_coords_reg.yaml +# Variant A: GeoClip coordinates only + + +defaults: + - override /model: heat_geoclip_reg + - override /data: heat_guatemala + +tags: ["heat_island", "guatemala", "coords_only", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 50 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "heat_island" + aim: + experiment: "heat_island" diff --git a/configs/experiment/heat_guatemala_fusion_reg.yaml b/configs/experiment/heat_guatemala_fusion_reg.yaml new file mode 100644 index 0000000..4ab69f5 --- /dev/null +++ b/configs/experiment/heat_guatemala_fusion_reg.yaml @@ -0,0 +1,25 @@ +# @package _global_ +# configs/experiment/heat_guatemala_fusion_reg.yaml +# Variant C: GeoClip + tabular fusion + + +defaults: + - override /model: heat_fusion_reg + - override /data: heat_guatemala + +tags: ["heat_island", "guatemala", "fusion", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 50 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "heat_island" + aim: + experiment: "heat_island" diff --git a/configs/experiment/heat_guatemala_tabular_reg.yaml b/configs/experiment/heat_guatemala_tabular_reg.yaml new file mode 100644 index 0000000..740b125 --- /dev/null +++ b/configs/experiment/heat_guatemala_tabular_reg.yaml @@ -0,0 +1,25 @@ +# @package _global_ +# configs/experiment/heat_guatemala_tabular_reg.yaml +# Variant B: tabular features only + + +defaults: + - override /model: heat_tabular_reg + - override /data: heat_guatemala + +tags: ["heat_island", "guatemala", "tabular_only", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 50 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "heat_island" + aim: + experiment: "heat_island" diff --git a/configs/model/heat_fusion_reg.yaml b/configs/model/heat_fusion_reg.yaml new file mode 100644 index 0000000..6a554c5 --- /dev/null +++ b/configs/model/heat_fusion_reg.yaml @@ -0,0 +1,39 @@ +# configs/model/heat_fusion_reg.yaml +# +# Renamed from: predictive_geoclip_tabular_regression.yaml +# Reason: "fusion" is concise; "heat_" scopes to this use case. +# +# Variant: GeoClip (coords) + tabular fusion +# Encoder output = GeoClip (512) + tabular projection (64) = 576-dim + +_target_: src.models.predictive_model_regression.PredictiveRegressionModel + +eo_encoder: + _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder + use_coords: true + use_tabular: true + tab_embed_dim: 64 + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + +# GeoClip frozen; tabular projection + head are trained. +trainable_modules: [eo_encoder, prediction_head] + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: torch.nn.MSELoss diff --git a/configs/model/heat_geoclip_reg.yaml b/configs/model/heat_geoclip_reg.yaml new file mode 100644 index 0000000..eeb5f9e --- /dev/null +++ b/configs/model/heat_geoclip_reg.yaml @@ -0,0 +1,38 @@ +# configs/model/heat_geoclip_reg.yaml +# +# Renamed from: predictive_geoclip_regression.yaml +# Reason: prefixing with "heat_" scopes it to this use case; +# "geoclip_reg" is concise and descriptive. +# +# Variant: coords only (GeoClip encodes lat/lon → 512-dim embedding) + +_target_: src.models.predictive_model_regression.PredictiveRegressionModel + +eo_encoder: + _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder + use_coords: true + use_tabular: false + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + +# Only the prediction head is trained; GeoClip encoder is frozen. +trainable_modules: [prediction_head] + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: torch.nn.MSELoss diff --git a/configs/model/heat_tabular_reg.yaml b/configs/model/heat_tabular_reg.yaml new file mode 100644 index 0000000..edd8dd3 --- /dev/null +++ b/configs/model/heat_tabular_reg.yaml @@ -0,0 +1,45 @@ +# configs/model/heat_tabular_reg.yaml +# +# Renamed from: predictive_tabular_regression.yaml +# Reason: prefixed with "heat_" to scope to this use case. +# +# Variant: tabular only (feat_* CSV columns encoded by a small MLP) +# +# NOTE: tabular_dim is NOT hardcoded here. +# It is resolved automatically at runtime: +# 1. HeatGuatemalaDataset.tabular_dim reads len(feat_names) from the CSV. +# 2. BaseDataModule.tabular_dim delegates to the train dataset. +# 3. PredictiveRegressionModel.setup() calls +# self.eo_encoder.build_tabular_branch(self.trainer.datamodule.tabular_dim) + +_target_: src.models.predictive_model_regression.PredictiveRegressionModel + +eo_encoder: + _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder + use_coords: false + use_tabular: true + tab_embed_dim: 64 + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + +# Both encoder and head have trainable parameters. +trainable_modules: [eo_encoder, prediction_head] + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0001 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: torch.nn.MSELoss diff --git a/docs/figures/diagram.png b/docs/figures/diagram.png new file mode 100644 index 0000000..da8c4f9 Binary files /dev/null and b/docs/figures/diagram.png differ diff --git a/src/data/base_datamodule.py b/src/data/base_datamodule.py index 9a30883..2c94f6e 100644 --- a/src/data/base_datamodule.py +++ b/src/data/base_datamodule.py @@ -33,6 +33,7 @@ def __init__( saved_split_file_name: str | None = None, caption_builder: BaseCaptionBuilder = None, seed: int = 12345, + spatial_split_distance_m: int = 1000, ) -> None: """Datamodule class which handles dataset splits and batching. @@ -47,6 +48,8 @@ def __init__( :param save_split: if to save split file :param saved_split_file_name: file name to save split file :param caption_builder: instance of BaseCaptionBuilder for generating textual captions + :param spatial_split_distance_m: minimum distance in metres between clusters when + split_mode is 'spatial_clusters'. Default 1000 m. """ super().__init__() self.save_hyperparameters(logger=False) @@ -63,6 +66,14 @@ def __init__( self.split_data() + @property + def tabular_dim(self): + dataset = self.data_train + # Unwrap Subset wrappers (e.g. from random_split) + while hasattr(dataset, "dataset"): + dataset = dataset.dataset + return dataset.tabular_dim + @property def num_classes(self) -> int: """Number of classes in the dataset.""" @@ -120,7 +131,7 @@ def split_data(self) -> None: "Warning: DBSCAN clustering on more than 2000 samples may be slow. Maybe set n_jobs in DBScan?" ) # 4000 m distance between points. Use geodist to calculate true distance. - min_dist = 4000 + min_dist = self.hparams.spatial_split_distance_m clustering = DBSCAN( eps=min_dist, metric=lambda u, v: geodist(u, v).meters, @@ -244,19 +255,14 @@ def split_data(self) -> None: def save_split_indices(self, split_indices: dict[str, Any] | dict): """Save split indices into file.""" - assert ( - self.hparams.split_dir is not None - ), "split_dir must be provided when saving a new data split." - assert os.path.exists( - self.hparams.split_dir - ), f"Directory to save split indices does not exist: {self.hparams.split_dir}" - assert isinstance(split_indices, dict), "split_indices must be a dictionary to be saved." + self.split_dir = os.path.join(self.hparams.dataset.data_dir, "splits") + os.makedirs(self.split_dir, exist_ok=True) timestamp = create_timestamp() torch.save( split_indices, os.path.join( - self.hparams.split_dir, + self.split_dir, f"split_indices_{self.hparams.dataset_name}_{timestamp}.pth", ), ) diff --git a/src/data/base_dataset.py b/src/data/base_dataset.py index 2ab5c4f..258fb43 100644 --- a/src/data/base_dataset.py +++ b/src/data/base_dataset.py @@ -23,6 +23,7 @@ def __init__( cache_dir: str = None, implemented_mod: set[str] = None, mock: bool = False, + use_features: bool = True, ) -> None: """Interface for any use case dataset. @@ -45,6 +46,7 @@ def __init__( :param cache_dir: directory to save cached data :param implemented_mod: implemented modalities for each dataset :param mock: whether to mock csv file + :param use_features: if tabular feat_* columns should be included. Default True. """ if mock: @@ -79,6 +81,7 @@ def __init__( self.use_aux_data: bool = use_aux_data self.records: dict[str, Any] = self.get_records() self.pooch_cli = None + self.use_features = use_features @final def get_records(self) -> dict[str, Any]: @@ -112,6 +115,10 @@ def get_records(self) -> dict[str, Any]: if self.use_aux_data: self.aux_names = [c for c in self.df.columns if "aux_" in c] columns.extend(self.aux_names) + + # Include tabular features + self.feat_names = [c for c in self.df.columns if c.startswith("feat_")] + columns.extend(self.feat_names) return self.df.loc[:, columns].to_dict("records") @@ -120,6 +127,13 @@ def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.records) + @final + @property + def tabular_dim(self) -> int: + if not self.use_features: + return 0 + return len(getattr(self, "feat_names", [])) + @abstractmethod def __getitem__(self, idx: int) -> Dict[str, Any]: """Returns a single item from the dataset.""" diff --git a/src/data/heat_guatemala_dataset.py b/src/data/heat_guatemala_dataset.py new file mode 100644 index 0000000..44e81dc --- /dev/null +++ b/src/data/heat_guatemala_dataset.py @@ -0,0 +1,111 @@ +""" +Heat Guatemala LST dataset. + +Location: src/data/heat_guatemala_dataset.py + +Changes vs original: + - tabular_dim property added so the datamodule (and model) can read it + without hardcoding anything. + - implemented_mod stays {"coords"} because tabular data arrives + automatically through feat_* CSV columns, not through the modalities dict. + This is documented explicitly below. + - Minor: __getitem__ guard tightened (tabular only added when feat_names exist + and modality logic is cleaner). +""" + +from typing import Any, Dict, override + +import torch + +from src.data.base_dataset import BaseDataset + + +class HeatGuatemalaDataset(BaseDataset): + """ + Dataset for the urban heat island use case (Guatemala City, LST regression). + + CSV layout expected (produced by scripts/make_model_ready_heat_guatemala.py): + - name_loc : unique location identifier + - lat, lon : WGS84 coordinates + - target_lst : Land Surface Temperature [°C] + - feat_* : tabular features (numeric + one-hot categorical) + + Modality design note + -------------------- + `implemented_mod = {"coords"}` because in this framework a "modality" refers + to data loaded from a separate file (e.g. a GeoTIFF or .npy embedding). + Tabular features live directly in the model-ready CSV and are picked up + automatically by BaseDataset.get_records() via the `feat_` column prefix. + They do NOT need to be listed in `modalities`. + """ + + def __init__( + self, + data_dir: str, + modalities: dict, + use_target_data: bool = True, + use_aux_data: bool = False, + seed: int = 12345, + cache_dir: str = None, + mock: bool = False, + use_features: bool = True, + ) -> None: + super().__init__( + data_dir=data_dir, + modalities=modalities, + use_target_data=use_target_data, + use_aux_data=use_aux_data, + dataset_name="heat_guatemala", + seed=seed, + cache_dir=cache_dir, + implemented_mod={"coords"}, + mock=mock, + use_features=use_features, + ) + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def tabular_dim(self) -> int: + """Number of tabular features (feat_* columns). 0 if none.""" + return len(self.feat_names) + + # ------------------------------------------------------------------ + # Required overrides + # ------------------------------------------------------------------ + + def setup(self) -> None: + """No files to download / prepare for this dataset.""" + return + + @override + def __getitem__(self, idx: int) -> Dict[str, Any]: + row = self.records[idx] + sample: Dict[str, Any] = {"eo": {}} + + # --- EO modalities --- + for modality in self.modalities: + if modality == "coords": + sample["eo"]["coords"] = torch.tensor( + [row["lat"], row["lon"]], dtype=torch.float32 + ) + + # --- Tabular features (always included if present in CSV) --- + if self.use_features and self.feat_names: + sample["eo"]["tabular"] = torch.tensor( + [row[k] for k in self.feat_names], dtype=torch.float32 + ) + + # --- Target --- + if self.use_target_data: + sample["target"] = torch.tensor( + [row[k] for k in self.target_names], dtype=torch.float32 + ) + + # --- Auxiliary data --- + if self.use_aux_data: + sample["aux"] = [row[k] for k in self.aux_names] + + return sample diff --git a/src/data_preprocessing/make_model_ready_heat_guatemala.py b/src/data_preprocessing/make_model_ready_heat_guatemala.py new file mode 100644 index 0000000..cfc1da9 --- /dev/null +++ b/src/data_preprocessing/make_model_ready_heat_guatemala.py @@ -0,0 +1,174 @@ +""" +Build model-ready CSV for the Heat Guatemala use case (data/heat_guatemala/model_ready_heat_guatemala.csv). +""" + +import argparse +import re + +import numpy as np +import pandas as pd + +# ----------------------------------------------------------------------- +# Continuous / numeric columns → kept as float feat_* columns +# ----------------------------------------------------------------------- +NUMERIC_COLS = [ + "AREA_M2", + "AREA_Ha", + "UTM_areaM2", + "UTM_perimeterMeter", + # Vegetation & water indices (continuous floats) + "NDVI_mean2022", + "NDVI_min2022", + "MODIS_NDVIchange_20002020", + "NDWI_mean2022", + "NDWI_minimum2022", + # Urban structure + "CopenicusMSZ_BuildingHeightM", + "BUA_GAIA_Age_Mean", + # Socio-demographic + "PopulationDensityPerKm2", + "SocioEconomicQuality", + # Terrain + "DEM5m_Slope%", + "DEM5m_TerrainRuggednessIndex_MeanTRI", + "DEM5m_MeanAspect", + "DEM5m_TopographicPositionIndex_MeanTPI", + # Forest cover + "Hansen_ForestCover_sumHa", + "Hansen_ForestCover_meanPerc", + "HansenLoss_Ha", +] + +# ----------------------------------------------------------------------- +# Genuinely categorical / nominal columns → one-hot encoded feat_* columns +# ----------------------------------------------------------------------- +CATEGORICAL_COLS = [ + "BlockType", + "BlockTypeIndustry", + "BlockMAGADominantLanduse", + "IntrZon", + "DISTRITOS", + "ZONAM", +] + +# Columns where >this fraction of values are NaN → drop column entirely +NAN_DROP_THRESHOLD = 0.30 + + +def clean_token(x: str) -> str: + """Make a string safe for use as a column name.""" + x = str(x).strip() + x = re.sub(r"[^\w]+", "_", x) + x = re.sub(r"_+", "_", x).strip("_") + return x if x else "NA" + + +def main(source_csv: str, out_csv: str, drop_zero_lst: bool = True) -> None: + df = pd.read_csv(source_csv, encoding="latin-1") + print(f"Loaded: {source_csv} → {df.shape[0]} rows, {df.shape[1]} cols") + + # ------------------------------------------------------------------ # + # 1. Clean target: drop zero and NaN LST rows # + # ------------------------------------------------------------------ # + target_col = "LST_°C_mean_predictor" + + if drop_zero_lst: + n_before = len(df) + df = df[df[target_col] != 0].copy().reset_index(drop=True) + print(f"Dropped {n_before - len(df)} rows with LST == 0") + + n_before = len(df) + df = df.dropna(subset=[target_col]).reset_index(drop=True) + if len(df) < n_before: + print(f"Dropped {n_before - len(df)} rows with NaN LST target") + else: + print("No NaN LST targets found — good.") + + # ------------------------------------------------------------------ # + # 2. Build output skeleton # + # ------------------------------------------------------------------ # + out = pd.DataFrame({ + "name_loc": [f"heat_{i:06d}" for i in range(len(df))], + "lat": df["LAT"].astype(float), + "lon": df["LONG"].astype(float), + "target_lst": df[target_col].astype(float), + }) + + # Verify target is clean + assert out["target_lst"].isna().sum() == 0, "BUG: NaN in target after cleaning" + print(f"Target stats: mean={out['target_lst'].mean():.2f} " + f"std={out['target_lst'].std():.2f} " + f"min={out['target_lst'].min():.2f} " + f"max={out['target_lst'].max():.2f}") + + # ------------------------------------------------------------------ # + # 3. Numeric features — impute or drop # + # ------------------------------------------------------------------ # + numeric_feat_cols = [] + dropped_cols = [] + + for c in NUMERIC_COLS: + if c not in df.columns: + print(f" [WARN] numeric column not found, skipping: {c}") + continue + col_name = f"feat_{clean_token(c).lower()}" + series = pd.to_numeric(df[c], errors="coerce").astype(float) + nan_frac = series.isna().sum() / len(series) + + if nan_frac > NAN_DROP_THRESHOLD: + print(f" [DROP] {col_name}: {nan_frac:.0%} NaN — exceeds threshold, dropped") + dropped_cols.append(col_name) + continue + + if nan_frac > 0: + col_mean = series.mean() + print(f" [IMPUTE] {col_name}: {nan_frac:.1%} NaN → filled with mean ({col_mean:.4f})") + series = series.fillna(col_mean) + + out[col_name] = series + numeric_feat_cols.append(col_name) + + # ------------------------------------------------------------------ # + # 4. Categorical features (one-hot) # + # ------------------------------------------------------------------ # + for c in CATEGORICAL_COLS: + if c not in df.columns: + print(f" [WARN] categorical column not found, skipping: {c}") + continue + cats = df[c].astype(str).fillna("NA").map(clean_token) + prefix = f"feat_{clean_token(c).lower()}" + dummies = pd.get_dummies(cats, prefix=prefix, prefix_sep="__") + out = pd.concat([out, dummies.astype(np.float32)], axis=1) + + # ------------------------------------------------------------------ # + # 5. Final NaN check — should be zero # + # ------------------------------------------------------------------ # + total_nan = out.isna().sum().sum() + if total_nan > 0: + print("\n[ERROR] NaN values remain in output:") + print(out.isna().sum()[out.isna().sum() > 0]) + raise ValueError(f"{total_nan} NaN values remain — fix before training.") + else: + print("\nNaN check passed — output is clean.") + + # ------------------------------------------------------------------ # + # 6. Save and report # + # ------------------------------------------------------------------ # + out.to_csv(out_csv, index=False) + + feat_cols = [c for c in out.columns if c.startswith("feat_")] + print(f"\nWrote: {out_csv}") + print(f"Shape: {out.shape}") + print(f"tabular_dim (feat_* columns): {len(feat_cols)}") + print(f" numeric features kept: {len(numeric_feat_cols)}") + print(f" numeric features dropped:{len(dropped_cols)}") + print(f" one-hot features: {len(feat_cols) - len(numeric_feat_cols)}") + + +if __name__ == "__main__": + ap = argparse.ArgumentParser() + ap.add_argument("--source_csv", required=True) + ap.add_argument("--out_csv", required=True) + ap.add_argument("--drop_zero_lst", type=lambda x: x.lower() != "false", default=True) + args = ap.parse_args() + main(args.source_csv, args.out_csv, args.drop_zero_lst) \ No newline at end of file diff --git a/src/models/components/eo_encoders/multimodal_encoder.py b/src/models/components/eo_encoders/multimodal_encoder.py new file mode 100644 index 0000000..c8f6603 --- /dev/null +++ b/src/models/components/eo_encoders/multimodal_encoder.py @@ -0,0 +1,100 @@ +""" +Unified multimodal encoder for EO data. + + +Controlled entirely via constructor flags: + - use_coords: encode lat/lon with GeoClip + - use_tabular: encode feat_* tabular columns + +""" + +from typing import Dict, override + +import torch +from torch import nn + +from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder +from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder + + +class MultiModalEncoder(BaseEOEncoder): + """ + - coords only (use_coords=True, use_tabular=False) + - tabular only (use_coords=False, use_tabular=True) + - coords + tabular (use_coords=True, use_tabular=True) + """ + + def __init__( + self, + use_coords: bool = True, + use_tabular: bool = False, + tab_embed_dim: int = 64, + tabular_dim: int = None, + ) -> None: + super().__init__() + + assert use_coords or use_tabular, ( + "At least one of use_coords or use_tabular must be True." + ) + + self.use_coords = use_coords + self.use_tabular = use_tabular + self.tab_embed_dim = tab_embed_dim + self._tabular_ready = False + + coords_dim = 0 + if use_coords: + self.coords_encoder = GeoClipCoordinateEncoder() + coords_dim = self.coords_encoder.output_dim # 512 + + self._coords_dim = coords_dim + + # Built only if dim is already known + if use_tabular and tabular_dim is not None: + self.build_tabular_branch(tabular_dim) + elif use_tabular: + self.tabular_proj = None + else: + self.output_dim = coords_dim + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def build_tabular_branch(self, tabular_dim: int) -> None: + """ + Build (or rebuild) the tabular projection layer. + """ + if self._tabular_ready and hasattr(self, "_last_tabular_dim"): + if self._last_tabular_dim == tabular_dim: + return # already built with correct dim + + self.tabular_proj = nn.Sequential( + nn.LayerNorm(tabular_dim), + nn.Linear(tabular_dim, self.tab_embed_dim), + nn.ReLU(), + ) + self._last_tabular_dim = tabular_dim + self._tabular_ready = True + self.output_dim = self._coords_dim + self.tab_embed_dim + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + @override + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + parts = [] + + if self.use_coords: + parts.append(self.coords_encoder(batch)) # (B, 512) + + if self.use_tabular: + assert self._tabular_ready, ( + "Tabular branch not built yet. Call build_tabular_branch(tabular_dim) first, " + "or pass tabular_dim to the constructor." + ) + tab = batch["eo"]["tabular"].float() # (B, tabular_dim) + parts.append(self.tabular_proj(tab)) # (B, tab_embed_dim) + + return torch.cat(parts, dim=-1) diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py new file mode 100644 index 0000000..e835b5a --- /dev/null +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -0,0 +1,46 @@ +""" +MLP regression prediction head. + +Renamed from: mlp_reg_pred_head.py → mlp_regression_head.py +Location: src/models/components/pred_heads/mlp_regression_head.py + +Changes vs original: + - File/class name made more readable (mlp_reg_pred_head → mlp_regression_head) + - No logic changes; class name kept as MLPRegressionPredictionHead for clarity +""" + +from typing import override + +import torch +from torch import nn + +from src.models.components.pred_heads.base_pred_head import BasePredictionHead + + +class MLPRegressionPredictionHead(BasePredictionHead): + """MLP prediction head for regression tasks (outputs a continuous value).""" + + def __init__(self, nn_layers: int = 2, hidden_dim: int = 256) -> None: + super().__init__() + self.nn_layers = nn_layers + self.hidden_dim = hidden_dim + + @override + def forward(self, feats: torch.Tensor) -> torch.Tensor: + return self.net(feats) + + @override + def configure_nn(self) -> None: + assert isinstance(self.input_dim, int), self.input_dim + assert isinstance(self.output_dim, int), self.output_dim + + layers = [] + in_dim = self.input_dim + + for _ in range(self.nn_layers - 1): + layers.append(nn.Linear(in_dim, self.hidden_dim)) + layers.append(nn.ReLU()) + in_dim = self.hidden_dim + + layers.append(nn.Linear(in_dim, self.output_dim)) + self.net = nn.Sequential(*layers) diff --git a/src/models/predictive_model_regression.py b/src/models/predictive_model_regression.py new file mode 100644 index 0000000..764f472 --- /dev/null +++ b/src/models/predictive_model_regression.py @@ -0,0 +1,123 @@ +""" +Regression variant of the predictive model (MSE / MAE / RMSE / R²). + +Location: src/models/predictive_model_regression.py + +Key changes vs original: + - setup(stage) injects tabular_dim into MultiModalEncoder automatically, + so tabular_dim never needs to be hardcoded in any config. + - num_classes defaults to 1 (single LST value). + - All regression metrics (MSE, RMSE, MAE, R²) are logged per split. +""" + +from typing import Dict, override + +import torch +import torch.nn.functional as F + +from src.models.base_model import BaseModel +from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder + + +class PredictiveRegressionModel(BaseModel): + + def __init__( + self, + eo_encoder, + prediction_head, + trainable_modules, + optimizer, + scheduler, + loss_fn, + num_classes: int = 1, + **kwargs, + ): + super().__init__( + trainable_modules=trainable_modules, + optimizer=optimizer, + scheduler=scheduler, + loss_fn=loss_fn, + num_classes=num_classes, + ) + + self.eo_encoder = eo_encoder + self.prediction_head = prediction_head + + # Prediction head wiring happens AFTER setup() resolves tabular_dim. + # If the encoder does NOT need tabular data, we can wire immediately. + if not ( + isinstance(self.eo_encoder, MultiModalEncoder) + and self.eo_encoder.use_tabular + ): + self._wire_head() + self.freezer() + + # ------------------------------------------------------------------ + # Lightning hooks + # ------------------------------------------------------------------ + + def setup(self, stage: str) -> None: + """ + Called by Lightning after the datamodule is ready. + Injects tabular_dim into the encoder if it needs it, + then wires the prediction head dimensions. + """ + if ( + isinstance(self.eo_encoder, MultiModalEncoder) + and self.eo_encoder.use_tabular + and not self.eo_encoder._tabular_ready + ): + # Pull tabular_dim from the datamodule — no hardcoding needed! + tabular_dim = self.trainer.datamodule.tabular_dim + self.eo_encoder.build_tabular_branch(tabular_dim) + self._wire_head() + + self.freezer() + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _wire_head(self) -> None: + """Connect encoder output_dim → head input_dim, then build head layers.""" + self.prediction_head.set_dim( + input_dim=self.eo_encoder.output_dim, + output_dim=self.num_classes, + ) + self.prediction_head.configure_nn() + if "prediction_head" not in self.trainable_modules: + self.trainable_modules.append("prediction_head") + + # ------------------------------------------------------------------ + # Forward & step + # ------------------------------------------------------------------ + + @override + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + feats = self.eo_encoder(batch) + feats = F.normalize(feats, dim=-1) + return self.prediction_head(feats) + + @override + def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: + y_hat = self.forward(batch) + y = batch["target"] + + loss = self.loss_fn(y_hat, y) + + mse = F.mse_loss(y_hat, y) + rmse = torch.sqrt(mse) + mae = F.l1_loss(y_hat, y) + + ss_res = torch.sum((y - y_hat) ** 2) + ss_tot = torch.sum((y - torch.mean(y)) ** 2) + 1e-12 + r2 = 1.0 - ss_res / ss_tot + + log_kwargs = dict(on_step=False, on_epoch=True) + self.log(f"{mode}_loss", loss, prog_bar=True, **log_kwargs) + self.log(f"{mode}_mse", mse, **log_kwargs) + self.log(f"{mode}_rmse", rmse, **log_kwargs) + self.log(f"{mode}_mae", mae, **log_kwargs) + self.log(f"{mode}_r2", r2, **log_kwargs) + + return loss