diff --git a/tests/test_datasets_and_datamodules.py b/tests/test_datasets_and_datamodules.py index b6d0f29..6fc7bce 100644 --- a/tests/test_datasets_and_datamodules.py +++ b/tests/test_datasets_and_datamodules.py @@ -1,10 +1,12 @@ +from src.data.base_dataset import BaseDataset from src.data.butterfly_dataset import ButterflyDataset +from src.data.heat_guatemala_dataset import HeatGuatemalaDataset from src.data.satbird_dataset import SatBirdDataset def test_datasets_generic_properties(request, tmp_path, sample_csv): """This test checks that all datasets implement the basic properties and methods.""" - list_datasets = [ButterflyDataset, SatBirdDataset] + list_datasets = [ButterflyDataset, SatBirdDataset, HeatGuatemalaDataset] use_mock = request.config.getoption("--use-mock") if use_mock: csv_dir = sample_csv @@ -44,6 +46,22 @@ def test_datasets_generic_properties(request, tmp_path, sample_csv): dataset, "dataset_name" ), f"'dataset_name' attribute missing in {ds_class.__name__}." assert hasattr(dataset, "mode"), f"'mode' attribute missing in {ds_class.__name__}." + assert hasattr( + dataset, "use_features" + ), f"'use_features' attribute missing in {ds_class.__name__}." + assert hasattr( + dataset, "use_aux_data" + ), f"'use_aux_data' attribute missing in {ds_class.__name__}." + assert hasattr( + dataset, "use_target_data" + ), f"'use_target_data' attribute missing in {ds_class.__name__}." + assert hasattr( + dataset, "tabular_dim" + ), f"'tabular_dim' attribute missing in {ds_class.__name__}." + assert hasattr(dataset, "setup"), f"'setup' method missing in {ds_class.__name__}." + assert hasattr( + dataset, "get_records" + ), f"'get_records' method missing in {ds_class.__name__}." def test_datamodule_random_split_and_loaders(create_butterfly_dataset): diff --git a/tests/test_eo_encoders.py b/tests/test_eo_encoders.py index e91d9da..635919d 100644 --- a/tests/test_eo_encoders.py +++ b/tests/test_eo_encoders.py @@ -9,6 +9,7 @@ from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder from src.models.components.eo_encoders.cnn_encoder import CNNEncoder from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder +from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder # @pytest.mark.slow @@ -18,6 +19,7 @@ def test_eo_encoder_generic_properties(create_butterfly_dataset): "geoclip_coords": GeoClipCoordinateEncoder, "cnn": CNNEncoder, "average": AverageEncoder, + "multimodal_coords": MultiModalEncoder, } ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) diff --git a/tests/test_pred_heads.py b/tests/test_pred_heads.py index 13809fa..09f8a9a 100644 --- a/tests/test_pred_heads.py +++ b/tests/test_pred_heads.py @@ -9,6 +9,9 @@ from src.models.components.pred_heads.base_pred_head import BasePredictionHead from src.models.components.pred_heads.linear_pred_head import LinearPredictionHead from src.models.components.pred_heads.mlp_pred_head import MLPPredictionHead +from src.models.components.pred_heads.mlp_regression_head import ( + MLPRegressionPredictionHead, +) # @pytest.mark.slow @@ -18,7 +21,7 @@ def test_pred_head_generic_properties(create_butterfly_dataset): eo_encoder = GeoClipCoordinateEncoder() feats = eo_encoder.forward(batch) - list_pred_heads = [LinearPredictionHead, MLPPredictionHead] + list_pred_heads = [LinearPredictionHead, MLPPredictionHead, MLPRegressionPredictionHead] for pred_head_class in list_pred_heads: pred_head = pred_head_class() assert hasattr( diff --git a/tests/test_train.py b/tests/test_train.py index d4e0333..a60a1e6 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -103,8 +103,8 @@ def test_train_resume(tmp_path: Path, cfg_train: DictConfig) -> None: metric_dict_2, _ = train(cfg_train) files = os.listdir(tmp_path / "checkpoints") - assert "epoch_001.ckpt" in files - assert "epoch_002.ckpt" not in files + assert len([x for x in files if x.endswith(".ckpt")]) == 2 + assert "last.ckpt" in files # assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"] # assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"]