From 80467eb7b564a07d741f6d4ad14113070084d4c1 Mon Sep 17 00:00:00 2001 From: bw4sz Date: Tue, 21 Apr 2026 09:32:08 -0700 Subject: [PATCH 1/2] Update bird detector --- docs/user_guide/02_prebuilt.md | 30 +- src/deepforest/scripts/compare_bird_models.py | 598 +++++++++++++++ .../scripts/evaluate_deepwater_horizon.py | 527 +++++++++++++ .../evaluate_patch_size_sensitivity.py | 291 ++++++++ src/deepforest/scripts/evaluate_thresholds.py | 190 +++++ src/deepforest/scripts/prepare_birds.py | 691 ++++++++++++++++++ .../scripts/push_bird_model_to_hf.py | 86 +++ src/deepforest/scripts/submit_train_birds.sh | 24 + src/deepforest/scripts/train_birds.py | 266 +++++++ 9 files changed, 2680 insertions(+), 23 deletions(-) create mode 100644 src/deepforest/scripts/compare_bird_models.py create mode 100644 src/deepforest/scripts/evaluate_deepwater_horizon.py create mode 100644 src/deepforest/scripts/evaluate_patch_size_sensitivity.py create mode 100644 src/deepforest/scripts/evaluate_thresholds.py create mode 100644 src/deepforest/scripts/prepare_birds.py create mode 100644 src/deepforest/scripts/push_bird_model_to_hf.py create mode 100755 src/deepforest/scripts/submit_train_birds.sh create mode 100644 src/deepforest/scripts/train_birds.py diff --git a/docs/user_guide/02_prebuilt.md b/docs/user_guide/02_prebuilt.md index 109d2e7ab..3f423944b 100644 --- a/docs/user_guide/02_prebuilt.md +++ b/docs/user_guide/02_prebuilt.md @@ -34,6 +34,13 @@ The model was initially described in [Ecological Applications](https://esajourna Using over 250,000 annotations from 13 projects from around the world, we develop a general bird detection model that achieves over 65% recall and 50% precision on novel aerial data without any local training despite differences in species, habitat, and imaging methodology. Fine-tuning this model with only 1000 local annotations increases these values to an average of 84% recall and 69% precision by building on the general features learned from other data sources. > +The bird detection model has been updated and retrained from the original `weecology/deepforest-bird` model. The updated model was fine-tuned starting from the tree detection model (`weecology/deepforest-tree`) and trained on data from both Weinstein et al. 2022 as well as new additional bird detection data from multiple sources including https://lila.science/. The result is a dataset with over a million bird detections from around the world. Training details and metrics can be viewed on the [Comet dashboard](https://www.comet.com/bw4sz/bird-detector/6181df1ab7ac40f291b863a2a9b86024?&prevPath=%2Fbw4sz%2Fbird-detector%2Fview%2Fnew%2Fexperiments). + +### Example Predictions + +The following examples show predictions from the updated bird detection model: + +![Bird Prediction Example 1](../figures/bird_prediction_example_1.png) ### Citation > Weinstein, B.G., Garner, L., Saccomanno, V.R., Steinkraus, A., Ortega, A., Brush, K., Yenni, G., McKellar, A.E., Converse, R., Lippitt, C.D., Wegmann, A., Holmes, N.D., Edney, A.J., Hart, T., Jessopp, M.J., Clarke, R.H., Marchowski, D., Senyondo, H., Dotson, R., White, E.P., Frederick, P. and Ernest, S.K.M. (2022), A general deep learning model for bird detection in high resolution airborne imagery. Ecological Applications. Accepted Author Manuscript e2694. https://doi-org.lp.hscl.ufl.edu/10.1002/eap.2694 @@ -122,29 +129,6 @@ Table S1 Confusion matrix for the Alive/Dead model in Weinstein et al. 2023 Citation: Weinstein, Ben G., et al. "Capturing long‐tailed individual tree diversity using an airborne imaging and a multi‐temporal hierarchical model." Remote Sensing in Ecology and Conservation 9.5 (2023): 656-670. -### NEON Tree Species and Genus Classification - -Two ResNet-18 crop classifiers trained on RGB crown images from the National Ecological Observatory Network (NEON). The training data includes deduplicated hand-annotated tree crowns from 29 NEON sites across the US. - -- **Species model**: 148 species classes, trained on ~16k deduplicated crown crops. HuggingFace repo: `weecology/cropmodel-tree-species` -- **Genus model**: 54 genus classes, same training data aggregated to genus level. HuggingFace repo: `weecology/cropmodel-tree-genus` - -Both models use a torchvision ResNet-18 backbone pretrained on ImageNet and fine-tuned on NEON RGB data. Input images are resized to 224x224 using nearest-neighbor interpolation (`resize_interpolation: nearest` in the model config) and normalized with standard ImageNet statistics. The interpolation mode is loaded automatically from the HuggingFace config — no user action required. - -```python -from deepforest.model import CropModel - -# Load the species classifier -species_model = CropModel.load_model("weecology/cropmodel-tree-species") - -# Load the genus classifier -genus_model = CropModel.load_model("weecology/cropmodel-tree-genus") -``` - -Use these as a second stage after tree crown detection: detect crowns with a DeepForest model, then classify each crop. - -For more details on the training data and code, see [NeonTreeClassification](https://github.com/GatorSense/NeonTreeClassification). - ## Want more pretrained models? Please consider contributing your data to open source repositories, such as zenodo or lila.science. The more data we gather, the more we can combine the annotation and data collection efforts of hundreds of researchers to built models available to everyone. We welcome suggestions on what models and data are most urgently [needed](https://github.com/weecology/DeepForest/discussions). diff --git a/src/deepforest/scripts/compare_bird_models.py b/src/deepforest/scripts/compare_bird_models.py new file mode 100644 index 000000000..4c0da718f --- /dev/null +++ b/src/deepforest/scripts/compare_bird_models.py @@ -0,0 +1,598 @@ +"""Compare retrained bird model checkpoint with pretrained +weecology/deepforest-bird model. + +This script evaluates both models on the same test dataset and prints a comparison +of performance metrics. It also evaluates the checkpoint model at multiple score +thresholds and generates a precision-recall curve. + +Example usage: + python compare_bird_models.py --checkpoint_path /path/to/checkpoint.ckpt --data_dir /path/to/data +""" + +import argparse +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from deepforest import main + + +def compare_models(checkpoint_path, data_dir, iou_threshold=0.4): + """Compare checkpoint model with pretrained weecology/deepforest-bird + model. + + Args: + checkpoint_path: Path to the checkpoint file + data_dir: Directory containing test.csv and images + iou_threshold: IoU threshold for evaluation (default: 0.4) + + Returns: + dict: Dictionary containing results for both models + """ + test_csv = os.path.join(data_dir, "test.csv") + + # Read test set and make a tiny subset of 100 images + test_df = pd.read_csv(test_csv) + test_df = test_df[test_df.image_path.str.startswith("BDA")].head(10) + test_csv = os.path.join(data_dir, "test_subset.csv") + test_df.to_csv(test_csv, index=False) + + print("=" * 80) + print("Bird Detection Model Comparison") + print("=" * 80) + print(f"\nTest dataset: {test_csv}") + print(f"IoU threshold: {iou_threshold}\n") + + results = {} + + # Evaluate checkpoint model + print("-" * 80) + print("Evaluating retrained checkpoint model...") + print(f"Checkpoint: {checkpoint_path}") + print("-" * 80) + checkpoint_model = main.deepforest.load_from_checkpoint(checkpoint_path) + checkpoint_model.config.score_thresh = 0.25 + checkpoint_model.model.score_thresh = 0.25 + + # Set up validation configuration + checkpoint_model.config.validation.csv_file = test_csv + checkpoint_model.config.validation.root_dir = data_dir + checkpoint_model.config.validation.iou_threshold = iou_threshold + checkpoint_model.config.validation.val_accuracy_interval = 1 + checkpoint_model.create_trainer() + + # Evaluate using trainer.validate() + print("\n1. trainer.validate() results:") + validation_results = checkpoint_model.trainer.validate(checkpoint_model) + checkpoint_validate = validation_results[0] if validation_results else {} + results["checkpoint_validate"] = checkpoint_validate + + checkpoint_precision_validate = checkpoint_validate.get("box_precision") + checkpoint_recall_validate = checkpoint_validate.get("box_recall") + if checkpoint_precision_validate is not None: + print(f" Box Precision: {checkpoint_precision_validate:.4f}") + else: + print(" Box Precision: N/A") + if checkpoint_recall_validate is not None: + print(f" Box Recall: {checkpoint_recall_validate:.4f}") + else: + print(" Box Recall: N/A") + print( + f" Empty Frame Accuracy: {checkpoint_validate.get('empty_frame_accuracy', 'N/A')}" + ) + + # Evaluate using main.evaluate() + print("\n2. main.evaluate() results:") + checkpoint_evaluate = checkpoint_model.evaluate( + csv_file=test_csv, + root_dir=data_dir, + iou_threshold=iou_threshold, + ) + results["checkpoint_evaluate"] = checkpoint_evaluate + + checkpoint_precision_evaluate = checkpoint_evaluate.get("box_precision") + checkpoint_recall_evaluate = checkpoint_evaluate.get("box_recall") + if checkpoint_precision_evaluate is not None: + print(f" Box Precision: {checkpoint_precision_evaluate:.4f}") + else: + print(" Box Precision: N/A") + if checkpoint_recall_evaluate is not None: + print(f" Box Recall: {checkpoint_recall_evaluate:.4f}") + else: + print(" Box Recall: N/A") + print( + f" Empty Frame Accuracy: {checkpoint_evaluate.get('empty_frame_accuracy', 'N/A')}" + ) + + # Store both for backward compatibility + results["checkpoint"] = checkpoint_validate + + # Evaluate pretrained model + print("\n" + "-" * 80) + print("Evaluating pretrained weecology/deepforest-bird model...") + print("-" * 80) + pretrained_model = main.deepforest() + pretrained_model.load_model("weecology/deepforest-bird") + pretrained_model.config.score_thresh = 0.25 + pretrained_model.model.score_thresh = 0.25 + + # Set label dictionaries to match + pretrained_model.label_dict = {"Bird": 0} + pretrained_model.numeric_to_label_dict = {0: "Bird"} + pretrained_model.config.label_dict = {"Bird": 0} + pretrained_model.config.num_classes = 1 + + # Set up validation configuration + pretrained_model.config.validation.csv_file = test_csv + pretrained_model.config.validation.root_dir = data_dir + pretrained_model.config.validation.iou_threshold = iou_threshold + pretrained_model.config.validation.val_accuracy_interval = 1 + pretrained_model.create_trainer() + + # Evaluate using trainer.validate() + print("\n1. trainer.validate() results:") + validation_results = pretrained_model.trainer.validate(pretrained_model) + pretrained_validate = validation_results[0] if validation_results else {} + results["pretrained_validate"] = pretrained_validate + + pretrained_precision_validate = pretrained_validate.get("box_precision") + pretrained_recall_validate = pretrained_validate.get("box_recall") + if pretrained_precision_validate is not None: + print(f" Box Precision: {pretrained_precision_validate:.4f}") + else: + print(" Box Precision: N/A") + if pretrained_recall_validate is not None: + print(f" Box Recall: {pretrained_recall_validate:.4f}") + else: + print(" Box Recall: N/A") + print( + f" Empty Frame Accuracy: {pretrained_validate.get('empty_frame_accuracy', 'N/A')}" + ) + + # Evaluate using main.evaluate() + print("\n2. main.evaluate() results:") + pretrained_evaluate = pretrained_model.evaluate( + csv_file=test_csv, + root_dir=data_dir, + iou_threshold=iou_threshold, + ) + results["pretrained_evaluate"] = pretrained_evaluate + + pretrained_precision_evaluate = pretrained_evaluate.get("box_precision") + pretrained_recall_evaluate = pretrained_evaluate.get("box_recall") + if pretrained_precision_evaluate is not None: + print(f" Box Precision: {pretrained_precision_evaluate:.4f}") + else: + print(" Box Precision: N/A") + if pretrained_recall_evaluate is not None: + print(f" Box Recall: {pretrained_recall_evaluate:.4f}") + else: + print(" Box Recall: N/A") + print( + f" Empty Frame Accuracy: {pretrained_evaluate.get('empty_frame_accuracy', 'N/A')}" + ) + + # Store both for backward compatibility + results["pretrained"] = pretrained_validate + + # Print comparison + print("\n" + "=" * 80) + print("COMPARISON SUMMARY") + print("=" * 80) + + # Comparison using trainer.validate() results + print("\n" + "-" * 80) + print("Using trainer.validate() results:") + print("-" * 80) + + checkpoint_precision_validate = checkpoint_validate.get("box_precision") + checkpoint_recall_validate = checkpoint_validate.get("box_recall") + pretrained_precision_validate = pretrained_validate.get("box_precision") + pretrained_recall_validate = pretrained_validate.get("box_recall") + + if ( + checkpoint_precision_validate is not None + and pretrained_precision_validate is not None + ): + precision_diff = checkpoint_precision_validate - pretrained_precision_validate + print("\nBox Precision:") + print(f" Checkpoint: {checkpoint_precision_validate:.4f}") + print(f" Pretrained: {pretrained_precision_validate:.4f}") + if pretrained_precision_validate != 0: + print( + f" Difference: {precision_diff:+.4f} ({precision_diff / pretrained_precision_validate * 100:+.2f}%)" + ) + else: + print(f" Difference: {precision_diff:+.4f} (N/A%)") + else: + print("\nBox Precision: Unable to compute (missing values)") + + if checkpoint_recall_validate is not None and pretrained_recall_validate is not None: + recall_diff = checkpoint_recall_validate - pretrained_recall_validate + print("\nBox Recall:") + print(f" Checkpoint: {checkpoint_recall_validate:.4f}") + print(f" Pretrained: {pretrained_recall_validate:.4f}") + if pretrained_recall_validate != 0: + print( + f" Difference: {recall_diff:+.4f} ({recall_diff / pretrained_recall_validate * 100:+.2f}%)" + ) + else: + print(f" Difference: {recall_diff:+.4f} (N/A%)") + else: + print("\nBox Recall: Unable to compute (missing values)") + + if ( + "empty_frame_accuracy" in checkpoint_validate + and "empty_frame_accuracy" in pretrained_validate + ): + checkpoint_empty = checkpoint_validate["empty_frame_accuracy"] + pretrained_empty = pretrained_validate["empty_frame_accuracy"] + if checkpoint_empty is not None and pretrained_empty is not None: + empty_diff = checkpoint_empty - pretrained_empty + print("\nEmpty Frame Accuracy:") + print(f" Checkpoint: {checkpoint_empty:.4f}") + print(f" Pretrained: {pretrained_empty:.4f}") + print(f" Difference: {empty_diff:+.4f}") + else: + print("\nEmpty Frame Accuracy: Unable to compute (missing values)") + print(f" Checkpoint: {checkpoint_empty}") + print(f" Pretrained: {pretrained_empty}") + + # Comparison using main.evaluate() results + print("\n" + "-" * 80) + print("Using main.evaluate() results:") + print("-" * 80) + + checkpoint_precision_evaluate = checkpoint_evaluate.get("box_precision") + checkpoint_recall_evaluate = checkpoint_evaluate.get("box_recall") + pretrained_precision_evaluate = pretrained_evaluate.get("box_precision") + pretrained_recall_evaluate = pretrained_evaluate.get("box_recall") + + if ( + checkpoint_precision_evaluate is not None + and pretrained_precision_evaluate is not None + ): + precision_diff = checkpoint_precision_evaluate - pretrained_precision_evaluate + print("\nBox Precision:") + print(f" Checkpoint: {checkpoint_precision_evaluate:.4f}") + print(f" Pretrained: {pretrained_precision_evaluate:.4f}") + if pretrained_precision_evaluate != 0: + print( + f" Difference: {precision_diff:+.4f} ({precision_diff / pretrained_precision_evaluate * 100:+.2f}%)" + ) + else: + print(f" Difference: {precision_diff:+.4f} (N/A%)") + else: + print("\nBox Precision: Unable to compute (missing values)") + + if checkpoint_recall_evaluate is not None and pretrained_recall_evaluate is not None: + recall_diff = checkpoint_recall_evaluate - pretrained_recall_evaluate + print("\nBox Recall:") + print(f" Checkpoint: {checkpoint_recall_evaluate:.4f}") + print(f" Pretrained: {pretrained_recall_evaluate:.4f}") + if pretrained_recall_evaluate != 0: + print( + f" Difference: {recall_diff:+.4f} ({recall_diff / pretrained_recall_evaluate * 100:+.2f}%)" + ) + else: + print(f" Difference: {recall_diff:+.4f} (N/A%)") + else: + print("\nBox Recall: Unable to compute (missing values)") + + if ( + "empty_frame_accuracy" in checkpoint_evaluate + and "empty_frame_accuracy" in pretrained_evaluate + ): + checkpoint_empty = checkpoint_evaluate["empty_frame_accuracy"] + pretrained_empty = pretrained_evaluate["empty_frame_accuracy"] + if checkpoint_empty is not None and pretrained_empty is not None: + empty_diff = checkpoint_empty - pretrained_empty + print("\nEmpty Frame Accuracy:") + print(f" Checkpoint: {checkpoint_empty:.4f}") + print(f" Pretrained: {pretrained_empty:.4f}") + print(f" Difference: {empty_diff:+.4f}") + else: + print("\nEmpty Frame Accuracy: Unable to compute (missing values)") + print(f" Checkpoint: {checkpoint_empty}") + print(f" Pretrained: {pretrained_empty}") + + print("\n" + "=" * 80) + + return results + + +def evaluate_multiple_thresholds( + checkpoint_path, data_dir, iou_threshold=0.4, thresholds=None, output_path=None +): + """Evaluate checkpoint model at multiple score thresholds. + + Args: + checkpoint_path: Path to the checkpoint file + data_dir: Directory containing test.csv and images + iou_threshold: IoU threshold for evaluation (default: 0.4) + thresholds: List of score thresholds to evaluate (default: 0.1 to 0.5 in 0.05 steps) + output_path: Path to save the plot (default: data_dir/precision_recall_curve.png) + + Returns: + dict: Dictionary with thresholds, precision, and recall arrays + """ + if thresholds is None: + thresholds = np.arange(0.1, 0.55, 0.05).round(2).tolist() + + test_csv = os.path.join(data_dir, "test.csv") + + if not os.path.exists(test_csv): + raise FileNotFoundError(f"Test CSV not found: {test_csv}") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + print("=" * 80) + print("Evaluating Checkpoint Model at Multiple Score Thresholds") + print("=" * 80) + print(f"\nTest dataset: {test_csv}") + print(f"IoU threshold: {iou_threshold}") + print(f"Score thresholds: {thresholds}\n") + + # Load model once + print("Loading checkpoint model...") + model = main.deepforest.load_from_checkpoint(checkpoint_path) + + precision_scores_validate = [] + recall_scores_validate = [] + precision_scores_evaluate = [] + recall_scores_evaluate = [] + + # Set up validation configuration once + model.config.validation.csv_file = test_csv + model.config.validation.root_dir = data_dir + model.config.validation.iou_threshold = iou_threshold + model.config.validation.val_accuracy_interval = 1 + + print("\nEvaluating at each threshold:") + print("-" * 80) + for i, threshold in enumerate(thresholds): + print( + f"\n[{i + 1}/{len(thresholds)}] Evaluating at score threshold: {threshold:.2f}" + ) + model.config.score_thresh = threshold + model.model.score_thresh = threshold + + # Evaluate using trainer.validate() + model.create_trainer() + validation_results = model.trainer.validate(model) + validate_results = validation_results[0] if validation_results else {} + + precision_validate = validate_results.get("box_precision", 0.0) + recall_validate = validate_results.get("box_recall", 0.0) + + precision_scores_validate.append(precision_validate) + recall_scores_validate.append(recall_validate) + + print( + f" trainer.validate() - Precision: {precision_validate:.4f}, Recall: {recall_validate:.4f}" + ) + + # Evaluate using main.evaluate() + evaluate_results = model.evaluate( + csv_file=test_csv, + root_dir=data_dir, + iou_threshold=iou_threshold, + ) + + precision_evaluate = evaluate_results.get("box_precision", 0.0) + recall_evaluate = evaluate_results.get("box_recall", 0.0) + + precision_scores_evaluate.append(precision_evaluate) + recall_scores_evaluate.append(recall_evaluate) + + print( + f" main.evaluate() - Precision: {precision_evaluate:.4f}, Recall: {recall_evaluate:.4f}" + ) + + # Create results dictionary + threshold_results = { + "thresholds": thresholds, + "precision_validate": precision_scores_validate, + "recall_validate": recall_scores_validate, + "precision_evaluate": precision_scores_evaluate, + "recall_evaluate": recall_scores_evaluate, + } + + # Print summary table + print("\n" + "=" * 80) + print("SUMMARY TABLE - trainer.validate()") + print("=" * 80) + print(f"\n{'Threshold':<12} {'Precision':<12} {'Recall':<12}") + print("-" * 40) + for thresh, prec, rec in zip( + thresholds, precision_scores_validate, recall_scores_validate, strict=True + ): + print(f"{thresh:<12.2f} {prec:<12.4f} {rec:<12.4f}") + + print("\n" + "=" * 80) + print("SUMMARY TABLE - main.evaluate()") + print("=" * 80) + print(f"\n{'Threshold':<12} {'Precision':<12} {'Recall':<12}") + print("-" * 40) + for thresh, prec, rec in zip( + thresholds, precision_scores_evaluate, recall_scores_evaluate, strict=True + ): + print(f"{thresh:<12.2f} {prec:<12.4f} {rec:<12.4f}") + + # Generate plot + if output_path is None: + output_path = os.path.join(data_dir, "precision_recall_curve.png") + + print(f"\nGenerating plot: {output_path}") + plt.figure(figsize=(14, 8)) + + # Plot trainer.validate() results + plt.plot( + thresholds, + precision_scores_validate, + "o-", + label="Precision (trainer.validate())", + linewidth=2, + markersize=8, + color="blue", + ) + plt.plot( + thresholds, + recall_scores_validate, + "s-", + label="Recall (trainer.validate())", + linewidth=2, + markersize=8, + color="blue", + linestyle="--", + ) + + # Plot main.evaluate() results + plt.plot( + thresholds, + precision_scores_evaluate, + "o-", + label="Precision (main.evaluate())", + linewidth=2, + markersize=8, + color="red", + ) + plt.plot( + thresholds, + recall_scores_evaluate, + "s-", + label="Recall (main.evaluate())", + linewidth=2, + markersize=8, + color="red", + linestyle="--", + ) + + plt.xlabel("Score Threshold", fontsize=12) + plt.ylabel("Score", fontsize=12) + plt.title( + "Precision and Recall vs Score Threshold\n(Retrained Bird Detection Model - Both Methods)", + fontsize=14, + ) + plt.legend(fontsize=11) + plt.grid(True, alpha=0.3) + plt.xlim(min(thresholds) - 0.02, max(thresholds) + 0.02) + max_score = max( + max(precision_scores_validate) if precision_scores_validate else 0, + max(recall_scores_validate) if recall_scores_validate else 0, + max(precision_scores_evaluate) if precision_scores_evaluate else 0, + max(recall_scores_evaluate) if recall_scores_evaluate else 0, + ) + plt.ylim(0, max_score * 1.1 if max_score > 0 else 1.0) + + # Add value labels on points for trainer.validate() + for thresh, prec, rec in zip( + thresholds, precision_scores_validate, recall_scores_validate, strict=True + ): + plt.annotate( + f"{prec:.3f}", + (thresh, prec), + textcoords="offset points", + xytext=(0, 10), + ha="center", + fontsize=7, + color="blue", + ) + plt.annotate( + f"{rec:.3f}", + (thresh, rec), + textcoords="offset points", + xytext=(0, -15), + ha="center", + fontsize=7, + color="blue", + ) + + # Add value labels on points for main.evaluate() + for thresh, prec, rec in zip( + thresholds, precision_scores_evaluate, recall_scores_evaluate, strict=True + ): + plt.annotate( + f"{prec:.3f}", + (thresh, prec), + textcoords="offset points", + xytext=(0, 20), + ha="center", + fontsize=7, + color="red", + ) + plt.annotate( + f"{rec:.3f}", + (thresh, rec), + textcoords="offset points", + xytext=(0, -25), + ha="center", + fontsize=7, + color="red", + ) + + plt.tight_layout() + plt.savefig(output_path, dpi=300, bbox_inches="tight") + print(f"Plot saved to: {output_path}") + + return threshold_results + + +def run(): + """Main function.""" + parser = argparse.ArgumentParser( + description="Compare retrained bird model with pretrained model" + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="Path to the checkpoint file", + ) + parser.add_argument( + "--data_dir", + type=str, + required=True, + help="Directory containing test.csv and images", + ) + parser.add_argument( + "--iou_threshold", + type=float, + default=0.4, + help="IoU threshold for evaluation (default: 0.4)", + ) + parser.add_argument( + "--evaluate_thresholds", + action="store_true", + help="Evaluate checkpoint model at multiple score thresholds (0.1-0.5) and generate plot", + ) + parser.add_argument( + "--plot_output", + type=str, + default=None, + help="Path to save the precision-recall plot (default: data_dir/precision_recall_curve.png)", + ) + + args = parser.parse_args() + + # Run comparison + compare_models( + checkpoint_path=args.checkpoint_path, + data_dir=args.data_dir, + iou_threshold=args.iou_threshold, + ) + + # Evaluate at multiple thresholds if requested + if args.evaluate_thresholds: + evaluate_multiple_thresholds( + checkpoint_path=args.checkpoint_path, + data_dir=args.data_dir, + iou_threshold=args.iou_threshold, + output_path=args.plot_output, + ) + + +if __name__ == "__main__": + run() diff --git a/src/deepforest/scripts/evaluate_deepwater_horizon.py b/src/deepforest/scripts/evaluate_deepwater_horizon.py new file mode 100644 index 000000000..7f83f3ba8 --- /dev/null +++ b/src/deepforest/scripts/evaluate_deepwater_horizon.py @@ -0,0 +1,527 @@ +"""Evaluate bird detection models on DeepWater Horizon imagery. + +This script: +1. Loads shapefiles from the DeepWater Horizon monitoring program +2. Creates a test.csv file +3. Evaluates both the old and new bird detection models +4. Generates visualization comparisons +""" + +import glob +import os + +import geopandas as gpd +import pandas as pd + +from deepforest import main as df_main +from deepforest.preprocess import split_raster +from deepforest.utilities import read_file +from deepforest.visualize import plot_results + + +def load_shapefiles_and_create_test_csv(data_dir, output_csv="test.csv", output_dir=None): + """Load all shapefiles and create a test.csv file. + + Args: + data_dir: Directory containing shapefiles and images + output_csv: Name of output CSV file + output_dir: Directory to write CSV file (default: tries data_dir, falls back to current directory) + + Returns: + Path to the created CSV file + """ + output_path = os.path.join(data_dir, output_csv) + + # Check if CSV already exists + if os.path.exists(output_path): + print(f"Test CSV already exists at {output_path}, skipping creation.") + # Verify it's readable + try: + existing_df = pd.read_csv(output_path) + print( + f"Found existing {output_csv} with {len(existing_df)} annotations from {len(existing_df['image_path'].unique())} images" + ) + except Exception as e: + print(f"Warning: Could not read existing CSV: {e}. Recreating...") + else: + return output_path + + # Find all shapefiles + shapefiles = glob.glob(os.path.join(data_dir, "*_annotated.shp")) + print(f"Found {len(shapefiles)} shapefiles") + + all_annotations = [] + + for shp_path in shapefiles: + # Extract base name to find corresponding image + base_name = os.path.basename(shp_path).replace("_annotated.shp", "") + + # Find corresponding image file + image_files = glob.glob(os.path.join(data_dir, f"{base_name}*.jpg")) + if not image_files: + print(f"Warning: No image found for {base_name}") + continue + + image_path = image_files[0] + image_filename = os.path.basename(image_path) + + print(f"Processing {base_name}: {image_filename}") + + # Read shapefile directly (coordinates are already in image space) + gdf = gpd.read_file(shp_path) + gdf.geometry = gdf.geometry.scale(xfact=1, yfact=-1, origin=(0, 0)) + + # Set image_path + gdf["image_path"] = image_filename + gdf.crs = None + gdf["label"] = "Bird" + gdf = gdf[gdf.geometry.notna()] + gdf = read_file(gdf, root_dir=data_dir) + + all_annotations.append(gdf) + + # Combine all annotations + combined_df = pd.concat(all_annotations, ignore_index=True) + + combined_df.to_csv(output_path, index=False) + print( + f"\nCreated {output_csv} with {len(combined_df)} annotations from {len(combined_df['image_path'].unique())} images" + ) + print(f"Saved to: {output_path}") + + return output_path + + +def split_test_images_for_evaluation( + test_csv, + data_dir, + patch_size=800, + patch_overlap=0, + output_dir=None, + split_csv_name=None, +): + """Split test images into smaller patches for evaluation using + split_raster. + + Args: + test_csv: Path to test CSV file with full image annotations + data_dir: Directory containing test images + patch_size: Size of patches for splitting (default: 800) + patch_overlap: Overlap between patches (default: 0) + output_dir: Directory to save split images (default: test_splits subdirectory of test_csv location) + split_csv_name: Name of the output CSV file (default: test_split.csv) + + Returns: + Tuple of (split_csv_path, split_dir) where split_csv_path is the path to + the CSV file with split image annotations and split_dir is the directory + containing the split images + """ + # Create output directory for split images + if output_dir is None: + output_dir = os.path.join(os.path.dirname(test_csv), "test_splits") + os.makedirs(output_dir, exist_ok=True) + + # Set default CSV name if not provided + if split_csv_name is None: + split_csv_name = "test_split.csv" + + # Read the test CSV + test_df = read_file(test_csv) + unique_images = test_df["image_path"].unique() + + print( + f"\nSplitting {len(unique_images)} test images into {patch_size}-pixel patches..." + ) + print(f"Output directory: {output_dir}") + + all_split_annotations = [] + + for image_name in unique_images: + image_path = os.path.join(data_dir, image_name) + if not os.path.exists(image_path): + print(f"Warning: Image not found: {image_path}") + continue + + print(f"Processing {image_name}...") + + # Get annotations for this image + image_annotations = test_df[test_df["image_path"] == image_name].copy() + + # Create temporary CSV file for this image's annotations + temp_annotations_file = os.path.join( + output_dir, f"temp_{image_name}_annotations.csv" + ) + image_annotations.to_csv(temp_annotations_file, index=False) + + # Use split_raster to create crops + split_df = split_raster( + annotations_file=temp_annotations_file, + path_to_raster=image_path, + root_dir=os.path.dirname(temp_annotations_file), + patch_size=patch_size, + patch_overlap=patch_overlap, + allow_empty=False, + save_dir=output_dir, + ) + + if not split_df.empty: + all_split_annotations.append(split_df) + print( + f" Created {len(split_df['image_path'].unique())} patches with {len(split_df)} annotations" + ) + else: + print(f" Warning: No patches created for {image_name}") + + # Clean up temporary annotations file + if os.path.exists(temp_annotations_file): + os.remove(temp_annotations_file) + + # Combine all split annotations + if all_split_annotations: + combined_split_df = pd.concat(all_split_annotations, ignore_index=True) + + # Save split CSV + split_csv_path = os.path.join(output_dir, split_csv_name) + combined_split_df.to_csv(split_csv_path, index=False) + + print( + f"\nCreated split test CSV with {len(combined_split_df)} annotations from {len(combined_split_df['image_path'].unique())} patches" + ) + print(f"Saved to: {split_csv_path}") + + return split_csv_path, output_dir + else: + raise ValueError( + "No split annotations were created. Check that images exist and contain valid annotations." + ) + + +def evaluate_models(checkpoint_path, data_dir, test_csv, split_dir, iou_threshold=0.4): + """Evaluate both old and new bird detection models. + + Args: + checkpoint_path: Path to the new checkpoint model + data_dir: Directory containing original test data (not used for evaluation) + test_csv: Path to split test CSV file (for evaluation) + split_dir: Directory containing split images (for evaluation) + iou_threshold: IoU threshold for evaluation + + Returns: + Dictionary with evaluation results for both models + """ + results = {} + + # Evaluate new checkpoint model + print("\n" + "=" * 80) + print("Evaluating NEW checkpoint model...") + print("=" * 80) + checkpoint_model = df_main.deepforest.load_from_checkpoint(checkpoint_path) + checkpoint_model.config.score_thresh = 0.25 + checkpoint_model.model.score_thresh = 0.25 + + # Set up validation configuration using split CSV and split directory + checkpoint_model.config.validation.csv_file = test_csv + checkpoint_model.config.validation.root_dir = split_dir + checkpoint_model.config.validation.iou_threshold = iou_threshold + checkpoint_model.config.validation.val_accuracy_interval = 1 + checkpoint_model.config.workers = 0 + checkpoint_model.create_trainer() + + validation_results = checkpoint_model.trainer.validate(checkpoint_model) + checkpoint_validate = validation_results[0] if validation_results else {} + results["checkpoint"] = checkpoint_validate + + print(f"Box Precision: {checkpoint_validate.get('box_precision', 'N/A')}") + print(f"Box Recall: {checkpoint_validate.get('box_recall', 'N/A')}") + print( + f"Empty Frame Accuracy: {checkpoint_validate.get('empty_frame_accuracy', 'N/A')}" + ) + + # Evaluate old pretrained model + print("\n" + "=" * 80) + print("Evaluating OLD pretrained model (weecology/deepforest-bird)...") + print("=" * 80) + pretrained_model = df_main.deepforest() + pretrained_model.load_model("weecology/deepforest-bird") + pretrained_model.config.score_thresh = 0.25 + pretrained_model.model.score_thresh = 0.25 + + # Set label dictionaries to match + pretrained_model.label_dict = {"Bird": 0} + pretrained_model.numeric_to_label_dict = {0: "Bird"} + pretrained_model.config.label_dict = {"Bird": 0} + pretrained_model.config.num_classes = 1 + + # Set up validation configuration using split CSV and split directory + pretrained_model.config.validation.csv_file = test_csv + pretrained_model.config.validation.root_dir = split_dir + pretrained_model.config.validation.iou_threshold = iou_threshold + pretrained_model.config.validation.val_accuracy_interval = 1 + pretrained_model.config.workers = 0 + pretrained_model.create_trainer() + + validation_results = pretrained_model.trainer.validate(pretrained_model) + pretrained_validate = validation_results[0] if validation_results else {} + results["pretrained"] = pretrained_validate + + print(f"Box Precision: {pretrained_validate.get('box_precision', 'N/A')}") + print(f"Box Recall: {pretrained_validate.get('box_recall', 'N/A')}") + print( + f"Empty Frame Accuracy: {pretrained_validate.get('empty_frame_accuracy', 'N/A')}" + ) + + return results, checkpoint_model, pretrained_model + + +def generate_visualizations( + checkpoint_model, + pretrained_model, + data_dir, + test_csv, + output_dir, + num_images=2, +): + """Generate side-by-side visualizations comparing old and new models. + + Args: + checkpoint_model: New checkpoint model + pretrained_model: Old pretrained model + data_dir: Directory containing test data + test_csv: Path to test CSV file + output_dir: Directory to save visualizations + num_images: Number of images to visualize + """ + import matplotlib.pyplot as plt + + from deepforest.utilities import read_file + + os.makedirs(output_dir, exist_ok=True) + + # Read test CSV to get image list + test_df = read_file(test_csv) + unique_images = test_df["image_path"].unique()[:num_images] + + print(f"\nGenerating visualizations for {len(unique_images)} images...") + + for image_name in unique_images: + image_path = os.path.join(data_dir, image_name) + if not os.path.exists(image_path): + print(f"Warning: Image not found: {image_path}") + continue + + print(f"Processing {image_name}...") + + # Get ground truth + ground_truth = test_df[test_df["image_path"] == image_name].copy() + + # Predict with new model + checkpoint_predictions = checkpoint_model.predict_tile( + path=image_path, patch_size=800, patch_overlap=0 + ) + + # Predict with old model + pretrained_predictions = pretrained_model.predict_tile( + path=image_path, patch_size=800, patch_overlap=0 + ) + + # Create side-by-side comparison using savedir approach + # Save individual plots first, then combine + base_name = os.path.splitext(image_name)[0] + plots_dir = "/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated/plots" + os.makedirs(plots_dir, exist_ok=True) + + # Plot new model + if len(checkpoint_predictions) > 0: + plot_results( + checkpoint_predictions, + ground_truth=ground_truth, + image=image_path, + savedir=plots_dir, + basename=f"{base_name}_new", + show=False, + ) + else: + # Create empty plot + fig, ax = plt.subplots(figsize=(10, 10)) + ax.text(0.5, 0.5, "No predictions", ha="center", va="center", fontsize=16) + ax.set_title("New Retrained Model - No Predictions", fontsize=14) + plt.savefig( + os.path.join(plots_dir, f"{base_name}_new.png"), + dpi=300, + bbox_inches="tight", + ) + plt.close(fig) + + # Plot old model + if len(pretrained_predictions) > 0: + plot_results( + pretrained_predictions, + ground_truth=ground_truth, + image=image_path, + savedir=plots_dir, + basename=f"{base_name}_old", + show=False, + ) + else: + # Create empty plot + fig, ax = plt.subplots(figsize=(10, 10)) + ax.text(0.5, 0.5, "No predictions", ha="center", va="center", fontsize=16) + ax.set_title("Original Pretrained Model - No Predictions", fontsize=14) + plt.savefig( + os.path.join(plots_dir, f"{base_name}_old.png"), + dpi=300, + bbox_inches="tight", + ) + plt.close(fig) + + # Combine the two images side by side + from PIL import Image as PILImage + + img1 = PILImage.open(os.path.join(plots_dir, f"{base_name}_new.png")) + img2 = PILImage.open(os.path.join(plots_dir, f"{base_name}_old.png")) + + # Resize to same height + height = max(img1.height, img2.height) + img1 = img1.resize( + (int(img1.width * height / img1.height), height), PILImage.Resampling.LANCZOS + ) + img2 = img2.resize( + (int(img2.width * height / img2.height), height), PILImage.Resampling.LANCZOS + ) + + # Combine + combined = PILImage.new("RGB", (img1.width + img2.width, height)) + combined.paste(img1, (0, 0)) + combined.paste(img2, (img1.width, 0)) + + # Save + output_path = os.path.join(plots_dir, f"{base_name}_comparison.png") + combined.save(output_path, dpi=(300, 300)) + + print(f"Saved: {output_path}") + + +def main(): + """Main function.""" + import argparse + + parser = argparse.ArgumentParser( + description="Evaluate bird detection models on DeepWater Horizon imagery" + ) + parser.add_argument( + "--data_dir", + type=str, + default="/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated", + help="Directory containing shapefiles and images", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + default="/blue/ewhite/b.weinstein/bird_detector_retrain/2022paper/checkpoints/f92a9384135f4481b7372b85d1da5b5f.ckpt", + help="Path to checkpoint file", + ) + parser.add_argument( + "--iou_threshold", + type=float, + default=0.4, + help="IoU threshold for evaluation", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory to save visualizations (default: data_dir/visualizations)", + ) + parser.add_argument( + "--num_images", + type=int, + default=2, + help="Number of images to visualize", + ) + parser.add_argument( + "--patch_size", + type=int, + default=800, + help="Patch size for splitting images during evaluation (default: 800)", + ) + parser.add_argument( + "--patch_overlap", + type=float, + default=0.0, + help="Patch overlap for splitting images during evaluation (default: 0.0)", + ) + + args = parser.parse_args() + + # Set default output directory (use current directory to avoid permission issues) + if args.output_dir is None: + args.output_dir = os.path.join(os.getcwd(), "visualizations") + + # Step 1: Load shapefiles and create test.csv + print("=" * 80) + print("Step 1: Loading shapefiles and creating test.csv") + print("=" * 80) + test_csv = load_shapefiles_and_create_test_csv(args.data_dir) + + # Step 2: Split test images for evaluation + print("\n" + "=" * 80) + print("Step 2: Splitting test images for evaluation") + print("=" * 80) + split_csv, split_dir = split_test_images_for_evaluation( + test_csv=test_csv, + data_dir=args.data_dir, + patch_size=args.patch_size, + patch_overlap=args.patch_overlap, + ) + + # Step 3: Evaluate models using split images + print("\n" + "=" * 80) + print("Step 3: Evaluating models") + print("=" * 80) + results, checkpoint_model, pretrained_model = evaluate_models( + checkpoint_path=args.checkpoint_path, + data_dir=args.data_dir, + test_csv=split_csv, + split_dir=split_dir, + iou_threshold=args.iou_threshold, + ) + + # Step 4: Generate visualizations using full images + print("\n" + "=" * 80) + print("Step 4: Generating visualizations") + print("=" * 80) + generate_visualizations( + checkpoint_model=checkpoint_model, + pretrained_model=pretrained_model, + data_dir=args.data_dir, + test_csv=test_csv, + output_dir=args.output_dir, + num_images=args.num_images, + ) + + # Print summary + print("\n" + "=" * 80) + print("EVALUATION SUMMARY") + print("=" * 80) + print("\nNew Checkpoint Model:") + print(f" Box Precision: {results['checkpoint'].get('box_precision', 'N/A')}") + print(f" Box Recall: {results['checkpoint'].get('box_recall', 'N/A')}") + print( + f" Empty Frame Accuracy: {results['checkpoint'].get('empty_frame_accuracy', 'N/A')}" + ) + + print("\nOriginal Pretrained Model:") + print(f" Box Precision: {results['pretrained'].get('box_precision', 'N/A')}") + print(f" Box Recall: {results['pretrained'].get('box_recall', 'N/A')}") + print( + f" Empty Frame Accuracy: {results['pretrained'].get('empty_frame_accuracy', 'N/A')}" + ) + + print(f"\nVisualizations saved to: {args.output_dir}") + print(f"Original test CSV saved to: {test_csv}") + print(f"Split test CSV saved to: {split_csv}") + print(f"Split images directory: {split_dir}") + + +if __name__ == "__main__": + main() diff --git a/src/deepforest/scripts/evaluate_patch_size_sensitivity.py b/src/deepforest/scripts/evaluate_patch_size_sensitivity.py new file mode 100644 index 000000000..7f14c1bf5 --- /dev/null +++ b/src/deepforest/scripts/evaluate_patch_size_sensitivity.py @@ -0,0 +1,291 @@ +"""Evaluate sensitivity of box_recall and box_precision to patch_size. + +This script wraps evaluate_deepwater_horizon.py to evaluate multiple +patch sizes for both checkpoint and pretrained models, and generate a +sensitivity plot showing how metrics vary with patch size for +comparison. +""" + +import argparse +import importlib.util +import os + +import matplotlib.pyplot as plt +import pandas as pd + +# Import from evaluate_deepwater_horizon in the same directory +_script_dir = os.path.dirname(os.path.abspath(__file__)) +_eval_module_path = os.path.join(_script_dir, "evaluate_deepwater_horizon.py") +spec = importlib.util.spec_from_file_location( + "evaluate_deepwater_horizon", _eval_module_path +) +eval_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(eval_module) + +load_shapefiles_and_create_test_csv = eval_module.load_shapefiles_and_create_test_csv +split_test_images_for_evaluation = eval_module.split_test_images_for_evaluation +evaluate_models = eval_module.evaluate_models + + +def evaluate_patch_size_sensitivity( + data_dir, + checkpoint_path, + patch_sizes, + iou_threshold=0.4, + patch_overlap=0.0, +): + """Evaluate both checkpoint and pretrained models across multiple patch + sizes and collect results. + + Args: + data_dir: Directory containing shapefiles and images + checkpoint_path: Path to checkpoint file + patch_sizes: List of patch sizes to evaluate + iou_threshold: IoU threshold for evaluation + patch_overlap: Overlap between patches + + Returns: + DataFrame with patch_size and metrics for both checkpoint and pretrained models + """ + # Step 1: Load shapefiles and create test.csv (only once) + print("=" * 80) + print("Step 1: Loading shapefiles and creating test.csv") + print("=" * 80) + test_csv = load_shapefiles_and_create_test_csv(data_dir) + + results = [] + + for patch_size in patch_sizes: + print("\n" + "=" * 80) + print(f"Evaluating patch size: {patch_size}") + print("=" * 80) + + # Create patch-size-specific output directory + base_output_dir = os.path.join(os.path.dirname(test_csv), "test_splits") + patch_output_dir = os.path.join(base_output_dir, f"patch_{patch_size}") + split_csv_name = f"test_split_patch_{patch_size}.csv" + + # Check if split CSV already exists + split_csv_path = os.path.join(patch_output_dir, split_csv_name) + if os.path.exists(split_csv_path): + print(f"Found existing split CSV at {split_csv_path}, skipping splitting...") + split_dir = patch_output_dir + else: + # Step 2: Split test images for this patch size + print(f"\nSplitting test images for patch size {patch_size}...") + split_csv_path, split_dir = split_test_images_for_evaluation( + test_csv=test_csv, + data_dir=data_dir, + patch_size=patch_size, + patch_overlap=patch_overlap, + output_dir=patch_output_dir, + split_csv_name=split_csv_name, + ) + + # Step 3: Evaluate both models + print(f"\nEvaluating both models for patch size {patch_size}...") + eval_results, _, _ = evaluate_models( + checkpoint_path=checkpoint_path, + data_dir=data_dir, + test_csv=split_csv_path, + split_dir=split_dir, + iou_threshold=iou_threshold, + ) + + # Extract results for both models + checkpoint_results = eval_results.get("checkpoint", {}) + pretrained_results = eval_results.get("pretrained", {}) + + checkpoint_precision = checkpoint_results.get("box_precision", None) + checkpoint_recall = checkpoint_results.get("box_recall", None) + pretrained_precision = pretrained_results.get("box_precision", None) + pretrained_recall = pretrained_results.get("box_recall", None) + + results.append( + { + "patch_size": patch_size, + "checkpoint_precision": checkpoint_precision, + "checkpoint_recall": checkpoint_recall, + "pretrained_precision": pretrained_precision, + "pretrained_recall": pretrained_recall, + } + ) + + print(f"Patch size {patch_size}:") + print( + f" Checkpoint - Precision={checkpoint_precision}, Recall={checkpoint_recall}" + ) + print( + f" Pretrained - Precision={pretrained_precision}, Recall={pretrained_recall}" + ) + + return pd.DataFrame(results) + + +def plot_sensitivity(results_df, output_path): + """Create a plot showing sensitivity of metrics to patch size for both + models. + + Args: + results_df: DataFrame with patch_size and metrics for both checkpoint and pretrained models + output_path: Path to save the plot + """ + fig, ax = plt.subplots(figsize=(12, 7)) + + # Plot checkpoint model (solid lines) + ax.plot( + results_df["patch_size"], + results_df["checkpoint_precision"], + marker="o", + label="Checkpoint Precision", + linewidth=2, + markersize=8, + linestyle="-", + color="C0", + ) + ax.plot( + results_df["patch_size"], + results_df["checkpoint_recall"], + marker="s", + label="Checkpoint Recall", + linewidth=2, + markersize=8, + linestyle="-", + color="C1", + ) + + # Plot pretrained model (dashed lines) + ax.plot( + results_df["patch_size"], + results_df["pretrained_precision"], + marker="o", + label="Pretrained Precision", + linewidth=2, + markersize=8, + linestyle="--", + color="C0", + alpha=0.7, + ) + ax.plot( + results_df["patch_size"], + results_df["pretrained_recall"], + marker="s", + label="Pretrained Recall", + linewidth=2, + markersize=8, + linestyle="--", + color="C1", + alpha=0.7, + ) + + ax.set_xlabel("Patch Size (pixels)", fontsize=12) + ax.set_ylabel("Metric Value", fontsize=12) + ax.set_title( + "Sensitivity of Box Precision and Recall to Patch Size\n(Checkpoint vs Pretrained Model)", + fontsize=14, + ) + ax.legend(fontsize=10, loc="best") + ax.grid(True, alpha=0.3) + ax.set_xlim(left=0) + + # Ensure y-axis shows full range + all_metrics = pd.concat( + [ + results_df["checkpoint_precision"], + results_df["checkpoint_recall"], + results_df["pretrained_precision"], + results_df["pretrained_recall"], + ] + ) + y_min = all_metrics.min() + y_max = all_metrics.max() + y_range = y_max - y_min + ax.set_ylim( + max(0, y_min - 0.1 * y_range), + min(1.0, y_max + 0.1 * y_range), + ) + + plt.tight_layout() + plt.savefig(output_path, dpi=300, bbox_inches="tight") + print(f"\nSaved sensitivity plot to: {output_path}") + plt.close(fig) + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser( + description="Evaluate sensitivity of metrics to patch size" + ) + parser.add_argument( + "--data_dir", + type=str, + default="/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated", + help="Directory containing shapefiles and images", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + default="/blue/ewhite/b.weinstein/bird_detector_retrain/data/checkpoints/6181df1ab7ac40f291b863a2a9b86024.ckpt", + help="Path to checkpoint file", + ) + parser.add_argument( + "--iou_threshold", + type=float, + default=0.4, + help="IoU threshold for evaluation", + ) + parser.add_argument( + "--patch_overlap", + type=float, + default=0.0, + help="Patch overlap for splitting images (default: 0.0)", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory to save plots (default: data_dir/plots)", + ) + parser.add_argument( + "--patch_sizes", + type=int, + nargs="+", + default=[200, 400, 600, 800, 1000, 1500, 2000], + help="List of patch sizes to evaluate (default: 200 400 600 800 1000 1500 2000)", + ) + + args = parser.parse_args() + + # Set default output directory + if args.output_dir is None: + args.output_dir = os.path.join(args.data_dir, "plots") + os.makedirs(args.output_dir, exist_ok=True) + + # Evaluate across patch sizes (both models) + results_df = evaluate_patch_size_sensitivity( + data_dir=args.data_dir, + checkpoint_path=args.checkpoint_path, + patch_sizes=args.patch_sizes, + iou_threshold=args.iou_threshold, + patch_overlap=args.patch_overlap, + ) + + # Save results to CSV + results_csv = os.path.join(args.output_dir, "patch_size_sensitivity_results.csv") + results_df.to_csv(results_csv, index=False) + print(f"\nSaved results to: {results_csv}") + + # Create and save plot + plot_path = os.path.join(args.output_dir, "patch_size_sensitivity.png") + plot_sensitivity(results_df, plot_path) + + # Print summary + print("\n" + "=" * 80) + print("SENSITIVITY ANALYSIS SUMMARY") + print("=" * 80) + print(results_df.to_string(index=False)) + print(f"\nPlot saved to: {plot_path}") + + +if __name__ == "__main__": + main() diff --git a/src/deepforest/scripts/evaluate_thresholds.py b/src/deepforest/scripts/evaluate_thresholds.py new file mode 100644 index 000000000..5e6aa3e69 --- /dev/null +++ b/src/deepforest/scripts/evaluate_thresholds.py @@ -0,0 +1,190 @@ +"""Evaluate bird detection model at multiple score thresholds. + +This script evaluates a checkpoint model at multiple score thresholds and +generates a precision-recall curve. + +Example usage: + python evaluate_thresholds.py --checkpoint_path /path/to/checkpoint.ckpt --data_dir /path/to/data +""" + +import argparse +import os + +import matplotlib.pyplot as plt +import numpy as np + +from deepforest import main + + +def evaluate_thresholds( + checkpoint_path, data_dir, iou_threshold=0.4, thresholds=None, output_path=None +): + """Evaluate checkpoint model at multiple score thresholds. + + Args: + checkpoint_path: Path to the checkpoint file + data_dir: Directory containing test.csv and images + iou_threshold: IoU threshold for evaluation (default: 0.4) + thresholds: List of score thresholds to evaluate (default: 0.1 to 0.5 in 0.05 steps) + output_path: Path to save the plot (default: data_dir/precision_recall_curve.png) + + Returns: + dict: Dictionary with thresholds, precision, and recall arrays + """ + if thresholds is None: + thresholds = np.arange(0.1, 0.55, 0.05).round(2).tolist() + + test_csv = os.path.join(data_dir, "test.csv") + + if not os.path.exists(test_csv): + raise FileNotFoundError(f"Test CSV not found: {test_csv}") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + print("=" * 80) + print("Evaluating Checkpoint Model at Multiple Score Thresholds") + print("=" * 80) + print(f"\nTest dataset: {test_csv}") + print(f"IoU threshold: {iou_threshold}") + print(f"Score thresholds: {thresholds}\n") + + # Load model once + print("Loading checkpoint model...") + model = main.deepforest.load_from_checkpoint(checkpoint_path) + + precision_scores = [] + recall_scores = [] + + print("\nEvaluating at each threshold:") + print("-" * 80) + for i, threshold in enumerate(thresholds): + print( + f"\n[{i + 1}/{len(thresholds)}] Evaluating at score threshold: {threshold:.2f}" + ) + model.config.score_thresh = threshold + model.model.score_thresh = threshold + + results = model.evaluate( + csv_file=test_csv, + root_dir=data_dir, + iou_threshold=iou_threshold, + ) + + precision = results["box_precision"] + recall = results["box_recall"] + + precision_scores.append(precision) + recall_scores.append(recall) + + print(f" Precision: {precision:.4f}") + print(f" Recall: {recall:.4f}") + + # Create results dictionary + threshold_results = { + "thresholds": thresholds, + "precision": precision_scores, + "recall": recall_scores, + } + + # Print summary table + print("\n" + "=" * 80) + print("SUMMARY TABLE") + print("=" * 80) + print(f"\n{'Threshold':<12} {'Precision':<12} {'Recall':<12}") + print("-" * 40) + for thresh, prec, rec in zip( + thresholds, precision_scores, recall_scores, strict=True + ): + print(f"{thresh:<12.2f} {prec:<12.4f} {rec:<12.4f}") + + # Generate plot + if output_path is None: + output_path = os.path.join(data_dir, "precision_recall_curve.png") + + print(f"\nGenerating plot: {output_path}") + plt.figure(figsize=(10, 6)) + plt.plot( + thresholds, precision_scores, "o-", label="Precision", linewidth=2, markersize=8 + ) + plt.plot(thresholds, recall_scores, "s-", label="Recall", linewidth=2, markersize=8) + plt.xlabel("Score Threshold", fontsize=12) + plt.ylabel("Score", fontsize=12) + plt.title( + "Precision and Recall vs Score Threshold\n(Retrained Bird Detection Model)", + fontsize=14, + ) + plt.legend(fontsize=11) + plt.grid(True, alpha=0.3) + plt.xlim(min(thresholds) - 0.02, max(thresholds) + 0.02) + plt.ylim(0, max(max(precision_scores), max(recall_scores)) * 1.1) + + # Add value labels on points + for thresh, prec, rec in zip( + thresholds, precision_scores, recall_scores, strict=True + ): + plt.annotate( + f"{prec:.3f}", + (thresh, prec), + textcoords="offset points", + xytext=(0, 10), + ha="center", + fontsize=8, + ) + plt.annotate( + f"{rec:.3f}", + (thresh, rec), + textcoords="offset points", + xytext=(0, -15), + ha="center", + fontsize=8, + ) + + plt.tight_layout() + plt.savefig(output_path, dpi=300, bbox_inches="tight") + print(f"Plot saved to: {output_path}") + + return threshold_results + + +def run(): + """Main function.""" + parser = argparse.ArgumentParser( + description="Evaluate checkpoint model at multiple score thresholds" + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="Path to the checkpoint file", + ) + parser.add_argument( + "--data_dir", + type=str, + required=True, + help="Directory containing test.csv and images", + ) + parser.add_argument( + "--iou_threshold", + type=float, + default=0.4, + help="IoU threshold for evaluation (default: 0.4)", + ) + parser.add_argument( + "--plot_output", + type=str, + default=None, + help="Path to save the precision-recall plot (default: data_dir/precision_recall_curve.png)", + ) + + args = parser.parse_args() + + evaluate_thresholds( + checkpoint_path=args.checkpoint_path, + data_dir=args.data_dir, + iou_threshold=args.iou_threshold, + output_path=args.plot_output, + ) + + +if __name__ == "__main__": + run() diff --git a/src/deepforest/scripts/prepare_birds.py b/src/deepforest/scripts/prepare_birds.py new file mode 100644 index 000000000..e0e0f4a89 --- /dev/null +++ b/src/deepforest/scripts/prepare_birds.py @@ -0,0 +1,691 @@ +"""Prepare bird detection training data from multiple sources. + +This script collects annotations from multiple data sources, maps labels to "Bird", +creates symlinks to a single output directory, and generates train/test splits. +This is a documentation/example script - users should adapt paths to their own data. + +Example paths are hardcoded below (actual data not publicly available). +""" + +import argparse +import json +import os + +import numpy as np +import pandas as pd +from PIL import Image +from sklearn.model_selection import train_test_split + +from deepforest.preprocess import split_raster +from deepforest.utilities import read_file + +# Data source file paths (adapt these to your own data locations) +DATA_SOURCES = [ + "/orange/ewhite/b.weinstein/Drones_for_Ducks/uas-imagery-of-migratory-waterfowl/crowdsourced/20240220_dronesforducks_zooniverse_refined.json", + "/orange/ewhite/b.weinstein/izembek-lagoon-waterfowl/izembek-lagoon-birds-metadata.json", + "/orange/ewhite/b.weinstein/bird_detector/generalization/crops/training_annotations.csv", + "/blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/crops/train.csv", +] + +# Path to existing test dataset (if provided, skip train/test split and put all new data in train) +EXISTING_TEST = None # e.g., "/path/to/existing/test.csv" + +# Nuisance labels to exclude (will be filtered out) +NUISANCE_LABELS = {"buoy", "buoys", "trash", "Trash", "boat", "sargassum"} + + +def load_coco_with_bboxes(json_file): + """Load COCO format JSON file with bounding boxes (bbox) instead of + segmentation. + + Args: + json_file: Path to COCO JSON file with bbox annotations + + Returns: + DataFrame with image_path, xmin, ymin, xmax, ymax, label columns + """ + with open(json_file) as f: + coco_data = json.load(f) + + # Create mapping from image_id to file_name + image_ids = {image["id"]: image["file_name"] for image in coco_data["images"]} + + # Create mapping from category_id to category name (if available) + category_ids = {} + if "categories" in coco_data: + category_ids = { + cat["id"]: cat.get("name", f"category_{cat['id']}") + for cat in coco_data["categories"] + } + + annotations = [] + for annotation in coco_data["annotations"]: + # Skip if image_id doesn't exist in images + image_id = annotation["image_id"] + if image_id not in image_ids: + continue + + # COCO bbox format: [x, y, width, height] where (x, y) is top-left corner + try: + bbox = annotation["bbox"] + except KeyError: + continue + + x = bbox[0] + y = bbox[1] + width = bbox[2] + height = bbox[3] + + # Convert to DeepForest format: xmin, ymin, xmax, ymax + xmin = x + ymin = y + xmax = x + width + ymax = y + height + + # Get category label + category_id = annotation.get("category_id", 1) + label = category_ids.get(category_id, "Bird") + + annotations.append( + { + "image_path": image_ids[image_id], + "xmin": xmin, + "ymin": ymin, + "xmax": xmax, + "ymax": ymax, + "label": label, + } + ) + + return pd.DataFrame(annotations) + + +def load_annotations_from_source(source_path): + """Load annotations from a data source file. + + Args: + source_path: Path to annotation file (CSV or JSON) + + Returns: + DataFrame with annotations and root_dir attribute + """ + if not os.path.exists(source_path): + raise FileNotFoundError(f"Source file does not exist: {source_path}") + + if source_path.endswith(".csv"): + df = read_file(source_path) + elif source_path.endswith(".json"): + df = load_coco_with_bboxes(source_path) + else: + raise ValueError(f"Unsupported file type: {source_path}") + + # Add root_dir attribute (directory containing the annotation file) + df.root_dir = os.path.dirname(source_path) + + return df + + +def map_labels_to_bird(df): + """Map all labels to "Bird" except nuisance labels which are filtered out. + + Args: + df: DataFrame with label column + + Returns: + DataFrame with labels mapped to "Bird" and nuisance labels removed + """ + # Filter out nuisance labels + if "label" in df.columns: + mask = ~df["label"].str.lower().isin([n.lower() for n in NUISANCE_LABELS]) + df = df[mask].copy() + + # Map all remaining labels to "Bird" + df["label"] = "Bird" + + return df + + +def create_blank_images(output_dir, num_images=100, image_size=(400, 400)): + """Create blank white images with empty annotations. + + Args: + output_dir: Directory to save images and annotations + num_images: Number of blank images to create + image_size: Tuple of (width, height) for images + + Returns: + DataFrame with empty annotations for blank images + """ + blank_annotations = [] + + for i in range(num_images): + # Create blank white image + blank_image = Image.new("RGB", image_size, color="white") + image_filename = f"blank_image_{i:03d}.png" + image_path = os.path.join(output_dir, image_filename) + blank_image.save(image_path) + + # Create empty annotation (0,0,0,0 coordinates indicate empty frame) + blank_annotations.append( + { + "image_path": image_filename, + "xmin": 0, + "ymin": 0, + "xmax": 0, + "ymax": 0, + "label": "Bird", + } + ) + + return pd.DataFrame(blank_annotations) + + +def create_symlink(source, target): + """Create a symlink, handling existing files. + + Args: + source: Source file path + target: Target symlink path + """ + # Remove target if it exists + if os.path.exists(target) or os.path.islink(target): + os.remove(target) + + # Create parent directory if needed + os.makedirs(os.path.dirname(target), exist_ok=True) + + # Create symlink + os.symlink(source, target) + + +def check_negative_coordinates(df): + """Check for negative bounding box coordinates. + + Args: + df: DataFrame with xmin, ymin, xmax, ymax columns + + Returns: + DataFrame with rows that have negative coordinates + """ + required_cols = ["xmin", "ymin", "xmax", "ymax"] + for col in required_cols: + if col not in df.columns: + return pd.DataFrame() + + # Find rows with any negative coordinates + negative_mask = ( + (df["xmin"] < 0) | (df["ymin"] < 0) | (df["xmax"] < 0) | (df["ymax"] < 0) + ) + return df[negative_mask].copy() + + +def clip_boxes_to_image_bounds(df, image_dir): + """Clip bounding box coordinates to image boundaries. + + Clips negative coordinates to 0 and coordinates beyond image dimensions + to the image edges. Ensures boxes remain valid (xmax > xmin, ymax > ymin). + + Args: + df: DataFrame with image_path, xmin, ymin, xmax, ymax columns + image_dir: Directory containing the images + + Returns: + DataFrame with clipped coordinates + """ + df = df.copy() + required_cols = ["image_path", "xmin", "ymin", "xmax", "ymax"] + for col in required_cols: + if col not in df.columns: + return df + + # Track how many boxes were clipped + clipped_count = 0 + invalid_count = 0 + + # Process each unique image + unique_images = df["image_path"].unique() + for img_path in unique_images: + # Get full image path + full_img_path = os.path.join(image_dir, img_path) + + if not os.path.exists(full_img_path): + continue + + try: + # Load image to get dimensions + img = Image.open(full_img_path) + img_width, img_height = img.size + + # Get annotations for this image + img_mask = df["image_path"] == img_path + img_indices = df[img_mask].index + + for idx in img_indices: + original_xmin = df.at[idx, "xmin"] + original_ymin = df.at[idx, "ymin"] + original_xmax = df.at[idx, "xmax"] + original_ymax = df.at[idx, "ymax"] + + # Clip coordinates to image boundaries + xmin = max(0, min(original_xmin, img_width - 1)) + ymin = max(0, min(original_ymin, img_height - 1)) + xmax = max(xmin + 1, min(original_xmax, img_width)) + ymax = max(ymin + 1, min(original_ymax, img_height)) + + # Check if clipping occurred + if ( + xmin != original_xmin + or ymin != original_ymin + or xmax != original_xmax + or ymax != original_ymax + ): + clipped_count += 1 + df.at[idx, "xmin"] = xmin + df.at[idx, "ymin"] = ymin + df.at[idx, "xmax"] = xmax + df.at[idx, "ymax"] = ymax + + # Check if box is still valid + if xmax <= xmin or ymax <= ymin: + invalid_count += 1 + + except Exception as e: + print(f" Warning: Error processing image {img_path}: {e}") + continue + + if clipped_count > 0: + print(f" Clipped {clipped_count} bounding boxes to image boundaries") + if invalid_count > 0: + print(f" Warning: {invalid_count} boxes became invalid after clipping") + + return df + + +def process_izembek_with_splitting(df, root_dir, output_dir, image_files_map): + """Process Izembek dataset by splitting images into 800-pixel crops. + + Args: + df: DataFrame with annotations + root_dir: Root directory for images + output_dir: Output directory for crops + image_files_map: Map from original path to symlink name + + Returns: + DataFrame with crop annotations + """ + + # Create temporary directory for crops + crops_dir = os.path.join(output_dir, "izembek_crops") + os.makedirs(crops_dir, exist_ok=True) + + crop_annotations_list = [] + unique_images = df["image_path"].unique() + + print(f" Splitting {len(unique_images)} images into 2000-pixel crops...") + + for img_path in unique_images: + # Construct full source path + if os.path.isabs(img_path): + source_img_path = img_path + else: + source_img_path = os.path.join(root_dir, img_path) + + if not os.path.exists(source_img_path): + # Try alternative locations + alt_paths = [ + os.path.join(root_dir, os.path.basename(img_path)), + ] + found = False + for alt_path in alt_paths: + if os.path.exists(alt_path): + source_img_path = alt_path + found = True + break + if not found: + print(f" Warning: Image not found: {source_img_path}") + continue + + # Get image basename for matching with annotations + image_basename = os.path.basename(source_img_path) + + # Filter annotations for this image and update image_path to basename + img_annotations = df[df["image_path"] == img_path].copy() + if img_annotations.empty: + continue + + # Update image_path to basename for split_raster matching + img_annotations["image_path"] = image_basename + + # Save temporary annotations file for this image + temp_annotations_file = os.path.join( + crops_dir, f"temp_{image_basename}_annotations.csv" + ) + img_annotations.to_csv(temp_annotations_file, index=False) + + try: + # Use split_raster to create crops + crop_df = split_raster( + annotations_file=temp_annotations_file, + path_to_raster=source_img_path, + root_dir=os.path.dirname(temp_annotations_file), + patch_size=2000, + patch_overlap=0, + allow_empty=False, + save_dir=crops_dir, + ) + + # Process each crop + for crop_img_path in crop_df["image_path"].unique(): + crop_full_path = os.path.join(crops_dir, crop_img_path) + + if not os.path.exists(crop_full_path): + continue + + # Create unique symlink name + crop_basename = crop_img_path + symlink_name = crop_basename + counter = 1 + while symlink_name in image_files_map.values(): + name, ext = os.path.splitext(crop_basename) + symlink_name = f"{name}_{counter}{ext}" + counter += 1 + + # Create symlink to crop + target_path = os.path.join(output_dir, symlink_name) + try: + create_symlink(crop_full_path, target_path) + image_files_map[crop_img_path] = symlink_name + except Exception as e: + print(f" Warning: Failed to create symlink for {crop_img_path}: {e}") + continue + + # Update image paths in crop dataframe to use symlink name + crop_df.loc[crop_df["image_path"] == crop_img_path, "image_path"] = ( + symlink_name + ) + + crop_annotations_list.append(crop_df) + + except Exception as e: + print(f" Warning: Failed to split image {img_path}: {e}") + continue + finally: + # Clean up temporary annotations file + if os.path.exists(temp_annotations_file): + os.remove(temp_annotations_file) + + if crop_annotations_list: + return pd.concat(crop_annotations_list, ignore_index=True) + else: + return pd.DataFrame() + + +def filter_small_boxes(df, min_area=1, epsilon=1e-6): + """Filter out bounding boxes with zero or single-pixel area. + + Args: + df: DataFrame with xmin, ymin, xmax, ymax columns + min_area: Minimum area (in pixels) for a box to be kept (default: 1) + epsilon: Small value for floating point comparison (default: 1e-6) + + Returns: + DataFrame with small boxes removed + """ + df = df.copy() + required_cols = ["xmin", "ymin", "xmax", "ymax"] + for col in required_cols: + if col not in df.columns: + return df + + # Calculate width, height, and area + width = df["xmax"] - df["xmin"] + height = df["ymax"] - df["ymin"] + area = width * height + + # Round area to handle floating point precision issues + # Single-pixel boxes (width=1, height=1) should have area=1.0 + area_rounded = np.round(area, decimals=6) + + # Filter out boxes with invalid dimensions or area <= min_area + # Filter if: width <= 0, height <= 0, or rounded area <= min_area + # This catches single-pixel boxes (width=1, height=1, area=1) + valid_mask = (width > epsilon) & (height > epsilon) & (area_rounded > min_area) + + removed_count = (~valid_mask).sum() + if removed_count > 0: + print( + f" Removed {removed_count} bounding boxes with area <= {min_area} pixel(s) or single-pixel dimensions" + ) + + return df[valid_mask].copy() + + +def main(): + """Main function to prepare bird detection training data.""" + parser = argparse.ArgumentParser( + description="Prepare bird detection training data from multiple sources" + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Output directory for prepared data (images and CSV files)", + ) + parser.add_argument( + "--test_size", + type=float, + default=0.1, + help="Fraction of data to use for testing (default: 0.2)", + ) + parser.add_argument( + "--num_blank_images", + type=int, + default=100, + help="Number of blank white images to generate (default: 100)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for train/test split (default: 42)", + ) + + args = parser.parse_args() + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + print("Loading annotations from multiple sources...") + all_annotations = [] + image_files_map = {} # Map from original path to symlink name + + # Load annotations from all sources + for source_path in DATA_SOURCES: + print(f"\nProcessing source: {source_path}") + df = load_annotations_from_source(source_path) + df = map_labels_to_bird(df) + + if df.empty: + print(f" No annotations after filtering for {source_path}") + continue + + # Get root directory for images + root_dir = ( + df.root_dir if hasattr(df, "root_dir") else os.path.dirname(source_path) + ) + + # Special case: Drones for Ducks images are in /images subdirectory + if "drones_for_ducks" in source_path.lower(): + root_dir = os.path.join(root_dir, "images") + + # Ensure required columns exist + required_cols = ["image_path", "xmin", "ymin", "xmax", "ymax", "label"] + missing_cols = [col for col in required_cols if col not in df.columns] + if missing_cols: + print(f" Warning: Missing columns {missing_cols}, skipping...") + continue + + # Special case: Izembek dataset - split into 2000-pixel crops + if "izembek" in source_path.lower(): + print( + " Using split_raster to create 2000-pixel crops with allow_empty=False" + ) + df = process_izembek_with_splitting( + df, root_dir, args.output_dir, image_files_map + ) + if df.empty: + print(f" No crop annotations generated for {source_path}") + continue + all_annotations.append(df) + print( + f" Loaded {len(df)} crop annotations from {df['image_path'].nunique()} crop images" + ) + continue + + # Handle image paths - create symlinks + unique_images = df["image_path"].unique() + for img_path in unique_images: + # Construct full source path + if os.path.isabs(img_path): + source_img_path = img_path + else: + source_img_path = os.path.join(root_dir, img_path) + + if not os.path.exists(source_img_path): + # Try alternative locations + alt_paths = [ + os.path.join(root_dir, os.path.basename(img_path)), + os.path.join(os.path.dirname(source_path), img_path), + ] + found = False + for alt_path in alt_paths: + if os.path.exists(alt_path): + source_img_path = alt_path + found = True + break + if not found: + print(f" Warning: Image not found: {source_img_path}") + continue + + # Create unique symlink name + img_basename = os.path.basename(img_path) + symlink_name = img_basename + counter = 1 + while symlink_name in image_files_map.values(): + name, ext = os.path.splitext(img_basename) + symlink_name = f"{name}_{counter}{ext}" + counter += 1 + + # Create symlink + target_path = os.path.join(args.output_dir, symlink_name) + try: + create_symlink(source_img_path, target_path) + image_files_map[img_path] = symlink_name + except Exception as e: + print(f" Warning: Failed to create symlink for {img_path}: {e}") + continue + + # Update image paths in dataframe to use symlink name + df.loc[df["image_path"] == img_path, "image_path"] = symlink_name + + all_annotations.append(df) + print(f" Loaded {len(df)} annotations from {len(unique_images)} images") + + if not all_annotations: + raise ValueError("No annotations were loaded from any source!") + + # Combine all annotations + combined_df = pd.concat(all_annotations, ignore_index=True) + + # Check for negative coordinates before clipping + print("\nChecking for negative bounding box coordinates...") + negative_coords_df = check_negative_coordinates(combined_df) + if not negative_coords_df.empty: + print(f" Found {len(negative_coords_df)} annotations with negative coordinates") + print(f" Affected images: {negative_coords_df['image_path'].nunique()}") + print("\n Summary of negative coordinates:") + print(f" xmin < 0: {(combined_df['xmin'] < 0).sum()}") + print(f" ymin < 0: {(combined_df['ymin'] < 0).sum()}") + print(f" xmax < 0: {(combined_df['xmax'] < 0).sum()}") + print(f" ymax < 0: {(combined_df['ymax'] < 0).sum()}") + + # Clip boxes to image boundaries + print("\nClipping bounding boxes to image boundaries...") + combined_df = clip_boxes_to_image_bounds(combined_df, args.output_dir) + + # Verify clipping worked + negative_after = check_negative_coordinates(combined_df) + if negative_after.empty: + print(" All negative coordinates have been clipped.") + else: + print( + f" Warning: {len(negative_after)} annotations still have negative coordinates after clipping" + ) + else: + print(" No negative coordinates found.") + + # Filter out boxes with zero or single-pixel area + print("\nFiltering out boxes with zero or single-pixel area...") + initial_count = len(combined_df) + combined_df = filter_small_boxes(combined_df, min_area=1) + removed_count = initial_count - len(combined_df) + if removed_count > 0: + print(f" Removed {removed_count} boxes (kept {len(combined_df)} boxes)") + + # Add blank images + print(f"\nGenerating {args.num_blank_images} blank white images...") + blank_df = create_blank_images(args.output_dir, args.num_blank_images) + combined_df = pd.concat([combined_df, blank_df], ignore_index=True) + + print(f"\nTotal annotations: {len(combined_df)}") + print(f"Total unique images: {combined_df['image_path'].nunique()}") + + # Save CSV files + train_csv = os.path.join(args.output_dir, "train.csv") + test_csv = os.path.join(args.output_dir, "test.csv") + + # Ensure required columns are present and in correct order + required_cols = ["image_path", "xmin", "ymin", "xmax", "ymax", "label"] + + # If existing test dataset is provided, skip split and put all new data in train + if EXISTING_TEST and os.path.exists(EXISTING_TEST): + print(f"\nUsing existing test dataset: {EXISTING_TEST}") + print("Putting all new data in training set...") + train_df = combined_df[required_cols].copy() + train_df.to_csv(train_csv, index=False) + print( + f"\nSaved training annotations: {train_csv} ({len(train_df)} annotations, {train_df['image_path'].nunique()} images)" + ) + print(f"Using existing test dataset: {EXISTING_TEST}") + else: + # Split into train/test by image_path (to avoid data leakage) + print( + f"\nSplitting into train/test ({1 - args.test_size:.0%}/{args.test_size:.0%})..." + ) + unique_images = combined_df["image_path"].unique() + train_images, test_images = train_test_split( + unique_images, test_size=args.test_size, random_state=args.seed + ) + + train_df = combined_df[combined_df["image_path"].isin(train_images)].copy() + test_df = combined_df[combined_df["image_path"].isin(test_images)].copy() + + train_df = train_df[required_cols] + test_df = test_df[required_cols] + + train_df.to_csv(train_csv, index=False) + test_df.to_csv(test_csv, index=False) + + print( + f"\nSaved training annotations: {train_csv} ({len(train_df)} annotations, {len(train_images)} images)" + ) + print( + f"Saved test annotations: {test_csv} ({len(test_df)} annotations, {len(test_images)} images)" + ) + + print(f"\nOutput directory: {args.output_dir}") + print("\nData preparation complete!") + + +if __name__ == "__main__": + main() diff --git a/src/deepforest/scripts/push_bird_model_to_hf.py b/src/deepforest/scripts/push_bird_model_to_hf.py new file mode 100644 index 000000000..601f87b64 --- /dev/null +++ b/src/deepforest/scripts/push_bird_model_to_hf.py @@ -0,0 +1,86 @@ +"""Push trained bird detection model to HuggingFace Hub via PR. + +This script loads a trained model checkpoint and creates a pull request +on HuggingFace Hub to update the weecology/deepforest-bird model. + +Example usage: + python push_bird_model_to_hf.py --checkpoint path/to/checkpoint.ckpt +""" + +import argparse +import os +from pathlib import Path + +from dotenv import load_dotenv +from huggingface_hub import login + +from deepforest import main + + +def run(): + """Main function to push model to HuggingFace via PR.""" + parser = argparse.ArgumentParser( + description="Push trained bird model to HuggingFace Hub via PR" + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the model checkpoint file (.ckpt)", + ) + parser.add_argument( + "--repo-id", + type=str, + default="weecology/deepforest-bird", + help="HuggingFace repository ID (default: weecology/deepforest-bird)", + ) + parser.add_argument( + "--commit-message", + type=str, + default="Update model weights", + help="Commit message for the PR", + ) + + args = parser.parse_args() + + # Load HF token from .env + load_dotenv() + hf_token = os.getenv("HF_TOKEN") + if not hf_token: + raise ValueError( + "HF_TOKEN not found in .env file. Please add your HuggingFace token to .env" + ) + + # Login to HuggingFace + login(token=hf_token) + + # Verify checkpoint exists + checkpoint_path = Path(args.checkpoint) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + print(f"Loading model from checkpoint: {checkpoint_path}") + # Load model from checkpoint + model = main.deepforest.load_from_checkpoint(str(checkpoint_path)) + + # Ensure label_dict is set (should be loaded from checkpoint, but verify) + if not hasattr(model, "label_dict") or model.label_dict is None: + print("Warning: label_dict not found in checkpoint, setting default Bird label") + model.label_dict = {"Bird": 0} + model.numeric_to_label_dict = {0: "Bird"} + + print(f"Model loaded with label_dict: {model.label_dict}") + + # Push to HuggingFace Hub - this will automatically create a PR + print(f"Pushing model to {args.repo_id} and creating PR...") + model.model.push_to_hub( + args.repo_id, + commit_message=args.commit_message, + create_pr=True, + ) + + print(f"\nSuccessfully created PR to update {args.repo_id}!") + + +if __name__ == "__main__": + run() diff --git a/src/deepforest/scripts/submit_train_birds.sh b/src/deepforest/scripts/submit_train_birds.sh new file mode 100755 index 000000000..9cac60af9 --- /dev/null +++ b/src/deepforest/scripts/submit_train_birds.sh @@ -0,0 +1,24 @@ +#!/bin/bash +#SBATCH --job-name=train_birds # Job name +#SBATCH --mail-type=END # Mail events +#SBATCH --mail-user=benweinstein2010@gmail.com # Where to send mail +#SBATCH --account=ewhite +#SBATCH --nodes=1 # Number of MPI ran +#SBATCH --cpus-per-task=10 +#SBATCH --mem=200GB +#SBATCH --time=48:00:00 #Time limit hrs:min:sec +#SBATCH --output=/home/b.weinstein/logs/train_birds%j.out # Standard output and error log +#SBATCH --error=/home/b.weinstein/logs/train_birds%j.err +#SBATCH --partition=hpg-b200 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus=1 + +# Example usage: +# First prepare the data: +uv run python src/deepforest/scripts/prepare_birds.py --output_dir /blue/ewhite/b.weinstein/bird_detector_retrain/data/ + +srun uv run python src/deepforest/scripts/train_birds.py \ + --data_dir /blue/ewhite/b.weinstein/bird_detector_retrain/data/ \ + --batch_size 32 \ + --workers 10 \ + --epochs 40 diff --git a/src/deepforest/scripts/train_birds.py b/src/deepforest/scripts/train_birds.py new file mode 100644 index 000000000..11f5a3372 --- /dev/null +++ b/src/deepforest/scripts/train_birds.py @@ -0,0 +1,266 @@ +"""Train DeepForest bird detection model. + +This script trains a bird detection model using the weecology/deepforest-bird +pretrained model as a starting point. + +Example usage: + python train_birds.py --data_dir /path/to/prepared/data --batch_size 12 --workers 5 +""" + +import argparse +import os + +import pandas as pd +import torch +from omegaconf import OmegaConf +from pytorch_lightning.loggers import CometLogger + +from deepforest import callbacks, main + + +def run(): + """Main training function.""" + parser = argparse.ArgumentParser(description="Train DeepForest bird detection model") + parser.add_argument( + "--data_dir", + type=str, + required=True, + help="Directory containing train.csv, test.csv, and images", + ) + parser.add_argument( + "--batch_size", + type=int, + default=12, + help="Batch size for training (default: 12)", + ) + parser.add_argument( + "--workers", + type=int, + default=5, + help="Number of workers for data loading (default: 5)", + ) + parser.add_argument( + "--epochs", + type=int, + default=12, + help="Number of training epochs (default: 12)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.001, + help="Learning rate (default: 0.001)", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="Directory to save model checkpoints (default: data_dir/checkpoints)", + ) + parser.add_argument( + "--fast_dev_run", + action="store_true", + help="Run a fast development run with a single batch", + ) + + args = parser.parse_args() + + # Set matmul precision to high for faster training on Tensor Core GPUs + if torch.cuda.is_available(): + torch.set_float32_matmul_precision("high") + print("Set torch.float32_matmul_precision to 'high' for faster training") + + # Set up paths + train_csv = os.path.join(args.data_dir, "train.csv") + test_csv = os.path.join(args.data_dir, "test.csv") + + if not os.path.exists(train_csv): + raise FileNotFoundError(f"Training CSV not found: {train_csv}") + if not os.path.exists(test_csv): + raise FileNotFoundError(f"Test CSV not found: {test_csv}") + + if args.checkpoint_dir is None: + checkpoint_dir = os.path.join(args.data_dir, "checkpoints") + else: + checkpoint_dir = args.checkpoint_dir + os.makedirs(checkpoint_dir, exist_ok=True) + + print("Initializing DeepForest model...") + # Initialize DeepForest model + m = main.deepforest() + + # Load the pretrained tree model as a starting point + # print("Loading pretrained tree model: weecology/deepforest-tree") + m.load_model("weecology/deepforest-tree") + + # Set label dictionaries for single "Bird" class + m.label_dict = {"Bird": 0} + m.numeric_to_label_dict = {0: "Bird"} + m.config.label_dict = {"Bird": 0} + m.config.num_classes = 1 + + m.config.score_thresh = 0.25 + m.model.score_thresh = 0.25 + + # Configure training data paths + m.config["train"]["csv_file"] = train_csv + m.config["train"]["root_dir"] = args.data_dir + m.config["train"]["fast_dev_run"] = args.fast_dev_run + m.config["train"]["epochs"] = args.epochs + m.config["train"]["lr"] = args.lr + m.config["train"]["scheduler"]["params"]["patience"] = 3 + + # Configure validation data paths + m.config["validation"]["csv_file"] = test_csv + m.config["validation"]["root_dir"] = args.data_dir + m.config["validation"]["val_accuracy_interval"] = 1 + m.config["validation"]["size"] = 800 + + # Configure data loading + m.config["batch_size"] = args.batch_size + m.config["workers"] = args.workers + + # Configure augmentations with modern options + # Using zoom augmentations (RandomResizedCrop), rotations, and other augmentations + # Use OmegaConf.update to bypass strict type validation + augmentations_config = OmegaConf.create( + { + "train": { + "augmentations": [ + { + "RandomResizedCrop": { + "size": (800, 800), + "scale": (0.3, 1.0), + "p": 0.5, + } + }, + {"Rotate": {"degrees": 15, "p": 0.5}}, + {"HorizontalFlip": {"p": 0.5}}, + {"VerticalFlip": {"p": 0.3}}, + {"PadIfNeeded": {"size": (1000, 1000)}}, + # {"RandomBrightnessContrast": {"brightness": 0.2, "contrast": 0.2, "p": 0.5}}, + # {"HueSaturationValue": {"hue": 0.1, "saturation": 0.1, "p": 0.3}}, + # {"ZoomBlur": {"max_factor": (1.0, 1.03), "step_factor": (0.01, 0.02), "p": 0.3}}, + ] + } + } + ) + OmegaConf.set_struct(m.config, False) + m.config = OmegaConf.merge(m.config, augmentations_config) + OmegaConf.set_struct(m.config, True) + + # Configure scheduler (similar to BOEM script) + m.config["train"]["scheduler"]["params"]["eps"] = 0 + + # Set up Comet logger (optional, will skip if not configured) + comet_logger = None + try: + comet_logger = CometLogger() + comet_logger.experiment.add_tag("bird-detection") + + # Log training and test set sizes + + train_df = pd.read_csv(train_csv) + test_df = pd.read_csv(test_csv) + comet_logger.experiment.log_table("train.csv", train_df) + comet_logger.experiment.log_table("test.csv", test_df) + + # Log training parameters + devices = torch.cuda.device_count() if torch.cuda.is_available() else 0 + comet_logger.experiment.log_parameter("devices", devices) + comet_logger.experiment.log_parameter("workers", m.config["workers"]) + comet_logger.experiment.log_parameter("batch_size", m.config["batch_size"]) + comet_logger.experiment.log_parameter("train_size", len(train_df)) + comet_logger.experiment.log_parameter("test_size", len(test_df)) + comet_logger.experiment.log_parameter("epochs", args.epochs) + comet_logger.experiment.log_parameter("learning_rate", args.lr) + + print(f"Comet logging enabled: {comet_logger.experiment.get_key()}") + except Exception as e: + print(f"Warning: Could not initialize Comet logger: {e}") + print("Continuing without Comet logging...") + comet_logger = None + + # Set up image callback for validation visualization + images_dir = os.path.join(checkpoint_dir, "images") + os.makedirs(images_dir, exist_ok=True) + im_callback = callbacks.ImagesCallback( + save_dir=images_dir, + prediction_samples=20, # Number of validation images to log + dataset_samples=20, # Number of dataset samples to log at start + every_n_epochs=1, # Log predictions every epoch + ) + + # Create trainer with GPU support + print("Creating trainer...") + # For DDP, each process uses 1 device. PyTorch Lightning will handle + + m.create_trainer( + logger=comet_logger, + callbacks=[im_callback], + devices=devices, + strategy="ddp", + precision="16-mixed", # Use mixed precision training for faster performance + fast_dev_run=args.fast_dev_run, + enable_progress_bar=True, + ) + + # Train the model + print("\nStarting training...") + m.trainer.fit(m) + m.trainer.validate(m) + + # Save the model checkpoint + checkpoint_path = os.path.join(checkpoint_dir, f"{comet_logger.experiment.id}.ckpt") + print(f"\nSaving checkpoint to: {checkpoint_path}") + m.trainer.save_checkpoint(checkpoint_path) + + # Evaluate on zero-shot dataset + print("\n" + "=" * 80) + print("Evaluating on zero-shot dataset (DeepWater Horizon)") + print("=" * 80) + + # Update validation config for zero-shot dataset + m.config.validation.csv_file = "/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated/test_splits/test_split_patch_600.csv" + m.config.validation.root_dir = "/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated/test_splits/patch_600" + m.config.validation.iou_threshold = 0.4 + + # Create new trainer for zero-shot evaluation + m.create_trainer() + + # Evaluate on zero-shot dataset + zero_shot_results = m.trainer.validate(m) + zero_shot_metrics = zero_shot_results[0] if zero_shot_results else {} + + print("\nZero-shot evaluation results:") + print(f" Box Precision: {zero_shot_metrics.get('box_precision', 'N/A')}") + print(f" Box Recall: {zero_shot_metrics.get('box_recall', 'N/A')}") + print( + f" Empty Frame Accuracy: {zero_shot_metrics.get('empty_frame_accuracy', 'N/A')}" + ) + + # log the zero-shot evaluation results to the comet logger + if comet_logger: + comet_logger.experiment.log_metric( + "zero_shot_box_precision", zero_shot_metrics.get("box_precision", "N/A") + ) + comet_logger.experiment.log_metric( + "zero_shot_box_recall", zero_shot_metrics.get("box_recall", "N/A") + ) + comet_logger.experiment.log_metric( + "zero_shot_empty_frame_accuracy", + zero_shot_metrics.get("empty_frame_accuracy", "N/A"), + ) + + if comet_logger: + # Log global steps + global_steps = torch.tensor( + m.trainer.global_step, dtype=torch.int32, device=m.device + ) + comet_logger.experiment.log_metric("global_steps", global_steps.item()) + + print("\nTraining complete!") + + +if __name__ == "__main__": + run() From 7b024a76c53c63ff78fd95e7892ed00da29724dd Mon Sep 17 00:00:00 2001 From: bw4sz Date: Tue, 21 Apr 2026 09:37:43 -0700 Subject: [PATCH 2/2] Remove trailing whitespace in 02_prebuilt.md --- docs/user_guide/02_prebuilt.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user_guide/02_prebuilt.md b/docs/user_guide/02_prebuilt.md index 3f423944b..161172a20 100644 --- a/docs/user_guide/02_prebuilt.md +++ b/docs/user_guide/02_prebuilt.md @@ -34,7 +34,7 @@ The model was initially described in [Ecological Applications](https://esajourna Using over 250,000 annotations from 13 projects from around the world, we develop a general bird detection model that achieves over 65% recall and 50% precision on novel aerial data without any local training despite differences in species, habitat, and imaging methodology. Fine-tuning this model with only 1000 local annotations increases these values to an average of 84% recall and 69% precision by building on the general features learned from other data sources. > -The bird detection model has been updated and retrained from the original `weecology/deepforest-bird` model. The updated model was fine-tuned starting from the tree detection model (`weecology/deepforest-tree`) and trained on data from both Weinstein et al. 2022 as well as new additional bird detection data from multiple sources including https://lila.science/. The result is a dataset with over a million bird detections from around the world. Training details and metrics can be viewed on the [Comet dashboard](https://www.comet.com/bw4sz/bird-detector/6181df1ab7ac40f291b863a2a9b86024?&prevPath=%2Fbw4sz%2Fbird-detector%2Fview%2Fnew%2Fexperiments). +The bird detection model has been updated and retrained from the original `weecology/deepforest-bird` model. The updated model was fine-tuned starting from the tree detection model (`weecology/deepforest-tree`) and trained on data from both Weinstein et al. 2022 as well as new additional bird detection data from multiple sources including https://lila.science/. The result is a dataset with over a million bird detections from around the world. Training details and metrics can be viewed on the [Comet dashboard](https://www.comet.com/bw4sz/bird-detector/6181df1ab7ac40f291b863a2a9b86024?&prevPath=%2Fbw4sz%2Fbird-detector%2Fview%2Fnew%2Fexperiments). ### Example Predictions