Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
90d3dcc
Update minor README
vdplasthijs Jan 26, 2026
3090c73
UrbanHeatIslands use case: Guatemala case
BachirNILU Feb 19, 2026
7299da0
Update base data (add feat_ + add parameter for spatial splitting)
BachirNILU Feb 19, 2026
a813376
Add regression model/head
BachirNILU Feb 19, 2026
dc0d261
Add multimodel encoder (tabular+coords)
BachirNILU Feb 19, 2026
84fe08e
Readme: add diagram
BachirNILU Feb 19, 2026
79a79e2
update comments
BachirNILU Feb 19, 2026
d5d996d
docs: move diagram to docs/figures/
BachirNILU Feb 19, 2026
ab32d62
docs: move diagram to docs/figures/
BachirNILU Feb 19, 2026
a816825
Removing output_normalization='l2'
BachirNILU Feb 20, 2026
27bc9ec
data config + split_dir fixed
BachirNILU Feb 20, 2026
03a202d
FIx the need and remove split_dir parameter
gabrieletijunaityte Feb 23, 2026
2bc0b8e
fix: wire head, trainable modules, F.normalize, use_features param
BachirNILU Feb 23, 2026
adecb1d
fix: remove redundant _tabular_ready check in __init__
BachirNILU Feb 23, 2026
d7a2254
Delete split_indices_heat_guatemala_2026-02-20-1148.pth
gabrieletijunaityte Feb 24, 2026
e34a830
Introduce BaseMetrics and MetricsWrapper classes
gabrieletijunaityte Feb 24, 2026
8661e5e
Introduce BaseMetrics and MetricsWrapper classes
gabrieletijunaityte Feb 24, 2026
9514a83
New metrics
gabrieletijunaityte Feb 24, 2026
96ab9bb
Adapt metrics, losses for metrics wrapper
gabrieletijunaityte Feb 24, 2026
14e8cc7
Introduce metrics wrapper as config
gabrieletijunaityte Feb 24, 2026
b7f094a
Format hooks
gabrieletijunaityte Feb 24, 2026
524958a
Format hooks
gabrieletijunaityte Feb 24, 2026
a95f0b2
Format hooks
gabrieletijunaityte Feb 24, 2026
3e6967a
De-duplicate regression predictive model
gabrieletijunaityte Feb 24, 2026
f096f20
Metrics as configs, use predictive model instead of predictive regres…
gabrieletijunaityte Feb 24, 2026
6882407
Tabular dimensions introduced earlier to remove need for wiring
gabrieletijunaityte Feb 24, 2026
5d91521
Merge branch 'feature/contrastive_setup' of github.com:WUR-AI/aether …
gabrieletijunaityte Mar 2, 2026
539e609
Restructure aux handling in the dataset
gabrieletijunaityte Mar 2, 2026
b819531
Merge branch 'develop' into feature/contrastive_setup
gabrieletijunaityte Mar 2, 2026
810d45b
Fix dataset argument expected types
gabrieletijunaityte Mar 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/data/butterfly_aef.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dataset:
size: 256
format: tif
use_target_data: true # default true
use_aux_data: false # default false
use_aux_data: none # default false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it none or "none"? I see both (in this file and butterfly_coords.yml)

seed: ${seed}
cache_dir: ${paths.cache_dir}

Expand Down
2 changes: 1 addition & 1 deletion configs/data/butterfly_coords.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dataset:
modalities:
coords:
use_target_data: true
use_aux_data: false
use_aux_data: "none"
seed: ${seed}
cache_dir: ${paths.cache_dir}

Expand Down
2 changes: 1 addition & 1 deletion configs/data/butterfly_coords_text.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dataset:
modalities:
coords:
use_target_data: false
use_aux_data: true
use_aux_data: "all"
seed: ${seed}
cache_dir: ${paths.cache_dir}

Expand Down
2 changes: 1 addition & 1 deletion configs/data/butterfly_full_param_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dataset:
format: npy
year: 2019
use_target_data: true
use_aux_data: false
use_aux_data: none
seed: ${seed}
cache_dir: ${paths.cache_dir}

Expand Down
2 changes: 1 addition & 1 deletion configs/data/butterfly_s2_rgb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dataset:
channels: rgb
preprocessing: zscored
use_target_data: true # default true
use_aux_data: false # default false
use_aux_data: none # default false
seed: ${seed}
cache_dir: ${paths.cache_dir}

Expand Down
2 changes: 1 addition & 1 deletion configs/data/heat_guatemala.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dataset:
coords: {}
use_target_data: true
use_features: true
use_aux_data: false
use_aux_data: none
seed: ${seed}
cache_dir: ${paths.cache_dir}

