diff --git a/configs/data/butterfly_aef.yaml b/configs/data/butterfly_aef.yaml index c5aad63..ba9f3087 100644 --- a/configs/data/butterfly_aef.yaml +++ b/configs/data/butterfly_aef.yaml @@ -10,7 +10,7 @@ dataset: size: 256 format: tif use_target_data: true # default true - use_aux_data: false # default false + use_aux_data: none # default false seed: ${seed} cache_dir: ${paths.cache_dir} diff --git a/configs/data/butterfly_coords.yaml b/configs/data/butterfly_coords.yaml index 7bd82e7..f4acb51 100644 --- a/configs/data/butterfly_coords.yaml +++ b/configs/data/butterfly_coords.yaml @@ -7,7 +7,7 @@ dataset: modalities: coords: use_target_data: true - use_aux_data: false + use_aux_data: "none" seed: ${seed} cache_dir: ${paths.cache_dir} diff --git a/configs/data/butterfly_coords_text.yaml b/configs/data/butterfly_coords_text.yaml index c478acf..92b6bd7 100644 --- a/configs/data/butterfly_coords_text.yaml +++ b/configs/data/butterfly_coords_text.yaml @@ -7,7 +7,7 @@ dataset: modalities: coords: use_target_data: false - use_aux_data: true + use_aux_data: "all" seed: ${seed} cache_dir: ${paths.cache_dir} diff --git a/configs/data/butterfly_full_param_example.yaml b/configs/data/butterfly_full_param_example.yaml index 40e327a..541f4c2 100644 --- a/configs/data/butterfly_full_param_example.yaml +++ b/configs/data/butterfly_full_param_example.yaml @@ -16,7 +16,7 @@ dataset: format: npy year: 2019 use_target_data: true - use_aux_data: false + use_aux_data: none seed: ${seed} cache_dir: ${paths.cache_dir} diff --git a/configs/data/butterfly_s2_rgb.yaml b/configs/data/butterfly_s2_rgb.yaml index 47955da..1b4a627 100644 --- a/configs/data/butterfly_s2_rgb.yaml +++ b/configs/data/butterfly_s2_rgb.yaml @@ -12,7 +12,7 @@ dataset: channels: rgb preprocessing: zscored use_target_data: true # default true - use_aux_data: false # default false + use_aux_data: none # default false seed: ${seed} cache_dir: ${paths.cache_dir} diff --git a/configs/data/heat_guatemala.yaml b/configs/data/heat_guatemala.yaml index 271fd65..af8715e 100644 --- a/configs/data/heat_guatemala.yaml +++ b/configs/data/heat_guatemala.yaml @@ -7,7 +7,7 @@ dataset: coords: {} use_target_data: true use_features: true - use_aux_data: false + use_aux_data: none seed: ${seed} cache_dir: ${paths.cache_dir} diff --git a/configs/data/satbird-Kenya_coords.yaml b/configs/data/satbird-Kenya_coords.yaml index c3300a3..9f16ced 100644 --- a/configs/data/satbird-Kenya_coords.yaml +++ b/configs/data/satbird-Kenya_coords.yaml @@ -7,7 +7,7 @@ dataset: modalities: coords: use_target_data: true - use_aux_data: false + use_aux_data: none seed: ${seed} cache_dir: ${paths.cache_dir} study_site: Kenya diff --git a/configs/data/satbird_USA-summer_coords.yaml b/configs/data/satbird_USA-summer_coords.yaml index 683f54a..bf27ec2 100644 --- a/configs/data/satbird_USA-summer_coords.yaml +++ b/configs/data/satbird_USA-summer_coords.yaml @@ -7,7 +7,7 @@ dataset: modalities: coords: use_target_data: true - use_aux_data: false + use_aux_data: none seed: ${seed} cache_dir: ${paths.cache_dir} study_site: USA-summer diff --git a/src/data/base_caption_builder.py b/src/data/base_caption_builder.py index 207c38a..3ca4b77 100644 --- a/src/data/base_caption_builder.py +++ b/src/data/base_caption_builder.py @@ -52,33 +52,53 @@ def _fill(template: str, fillers: Dict[str, str]) -> str: return template @abstractmethod - def _build_from_template(self, template_idx: int, row: List[Any]) -> str: + def _build_from_template( + self, template_idx: int, aux: torch.Tensor, top: List[str] | None = None + ) -> str: """Build caption text from template and row of auxiliary data.""" pass - def random(self, aux_values: List[Any]) -> List[str]: + def random(self, aux_values) -> List[str]: """Return a caption from a randomly sampled template for each data point.""" formatted_rows = [] - template_idx = random.choices( + + batch_size = len(aux_values["aux"]) + + template_ids = random.choices( range(len(self.templates)), - k=len(aux_values), + k=batch_size, ) - for idx, row in zip(template_idx, aux_values): - formatted_rows.append(self._build_from_template(idx, row)) + for ( + i, + template_idx, + ) in enumerate(template_ids): + row_aux = aux_values["aux"][i] + row_top = aux_values.get("top")[i] if aux_values.get("top") else None + formatted_rows.append( + self._build_from_template(template_idx, aux=row_aux, top=row_top) + ) return formatted_rows - def all(self, aux_values: List[Any]) -> List[str]: + def all(self, aux_values) -> List[str]: """Return a list of captions from all available templates.""" formatted_rows = [] - for row in aux_values: + for i in range(0, len(aux_values["aux"])): descriptions = [] + row_aux = aux_values["aux"][i] + row_top = aux_values.get("top")[i] if aux_values.get("top") else None + for template_idx in range(0, len(self)): - descriptions.append(self._build_from_template(template_idx, row)) + descriptions.append( + self._build_from_template(template_idx, aux=row_aux, top=row_top) + ) formatted_rows.append(descriptions) return formatted_rows + def build_concepts(self, aux_values) -> List[str]: + pass + class DummyCaptionBuilder(BaseCaptionBuilder): """Dummy caption builder for testing purposes.""" @@ -89,8 +109,10 @@ def __init__(self, templates_fname: str, data_dir: str, seed: int) -> None: def sync_with_dataset(self, dataset) -> None: pass - def _build_from_template(self, template_idx: int, row: List[Any]) -> str: - first_val = row[0].item() if torch.is_tensor(row) else row[0] + def _build_from_template( + self, template_idx: int, aux: torch.Tensor, top: List[str] | None = None + ) -> str: + first_val = aux[0].item() return f"Location with value {first_val}" diff --git a/src/data/base_dataset.py b/src/data/base_dataset.py index ae669d2..bd727d7 100644 --- a/src/data/base_dataset.py +++ b/src/data/base_dataset.py @@ -1,6 +1,7 @@ import os +import re from abc import ABC, abstractmethod -from typing import Any, Dict, final +from typing import Any, Dict, List, final import numpy as np import pandas as pd @@ -16,7 +17,7 @@ def __init__( data_dir: str, modalities: dict, use_target_data: bool = True, - use_aux_data: bool = False, + use_aux_data: Dict[str, List[str] | str] | str | None = None, dataset_name: str = "BaseDataset", seed: int = 12345, mode: str = "train", @@ -78,10 +79,26 @@ def __init__( self.num_classes = None self.tabular_dim = None self.seed = seed - self.use_target_data: bool = use_target_data - self.use_aux_data: bool = use_aux_data + self.use_target_data = use_target_data self.use_features = use_features + if use_aux_data is None or use_aux_data == "all": + self.use_aux_data = { + "aux": { + "pattern": "^aux_(?!.*top).*", + # 'columns' : [] + }, + "top": { + "pattern": "^aux_.*top.*", + # 'columns' : [] + }, + } + + elif type(use_aux_data) is dict: + self.use_aux_data = use_aux_data + else: + self.use_aux_data = None + self.mode: str = mode # 'train', 'val', 'test' self.records: dict[str, Any] = self.get_records() @@ -107,16 +124,25 @@ def get_records(self) -> dict[str, Any]: ) columns.append(f"{modality}_path") - # Include targets + # Include targets TODO: this could be moved under geo-modalities if self.use_target_data: self.target_names = [c for c in self.df.columns if "target_" in c] columns.extend(self.target_names) self.num_classes = len(self.target_names) # Include aux data - if self.use_aux_data: - self.aux_names = [c for c in self.df.columns if "aux_" in c] - columns.extend(self.aux_names) + if self.use_aux_data is not None: + for k, val in self.use_aux_data.items(): + if val.get("pattern"): + pattern = re.compile(val["pattern"]) + aux_names = [x for x in self.df.columns if pattern.match(x)] + else: + aux_names = val.get( + "columns", + ValueError('use_aux_data should have "pattern" or "columns" defined'), + ) + self.use_aux_data[k] = aux_names + columns.extend(aux_names) # Include tabular features if self.use_features: diff --git a/src/data/butterfly_caption_builder.py b/src/data/butterfly_caption_builder.py index 444b483..66b60a3 100644 --- a/src/data/butterfly_caption_builder.py +++ b/src/data/butterfly_caption_builder.py @@ -1,8 +1,8 @@ -import math import os from typing import Any, List, override import pandas as pd +import torch from src.data.base_caption_builder import ( BaseCaptionBuilder, @@ -27,19 +27,20 @@ def sync_with_dataset(self, dataset: BaseDataset) -> None: corine_columns = self.get_corine_column_keys() humanfootprint_columns = self.get_humanfootprint_column_keys() aux_columns = {**bioclim_columns, **corine_columns, **humanfootprint_columns} - self.column_to_metadata_map = {} + self.column_to_metadata_map = {k: {} for k in dataset.use_aux_data.keys()} - for id, key in enumerate(dataset.aux_names): - if "aux_corine_frac" in key and "top" in key: # to avoid assert statement - description, units = None, None - else: - description, units = aux_columns.get(key) or (None, None) - assert description is not None, f"Key {key} not found in aux columns" - self.column_to_metadata_map[key] = { - "id": id, - "description": description, - "units": units, - } + for aux_cat, cols in dataset.use_aux_data.items(): + for i, c in enumerate(cols): + if "top" in aux_cat: + description, units = None, None + else: + description, units = aux_columns.get(c) or (None, None) + + self.column_to_metadata_map[aux_cat][c] = { + "id": i, + "description": description, + "units": units, + } def get_corine_column_keys(self): """Returns metadata for corine columns.""" @@ -99,7 +100,8 @@ def get_humanfootprint_column_keys(self): def _build_from_template( self, template_idx: int, - row: List[Any], + aux: torch.Tensor, + top: List[str] | None = None, convert_corine_perc: bool = True, ) -> str: """Create caption from template and row of auxiliary data.""" @@ -107,25 +109,19 @@ def _build_from_template( tokens = self.tokens_in_template[template_idx] replacements = {} for token in tokens: - if "aux_corine_frac" in token and "top" in token: - try: - values_dict_top = self.column_to_metadata_map[token] - except KeyError: - raise KeyError( - f"Token {token} not found in column_to_metadata_map {self.column_to_metadata_map}. Check if the token in the template matches the column names in the dataset." - ) - idx_top = values_dict_top["id"] - referral_token = row[ - idx_top - ] # e.g., token 'aux_corine_frac_lowlevel_top_1' might refer to 'corine_frac_211' in this row - referral_token = ( - "aux_" + referral_token if "aux_" not in referral_token else referral_token + init_token = token + if "top" in token: + idx = self.column_to_metadata_map["top"][token]["id"] + token = f"aux_{top[idx]}" + try: + values_dict = self.column_to_metadata_map["aux"][token] + except KeyError: + raise KeyError( + f"Token {token} not found in column_to_metadata_map {self.column_to_metadata_map}. Check if the token in the template matches the column names in the dataset." ) - values_dict = self.column_to_metadata_map[referral_token] - else: - values_dict = self.column_to_metadata_map[token] + idx = values_dict["id"] - value = row[idx] + value = aux[idx].item() formatted_desc = values_dict["description"].lower() or "" units = values_dict["units"] @@ -139,7 +135,7 @@ def _build_from_template( formatted_desc = formatted_desc + f' ({round(value)} {units if units else ""})' else: formatted_desc = formatted_desc + f' of {round(value)} {units if units else ""}' - replacements[token] = formatted_desc + replacements[init_token] = formatted_desc template = self._fill(template, replacements) return template diff --git a/src/data/butterfly_dataset.py b/src/data/butterfly_dataset.py index 4b0344a..422f8c2 100644 --- a/src/data/butterfly_dataset.py +++ b/src/data/butterfly_dataset.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, override +from typing import Any, Dict, List, override import numpy as np import pooch @@ -17,7 +17,7 @@ def __init__( data_dir: str, modalities: dict, use_target_data: bool = True, - use_aux_data: bool = False, + use_aux_data: Dict[str, List[str] | str] | None = None, seed: int = 12345, cache_dir: str = None, mock: bool = False, @@ -165,7 +165,14 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: ) if self.use_aux_data: - formatted_row["aux"] = [row[i] for i in self.aux_names] + formatted_row["aux"] = {} + for aux_cat, vals in self.use_aux_data.items(): + if aux_cat == "aux": + formatted_row["aux"][aux_cat] = torch.tensor( + [row[v] for v in vals], dtype=torch.float32 + ) + else: + formatted_row["aux"][aux_cat] = [row[v] for v in vals] return formatted_row diff --git a/src/data/collate_fns.py b/src/data/collate_fns.py index cce80a3..4e37844 100644 --- a/src/data/collate_fns.py +++ b/src/data/collate_fns.py @@ -5,6 +5,15 @@ from src.data.base_caption_builder import BaseCaptionBuilder +def smart_stack(values): + first = values[0] + + if isinstance(first, (torch.Tensor, int, float)): + return torch.stack(values, dim=0) + + return values + + def collate_fn( batch: List[Any], mode: str = "train", @@ -12,32 +21,28 @@ def collate_fn( ) -> Dict[str, torch.Tensor]: """Collates batch into stacked tensors and label lists.""" - # map of all keys present in the batch - keys = batch[0].keys() - eo_keys = batch[0].get("eo", {}).keys() - collected = {k: ([] if k != "eo" else {k_1: [] for k_1 in eo_keys}) for k in keys} - - # fill-in collected items into batch dict - for item in batch: - for k in keys: - if k == "eo": - for k_1 in eo_keys: - collected[k][k_1].append(item[k][k_1]) - else: - collected[k].append(item[k]) - - # stack tensors - for k in keys: - if k == "eo": - for k_1 in eo_keys: - collected[k][k_1] = torch.stack(collected[k][k_1], dim=0) - elif isinstance(collected[k][0], torch.Tensor): - collected[k] = torch.stack(collected[k], dim=0) + batch_collected = {} + + if "eo" in batch[0]: + batch_collected["eo"] = { + k: torch.stack([item["eo"][k] for item in batch]) for k in batch[0]["eo"].keys() + } + + if batch[0].get("aux") is not None: + batch_collected["aux"] = { + k: smart_stack([item["aux"][k] for item in batch]) for k in batch[0]["aux"].keys() + } + + if batch[0].get("target") is not None: + batch_collected["target"] = smart_stack([item["target"] for item in batch]) # convert aux into captions if mode == "train": - collected["text"] = caption_builder.random(collected["aux"]) + batch_collected["text"] = caption_builder.random(batch_collected["aux"]) + elif mode == "val": + batch_collected["text"] = caption_builder.all(batch_collected["aux"]) else: - collected["text"] = caption_builder.all(collected["aux"]) + batch_collected["text"] = caption_builder.all(batch_collected["aux"]) + # batch_collected['concepts'] = caption_builder.build_concepts(batch_collected["aux"]) - return collected + return batch_collected diff --git a/src/data/heat_guatemala_dataset.py b/src/data/heat_guatemala_dataset.py index 3099b44..c20b686 100644 --- a/src/data/heat_guatemala_dataset.py +++ b/src/data/heat_guatemala_dataset.py @@ -42,7 +42,7 @@ def __init__( data_dir: str, modalities: dict, use_target_data: bool = True, - use_aux_data: bool = False, + use_aux_data: Dict[str, Any] | str = "all", seed: int = 12345, cache_dir: str = None, mock: bool = False, @@ -95,6 +95,13 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: # --- Auxiliary data --- if self.use_aux_data: - sample["aux"] = [row[k] for k in self.aux_names] + sample["aux"] = {} + for aux_cat, vals in self.use_aux_data.items(): + if aux_cat == "aux": + sample["aux"][aux_cat] = torch.tensor( + [row[v] for v in vals], dtype=torch.float32 + ) + else: + sample["aux"][aux_cat] = [row[v] for v in vals] return sample diff --git a/src/data/satbird_dataset.py b/src/data/satbird_dataset.py index 8f96dfb..e26fe51 100644 --- a/src/data/satbird_dataset.py +++ b/src/data/satbird_dataset.py @@ -1,5 +1,5 @@ import os -from typing import override +from typing import Any, Dict, override import torch from rasterio import open as ropen @@ -15,8 +15,8 @@ def __init__( data_dir: str, modalities: dict, use_target_data: bool, - use_aux_data: bool, seed: int, + use_aux_data: Dict[str, Any] | str = "all", study_site: str = "Kenya", cache_dir: str = None, mock: bool = False, @@ -108,7 +108,14 @@ def __getitem__(self, idx): ) if self.use_aux_data: - formatted_row["aux"] = [row[i] for i in self.aux_names] + formatted_row["aux"] = {} + for aux_cat, vals in self.use_aux_data.items(): + if aux_cat == "aux": + formatted_row["aux"][aux_cat] = torch.tensor( + [row[v] for v in vals], dtype=torch.float32 + ) + else: + formatted_row["aux"][aux_cat] = [row[v] for v in vals] return formatted_row diff --git a/tests/conftest.py b/tests/conftest.py index 5e2f1c2..e4e9b2d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -160,7 +160,7 @@ def create_butterfly_dataset(request, sample_csv, tmp_path): cache_dir=str(tmp_path), modalities={"coords": None}, use_target_data=True, - use_aux_data=True, + use_aux_data="all", seed=0, mock=use_mock, ) diff --git a/tests/test_captions.py b/tests/test_captions.py index 0e8c6fa..66b132f 100644 --- a/tests/test_captions.py +++ b/tests/test_captions.py @@ -24,7 +24,7 @@ def test_datamodule_uses_collate_when_aux_data(request, sample_csv, tmp_path): cache_dir=str(tmp_path), modalities={"coords": None}, use_target_data=True, - use_aux_data=True, + use_aux_data="all", seed=0, mock=use_mock, ) diff --git a/tests/test_datasets_and_datamodules.py b/tests/test_datasets_and_datamodules.py index 6fc7bce..2141987 100644 --- a/tests/test_datasets_and_datamodules.py +++ b/tests/test_datasets_and_datamodules.py @@ -19,7 +19,7 @@ def test_datasets_generic_properties(request, tmp_path, sample_csv): cache_dir=str(tmp_path), modalities={"coords": None}, use_target_data=True, - use_aux_data=True, + use_aux_data="all", seed=0, mock=use_mock, ) @@ -38,9 +38,6 @@ def test_datasets_generic_properties(request, tmp_path, sample_csv): assert hasattr( dataset, "target_names" ), f"'target_names' attribute missing in {ds_class.__name__}." - assert hasattr( - dataset, "aux_names" - ), f"'aux_names' attribute missing in {ds_class.__name__}." assert hasattr(dataset, "records"), f"'records' attribute missing in {ds_class.__name__}." assert hasattr( dataset, "dataset_name"