Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion tests/test_datasets_and_datamodules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from src.data.base_dataset import BaseDataset
from src.data.butterfly_dataset import ButterflyDataset
from src.data.heat_guatemala_dataset import HeatGuatemalaDataset
from src.data.satbird_dataset import SatBirdDataset


def test_datasets_generic_properties(request, tmp_path, sample_csv):
"""This test checks that all datasets implement the basic properties and methods."""
list_datasets = [ButterflyDataset, SatBirdDataset]
list_datasets = [ButterflyDataset, SatBirdDataset, HeatGuatemalaDataset]
use_mock = request.config.getoption("--use-mock")
if use_mock:
csv_dir = sample_csv
Expand Down Expand Up @@ -44,6 +46,22 @@ def test_datasets_generic_properties(request, tmp_path, sample_csv):
dataset, "dataset_name"
), f"'dataset_name' attribute missing in {ds_class.__name__}."
assert hasattr(dataset, "mode"), f"'mode' attribute missing in {ds_class.__name__}."
assert hasattr(
dataset, "use_features"
), f"'use_features' attribute missing in {ds_class.__name__}."
assert hasattr(
dataset, "use_aux_data"
), f"'use_aux_data' attribute missing in {ds_class.__name__}."
assert hasattr(
dataset, "use_target_data"
), f"'use_target_data' attribute missing in {ds_class.__name__}."
assert hasattr(
dataset, "tabular_dim"
), f"'tabular_dim' attribute missing in {ds_class.__name__}."
assert hasattr(dataset, "setup"), f"'setup' method missing in {ds_class.__name__}."
assert hasattr(
dataset, "get_records"
), f"'get_records' method missing in {ds_class.__name__}."


def test_datamodule_random_split_and_loaders(create_butterfly_dataset):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_eo_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder
from src.models.components.eo_encoders.cnn_encoder import CNNEncoder
from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder
from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder


# @pytest.mark.slow
Expand All @@ -18,6 +19,7 @@ def test_eo_encoder_generic_properties(create_butterfly_dataset):
"geoclip_coords": GeoClipCoordinateEncoder,
"cnn": CNNEncoder,
"average": AverageEncoder,
"multimodal_coords": MultiModalEncoder,
}
ds, dm = create_butterfly_dataset
batch = next(iter(dm.train_dataloader()))
Expand Down
5 changes: 4 additions & 1 deletion tests/test_pred_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from src.models.components.pred_heads.base_pred_head import BasePredictionHead
from src.models.components.pred_heads.linear_pred_head import LinearPredictionHead
from src.models.components.pred_heads.mlp_pred_head import MLPPredictionHead
from src.models.components.pred_heads.mlp_regression_head import (
MLPRegressionPredictionHead,
)


# @pytest.mark.slow
Expand All @@ -18,7 +21,7 @@ def test_pred_head_generic_properties(create_butterfly_dataset):
eo_encoder = GeoClipCoordinateEncoder()
feats = eo_encoder.forward(batch)

list_pred_heads = [LinearPredictionHead, MLPPredictionHead]
list_pred_heads = [LinearPredictionHead, MLPPredictionHead, MLPRegressionPredictionHead]
for pred_head_class in list_pred_heads:
pred_head = pred_head_class()
assert hasattr(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def test_train_resume(tmp_path: Path, cfg_train: DictConfig) -> None:
metric_dict_2, _ = train(cfg_train)

files = os.listdir(tmp_path / "checkpoints")
assert "epoch_001.ckpt" in files
assert "epoch_002.ckpt" not in files
assert len([x for x in files if x.endswith(".ckpt")]) == 2
assert "last.ckpt" in files

# assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"]
# assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"]