Expand Down
2 changes: 1 addition & 1 deletion configs/data/satbird-Kenya_coords.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dataset:
modalities:
coords:
use_target_data: true
use_aux_data: false
use_aux_data: none
seed: ${seed}
cache_dir: ${paths.cache_dir}
study_site: Kenya
Expand Down
2 changes: 1 addition & 1 deletion configs/data/satbird_USA-summer_coords.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dataset:
modalities:
coords:
use_target_data: true
use_aux_data: false
use_aux_data: none
seed: ${seed}
cache_dir: ${paths.cache_dir}
study_site: USA-summer
Expand Down
44 changes: 33 additions & 11 deletions src/data/base_caption_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,33 +52,53 @@ def _fill(template: str, fillers: Dict[str, str]) -> str:
return template

@abstractmethod
def _build_from_template(self, template_idx: int, row: List[Any]) -> str:
def _build_from_template(
self, template_idx: int, aux: torch.Tensor, top: List[str] | None = None
) -> str:
"""Build caption text from template and row of auxiliary data."""
pass

def random(self, aux_values: List[Any]) -> List[str]:
def random(self, aux_values) -> List[str]:
"""Return a caption from a randomly sampled template for each data point."""
formatted_rows = []
template_idx = random.choices(

batch_size = len(aux_values["aux"])

template_ids = random.choices(
range(len(self.templates)),
k=len(aux_values),
k=batch_size,
)
for idx, row in zip(template_idx, aux_values):
formatted_rows.append(self._build_from_template(idx, row))
for (
i,
template_idx,
) in enumerate(template_ids):
row_aux = aux_values["aux"][i]
row_top = aux_values.get("top")[i] if aux_values.get("top") else None
formatted_rows.append(
self._build_from_template(template_idx, aux=row_aux, top=row_top)
)

return formatted_rows

def all(self, aux_values: List[Any]) -> List[str]:
def all(self, aux_values) -> List[str]:
"""Return a list of captions from all available templates."""
formatted_rows = []
for row in aux_values:
for i in range(0, len(aux_values["aux"])):
descriptions = []
row_aux = aux_values["aux"][i]
row_top = aux_values.get("top")[i] if aux_values.get("top") else None

for template_idx in range(0, len(self)):
descriptions.append(self._build_from_template(template_idx, row))
descriptions.append(
self._build_from_template(template_idx, aux=row_aux, top=row_top)
)
formatted_rows.append(descriptions)

return formatted_rows

def build_concepts(self, aux_values) -> List[str]:
pass


class DummyCaptionBuilder(BaseCaptionBuilder):
"""Dummy caption builder for testing purposes."""
Expand All @@ -89,8 +109,10 @@ def __init__(self, templates_fname: str, data_dir: str, seed: int) -> None:
def sync_with_dataset(self, dataset) -> None:
pass

def _build_from_template(self, template_idx: int, row: List[Any]) -> str:
first_val = row[0].item() if torch.is_tensor(row) else row[0]
def _build_from_template(
self, template_idx: int, aux: torch.Tensor, top: List[str] | None = None
) -> str:
first_val = aux[0].item()
return f"Location with value {first_val}"


Expand Down
42 changes: 34 additions & 8 deletions src/data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, final
from typing import Any, Dict, List, final

import numpy as np
import pandas as pd
Expand All @@ -16,7 +17,7 @@ def __init__(
data_dir: str,
modalities: dict,
use_target_data: bool = True,
use_aux_data: bool = False,
use_aux_data: Dict[str, List[str] | str] | str | None = None,
dataset_name: str = "BaseDataset",
seed: int = 12345,
mode: str = "train",
Expand Down Expand Up @@ -78,10 +79,26 @@ def __init__(
self.num_classes = None
self.tabular_dim = None
self.seed = seed
self.use_target_data: bool = use_target_data
self.use_aux_data: bool = use_aux_data
self.use_target_data = use_target_data
self.use_features = use_features

if use_aux_data is None or use_aux_data == "all":
self.use_aux_data = {
"aux": {
"pattern": "^aux_(?!.*top).*",
# 'columns' : []
},
"top": {
"pattern": "^aux_.*top.*",
# 'columns' : []
},
}

elif type(use_aux_data) is dict:
self.use_aux_data = use_aux_data
else:
self.use_aux_data = None

self.mode: str = mode # 'train', 'val', 'test'
self.records: dict[str, Any] = self.get_records()

Expand All @@ -107,16 +124,25 @@ def get_records(self) -> dict[str, Any]:
)
columns.append(f"{modality}_path")

# Include targets
# Include targets TODO: this could be moved under geo-modalities
if self.use_target_data:
self.target_names = [c for c in self.df.columns if "target_" in c]
columns.extend(self.target_names)
self.num_classes = len(self.target_names)

# Include aux data
if self.use_aux_data:
self.aux_names = [c for c in self.df.columns if "aux_" in c]
columns.extend(self.aux_names)
if self.use_aux_data is not None:
for k, val in self.use_aux_data.items():
if val.get("pattern"):
pattern = re.compile(val["pattern"])
aux_names = [x for x in self.df.columns if pattern.match(x)]
else:
aux_names = val.get(
"columns",
ValueError('use_aux_data should have "pattern" or "columns" defined'),
)
self.use_aux_data[k] = aux_names
columns.extend(aux_names)

# Include tabular features
if self.use_features:
Expand Down
60 changes: 28 additions & 32 deletions src/data/butterfly_caption_builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math
import os
from typing import Any, List, override

import pandas as pd
import torch

from src.data.base_caption_builder import (
BaseCaptionBuilder,
Expand All @@ -27,19 +27,20 @@ def sync_with_dataset(self, dataset: BaseDataset) -> None:
corine_columns = self.get_corine_column_keys()
humanfootprint_columns = self.get_humanfootprint_column_keys()
aux_columns = {**bioclim_columns, **corine_columns, **humanfootprint_columns}
self.column_to_metadata_map = {}
self.column_to_metadata_map = {k: {} for k in dataset.use_aux_data.keys()}

for id, key in enumerate(dataset.aux_names):
if "aux_corine_frac" in key and "top" in key: # to avoid assert statement
description, units = None, None
else:
description, units = aux_columns.get(key) or (None, None)
assert description is not None, f"Key {key} not found in aux columns"
self.column_to_metadata_map[key] = {
"id": id,
"description": description,
"units": units,
}
for aux_cat, cols in dataset.use_aux_data.items():
for i, c in enumerate(cols):
if "top" in aux_cat:
description, units = None, None
else:
description, units = aux_columns.get(c) or (None, None)

self.column_to_metadata_map[aux_cat][c] = {
"id": i,
"description": description,
"units": units,
}

def get_corine_column_keys(self):
"""Returns metadata for corine columns."""
Expand Down Expand Up @@ -99,33 +100,28 @@ def get_humanfootprint_column_keys(self):
def _build_from_template(
self,
template_idx: int,
row: List[Any],
aux: torch.Tensor,
top: List[str] | None = None,
convert_corine_perc: bool = True,
) -> str:
"""Create caption from template and row of auxiliary data."""
template = self.templates[template_idx]
tokens = self.tokens_in_template[template_idx]
replacements = {}
for token in tokens:
if "aux_corine_frac" in token and "top" in token:
try:
values_dict_top = self.column_to_metadata_map[token]
except KeyError:
raise KeyError(
f"Token {token} not found in column_to_metadata_map {self.column_to_metadata_map}. Check if the token in the template matches the column names in the dataset."
)
idx_top = values_dict_top["id"]
referral_token = row[
idx_top
] # e.g., token 'aux_corine_frac_lowlevel_top_1' might refer to 'corine_frac_211' in this row
referral_token = (
"aux_" + referral_token if "aux_" not in referral_token else referral_token
init_token = token
if "top" in token:
idx = self.column_to_metadata_map["top"][token]["id"]
token = f"aux_{top[idx]}"
try:
values_dict = self.column_to_metadata_map["aux"][token]
except KeyError:
raise KeyError(
f"Token {token} not found in column_to_metadata_map {self.column_to_metadata_map}. Check if the token in the template matches the column names in the dataset."
)
values_dict = self.column_to_metadata_map[referral_token]
else:
values_dict = self.column_to_metadata_map[token]

idx = values_dict["id"]
value = row[idx]
value = aux[idx].item()

formatted_desc = values_dict["description"].lower() or ""
units = values_dict["units"]
Expand All @@ -139,7 +135,7 @@ def _build_from_template(
formatted_desc = formatted_desc + f' ({round(value)} {units if units else ""})'
else:
formatted_desc = formatted_desc + f' of {round(value)} {units if units else ""}'
replacements[token] = formatted_desc
replacements[init_token] = formatted_desc

template = self._fill(template, replacements)
return template
Expand Down
13 changes: 10 additions & 3 deletions src/data/butterfly_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, override
from typing import Any, Dict, List, override

import numpy as np
import pooch
Expand All @@ -17,7 +17,7 @@ def __init__(
data_dir: str,
modalities: dict,
use_target_data: bool = True,
use_aux_data: bool = False,
use_aux_data: Dict[str, List[str] | str] | None = None,
seed: int = 12345,
cache_dir: str = None,
mock: bool = False,
Expand Down Expand Up @@ -165,7 +165,14 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
)

if self.use_aux_data:
formatted_row["aux"] = [row[i] for i in self.aux_names]
formatted_row["aux"] = {}
for aux_cat, vals in self.use_aux_data.items():
if aux_cat == "aux":
formatted_row["aux"][aux_cat] = torch.tensor(
[row[v] for v in vals], dtype=torch.float32
)
else:
formatted_row["aux"][aux_cat] = [row[v] for v in vals]

return formatted_row

Expand Down
Loading