From e3f8afaed7a8037adc3a55dc8cc4362248acaa13 Mon Sep 17 00:00:00 2001 From: "hugo.varenne" Date: Mon, 20 Oct 2025 14:15:54 +0200 Subject: [PATCH] Fix: Issue 261 - Prediction irregularity format for LoadedModel --- ctlearn/tools/predict_model.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ctlearn/tools/predict_model.py b/ctlearn/tools/predict_model.py index 562bf2ff..bd933021 100644 --- a/ctlearn/tools/predict_model.py +++ b/ctlearn/tools/predict_model.py @@ -473,6 +473,7 @@ def _predict_with_model(self, model_path): prediction_colname = ( "type" if isinstance(model.layers[-1], keras.layers.Softmax) else model.layers[-1].name ) + self.log.info("Checking column name for prediction : %s", prediction_colname) backbone_model, feature_vectors = None, None if self.dl1_features: # Get the backbone model which is the second layer of the model @@ -525,21 +526,23 @@ def _predict_with_model(self, model_path): # which returns the probabilities for each class in an array, while # the regression tasks have output neurons which returns the # predicted value for the task in a dictionary. - if prediction_colname == "type": - predict_data = Table({prediction_colname: predict_data}) + if isinstance(predict_data, dict) and prediction_colname in predict_data: + predict_data = Table({prediction_colname: predict_data[prediction_colname]}) else: - predict_data = Table(predict_data) + predict_data = Table({prediction_colname: predict_data}) # Predict the last batch and stack the results to the prediction data if data_loader_last_batch is not None: predict_data_last_batch = model.predict( data_loader_last_batch, verbose=self.keras_verbose ) - if model.layers[-1].name == "type": + if isinstance(predict_data_last_batch, dict) and prediction_colname in predict_data_last_batch: predict_data_last_batch = Table( - {prediction_colname: predict_data_last_batch} + {prediction_colname: predict_data_last_batch[prediction_colname]} ) else: - predict_data_last_batch = Table(predict_data_last_batch) + predict_data_last_batch = Table( + {prediction_colname: predict_data_last_batch} + ) predict_data = vstack([predict_data, predict_data_last_batch]) return predict_data, feature_vectors