From 8f327e55cb667cc95a55d4dd98d982b24d964341 Mon Sep 17 00:00:00 2001 From: Keerthi Reddy Date: Tue, 8 Apr 2025 11:54:11 +0530 Subject: [PATCH 1/2] dont run on_validation_epoch_end if val-accuracy-interval is greater than max epochs. --- src/deepforest/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 44a08dce4..524e1da41 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -806,7 +806,10 @@ def on_validation_epoch_end(self): self.predictions_df = pd.concat(self.predictions) #Evaluate every n epochs - if self.current_epoch % self.config["validation"]["val_accuracy_interval"] == 0: + if (self.config["validation"]["val_accuracy_interval"] + <= self.config["train"]["epochs"] and + self.current_epoch % self.config["validation"]["val_accuracy_interval"] + == 0): #Create a geospatial column ground_df = utilities.read_file(self.config["validation"]["csv_file"]) ground_df["label"] = ground_df.label.apply(lambda x: self.label_dict[x]) From c0557b0dee37eb84ea52a6742744779e928f29c6 Mon Sep 17 00:00:00 2001 From: Keerthi Reddy Date: Mon, 14 Apr 2025 11:06:39 +0530 Subject: [PATCH 2/2] Added test --- tests/test_main.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index d466e3b2c..151648206 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -970,3 +970,14 @@ def test_set_labels_invalid_length(m): # Expect a ValueError when setting an inv invalid_mapping = {"Object": 0, "Extra": 1} with pytest.raises(ValueError): m.set_labels(invalid_mapping) + +def test_validation_interval_greater_than_epochs(m): + # Set interval higher than max_epochs to disable evaluation + m.config["validation"]["val_accuracy_interval"] = 3 + m.config["train"]["epochs"] = 2 + m.create_trainer() + m.trainer.fit(m) + + assert "box_precision" not in m.trainer.logged_metrics + assert "box_recall" not in m.trainer.logged_metrics + assert "empty_frame_accuracy" not in m.trainer.logged_metrics