diff --git a/ctlearn/tools/predict_model.py b/ctlearn/tools/predict_model.py index 562bf2ff..bd933021 100644 --- a/ctlearn/tools/predict_model.py +++ b/ctlearn/tools/predict_model.py @@ -473,6 +473,7 @@ def _predict_with_model(self, model_path): prediction_colname = ( "type" if isinstance(model.layers[-1], keras.layers.Softmax) else model.layers[-1].name ) + self.log.info("Checking column name for prediction : %s", prediction_colname) backbone_model, feature_vectors = None, None if self.dl1_features: # Get the backbone model which is the second layer of the model @@ -525,21 +526,23 @@ def _predict_with_model(self, model_path): # which returns the probabilities for each class in an array, while # the regression tasks have output neurons which returns the # predicted value for the task in a dictionary. - if prediction_colname == "type": - predict_data = Table({prediction_colname: predict_data}) + if isinstance(predict_data, dict) and prediction_colname in predict_data: + predict_data = Table({prediction_colname: predict_data[prediction_colname]}) else: - predict_data = Table(predict_data) + predict_data = Table({prediction_colname: predict_data}) # Predict the last batch and stack the results to the prediction data if data_loader_last_batch is not None: predict_data_last_batch = model.predict( data_loader_last_batch, verbose=self.keras_verbose ) - if model.layers[-1].name == "type": + if isinstance(predict_data_last_batch, dict) and prediction_colname in predict_data_last_batch: predict_data_last_batch = Table( - {prediction_colname: predict_data_last_batch} + {prediction_colname: predict_data_last_batch[prediction_colname]} ) else: - predict_data_last_batch = Table(predict_data_last_batch) + predict_data_last_batch = Table( + {prediction_colname: predict_data_last_batch} + ) predict_data = vstack([predict_data, predict_data_last_batch]) return predict_data, feature_vectors