Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions ctlearn/tools/predict_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def _predict_with_model(self, model_path):
prediction_colname = (
"type" if isinstance(model.layers[-1], keras.layers.Softmax) else model.layers[-1].name
)
self.log.info("Checking column name for prediction : %s", prediction_colname)
backbone_model, feature_vectors = None, None
if self.dl1_features:
# Get the backbone model which is the second layer of the model
Expand Down Expand Up @@ -525,21 +526,23 @@ def _predict_with_model(self, model_path):
# which returns the probabilities for each class in an array, while
# the regression tasks have output neurons which returns the
# predicted value for the task in a dictionary.
if prediction_colname == "type":
predict_data = Table({prediction_colname: predict_data})
if isinstance(predict_data, dict) and prediction_colname in predict_data:
predict_data = Table({prediction_colname: predict_data[prediction_colname]})
else:
predict_data = Table(predict_data)
predict_data = Table({prediction_colname: predict_data})
# Predict the last batch and stack the results to the prediction data
if data_loader_last_batch is not None:
predict_data_last_batch = model.predict(
data_loader_last_batch, verbose=self.keras_verbose
)
if model.layers[-1].name == "type":
if isinstance(predict_data_last_batch, dict) and prediction_colname in predict_data_last_batch:
predict_data_last_batch = Table(
{prediction_colname: predict_data_last_batch}
{prediction_colname: predict_data_last_batch[prediction_colname]}
)
else:
predict_data_last_batch = Table(predict_data_last_batch)
predict_data_last_batch = Table(
{prediction_colname: predict_data_last_batch}
)
predict_data = vstack([predict_data, predict_data_last_batch])
return predict_data, feature_vectors

Expand Down