diff --git a/training_pipeline/predict.py b/training_pipeline/predict.py index 27886f7..de623c8 100644 --- a/training_pipeline/predict.py +++ b/training_pipeline/predict.py @@ -309,8 +309,7 @@ def main(): print(f"Checkpoint path: {checkpoint_path}") print(f"Processing {len(valid_images)} images...") - predictions = predict_batch(model, valid_images, device, args.threshold, - checkpoint_path, resume) + processed_images, predictions = predict_batch(model, valid_images, device, args.threshold, checkpoint_path, resume) # Save predictions for i, (image_path, pred_mask) in enumerate(zip(valid_images, predictions)):