diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 44a08dce4..524e1da41 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -806,7 +806,10 @@ def on_validation_epoch_end(self): self.predictions_df = pd.concat(self.predictions) #Evaluate every n epochs - if self.current_epoch % self.config["validation"]["val_accuracy_interval"] == 0: + if (self.config["validation"]["val_accuracy_interval"] + <= self.config["train"]["epochs"] and + self.current_epoch % self.config["validation"]["val_accuracy_interval"] + == 0): #Create a geospatial column ground_df = utilities.read_file(self.config["validation"]["csv_file"]) ground_df["label"] = ground_df.label.apply(lambda x: self.label_dict[x]) diff --git a/tests/test_main.py b/tests/test_main.py index d466e3b2c..151648206 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -970,3 +970,14 @@ def test_set_labels_invalid_length(m): # Expect a ValueError when setting an inv invalid_mapping = {"Object": 0, "Extra": 1} with pytest.raises(ValueError): m.set_labels(invalid_mapping) + +def test_validation_interval_greater_than_epochs(m): + # Set interval higher than max_epochs to disable evaluation + m.config["validation"]["val_accuracy_interval"] = 3 + m.config["train"]["epochs"] = 2 + m.create_trainer() + m.trainer.fit(m) + + assert "box_precision" not in m.trainer.logged_metrics + assert "box_recall" not in m.trainer.logged_metrics + assert "empty_frame_accuracy" not in m.trainer.logged_metrics