diff --git a/.gitignore b/.gitignore index 10c813b9e..663a5b3f1 100644 --- a/.gitignore +++ b/.gitignore @@ -21,7 +21,7 @@ docs/_build/ env/ logs/* .venv/ -lightning_logs/* +lightning_logs* public/ tests/__pycache__ tests/data/* diff --git a/docs/development/cluster.md b/docs/development/cluster.md new file mode 100644 index 000000000..6aabbafd4 --- /dev/null +++ b/docs/development/cluster.md @@ -0,0 +1,170 @@ +# Cluster Distributed Runs + +This page shows supported patterns for running DeepForest across multiple GPUs and multiple nodes on a Slurm-managed cluster (for example HiPerGator). + +## Slurm: `sbatch` and `srun` + +`sbatch` requests the allocation (nodes, GPUs, tasks, memory, time). `srun` inside that batch script starts a **job step** within the same allocation. It does **not** submit a second job or double-charge the scheduler. + +Match `#SBATCH --ntasks-per-node` to `devices` (one Slurm task per GPU) and `#SBATCH --nodes` to `num_nodes`. For multi-GPU DDP, launch with `srun`. For a single GPU, the cluster train script runs the command directly in the batch step. + +Example launchers live under `src/deepforest/scripts/HPC/`. + +## Shared Settings + +Use the same launch pattern for `train`, `evaluate`, and `predict`: + +- `devices=` is the number of GPUs on each node +- `num_nodes=` is the total number of nodes +- `strategy=ddp` enables distributed data parallel execution (use `auto` for single-GPU jobs) +- `workers=0` is required for large-tile prediction with `dataloader_strategy="window"` + +## Environment + +```bash +ml conda +eval "$(conda shell.bash hook)" +conda activate predict +cd /path/to/DeepForest +mkdir -p slurm_logs +``` + +## Train + +Use `src/deepforest/scripts/HPC/run_cluster_train.sbatch` for production training and smoke tests. The launcher script is `run_cluster_train.sh`. + +### Production training (single GPU) + +Defaults use `TRAIN_MODE=train` and `CONFIG_NAME=bird`. Submit from the repo root: + +```bash +sbatch src/deepforest/scripts/HPC/run_cluster_train.sbatch +``` + +Hydra overrides and resume: + +```bash +export COMET_EXPERIMENT_NAME="exp_lr_0.0005" +sbatch src/deepforest/scripts/HPC/run_cluster_train.sbatch train.lr=0.0005 train.epochs=80 + +RESUME_CKPT=/path/to/last.ckpt sbatch src/deepforest/scripts/HPC/run_cluster_train.sbatch +``` + +Multi-GPU or multi-node training: set Slurm resources at submit time and pass matching Hydra settings if needed. The script infers `SCENARIO` from the allocation. + +```bash +sbatch --nodes=2 --ntasks-per-node=2 --gpus-per-node=2 --cpus-per-task=8 --mem=128G --time=15:00:00 \ + src/deepforest/scripts/HPC/run_cluster_train.sbatch \ + --strategy ddp devices=2 num_nodes=2 +``` + +### Smoke tests + +Smoke tests use bundled OSBS sample data (`TRAIN_MODE=smoke`, `CONFIG_NAME=smoke`, 1 epoch). Set `SCENARIO` and match `#SBATCH` resources: + +```bash +# 1 GPU +TRAIN_MODE=smoke SCENARIO=1gpu sbatch --nodes=1 --ntasks-per-node=1 --gpus-per-node=1 \ + --cpus-per-task=8 --mem=32G --time=00:30:00 \ + src/deepforest/scripts/HPC/run_cluster_train.sbatch + +# Multi-GPU (one node) +TRAIN_MODE=smoke SCENARIO=multigpu GPUS_PER_NODE=2 sbatch --nodes=1 --ntasks-per-node=2 --gpus-per-node=2 \ + --cpus-per-task=8 --mem=64G --time=00:45:00 \ + src/deepforest/scripts/HPC/run_cluster_train.sbatch + +# Multi-node +TRAIN_MODE=smoke SCENARIO=multinode GPUS_PER_NODE=2 NNODES=2 sbatch --nodes=2 --ntasks-per-node=2 --gpus-per-node=2 \ + --cpus-per-task=8 --mem=64G --time=01:00:00 \ + src/deepforest/scripts/HPC/run_cluster_train.sbatch +``` + +Optional: `export COMET_EXPERIMENT_NAME="my-smoke-run"` before `sbatch`. Disable Comet with `USE_COMET=0`. + +### Train directly in a batch script + +```bash +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 + +srun uv run deepforest train \ + --strategy ddp \ + accelerator=gpu \ + devices=2 \ + num_nodes=2 \ + train.csv_file=/path/to/train.csv \ + train.root_dir=/path/to/train_images \ + validation.csv_file=/path/to/val.csv \ + validation.root_dir=/path/to/val_images +``` + +## Evaluate + +```bash +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 + +srun uv run deepforest evaluate \ + /path/to/ground_truth.csv \ + --root-dir /path/to/images \ + --save-predictions eval_preds.csv \ + -o eval_metrics.csv \ + --strategy ddp \ + accelerator=gpu \ + devices=2 \ + num_nodes=2 +``` + +## Predict From CSV + +For the cluster regression test and example launcher (submit from the repo root): + +```bash +sbatch src/deepforest/scripts/HPC/run_cluster_predict_test.sbatch +``` + +To run your own CSV prediction job directly: + +```bash +srun uv run deepforest predict \ + /path/to/images.csv \ + --mode csv \ + --root-dir /path/to/images \ + -o predictions.csv \ + --strategy ddp \ + accelerator=gpu \ + devices=2 \ + num_nodes=2 +``` + +## Predict A Large Tile + +For large rasters on a cluster, prefer `predict_tile(..., dataloader_strategy="window")`. + +The ready-to-run test launcher is: + +```bash +sbatch src/deepforest/scripts/HPC/run_cluster_predict_tile_test.sbatch +``` + +To run a tiled prediction job directly: + +```bash +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 + +srun uv run python tests/cluster_predict_tile_driver.py \ + --input-path /path/to/tile.tif \ + --output-path tile_predictions.csv \ + --model-name weecology/everglades-bird-species-detector \ + --patch-size 1500 \ + --patch-overlap 0 \ + --dataloader-strategy window \ + --devices 2 \ + --num-nodes 2 +``` + +See also the [multi-GPU and multi-node guide](../user_guide/distributed.md). diff --git a/docs/development/index.md b/docs/development/index.md index e6e810d5f..1c745f776 100644 --- a/docs/development/index.md +++ b/docs/development/index.md @@ -5,6 +5,7 @@ ```{toctree} :maxdepth: 1 +cluster authors contributing code_of_conduct diff --git a/docs/user_guide/07_scaling.md b/docs/user_guide/07_scaling.md index 067ddfa0e..8d3f41b72 100644 --- a/docs/user_guide/07_scaling.md +++ b/docs/user_guide/07_scaling.md @@ -1,5 +1,7 @@ # Scaling DeepForest using PyTorch Lightning +For concise launch recipes, see the [multi-GPU and multi-node guide](distributed.md). If you are using a Slurm-managed cluster, see the [cluster developer guide](../development/cluster.md). + ## Increase batch size It is more efficient to run a larger batch size on a single GPU. This is because the overhead of loading data and moving data between the CPU and GPU is relatively large. By running a larger batch size, we can reduce the overhead of these operations. @@ -27,9 +29,7 @@ A few notes that can trip up those less used to multi-gpu training. These are fo 2. Each device gets its own portion of the dataset. This means that they do not interact during forward passes. -3. Make sure to use srun when combining with SLURM! This is an easy one to miss and will cause training to hang without error. Documented here - -https://lightning.ai/docs/pytorch/latest/clouds/cluster_advanced.html#troubleshooting. +3. On SLURM, launch with **`srun`**. Match `#SBATCH --ntasks-per-node` to `devices` and `#SBATCH --nodes` to `num_nodes`. See the [multi-GPU and multi-node guide](distributed.md) and [Lightning SLURM troubleshooting](https://lightning.ai/docs/pytorch/latest/clouds/cluster_advanced.html#troubleshooting). ## Prediction diff --git a/docs/user_guide/09_configuration_file.md b/docs/user_guide/09_configuration_file.md index bda0de7ba..bd200116f 100644 --- a/docs/user_guide/09_configuration_file.md +++ b/docs/user_guide/09_configuration_file.md @@ -151,17 +151,30 @@ The number of cpus/gpus to use during model training. Deepforest has been tested ### accelerator -Most commonly, `cpu`, `gpu` or `tpu` as well as other [options](https://pytorch-lightning.readthedocs.io/en/1.4.0/advanced/multi_gpu.html) listed: +Most commonly, `cpu`, `gpu` or `tpu` as well as other [options](https://lightning.ai/docs/pytorch/stable/accelerators/gpu.html). -If `gpu`, it can be helpful to specify the data parallelization strategy. This can be done using the `strategy` arg in `main.create_trainer()` +### num_nodes + +Number of machines for distributed training. Default is `1`. Set this to your Slurm node count for multi-node jobs. See [Scaling](07_scaling.md) and [distributed runs](distributed.md). + +### strategy + +Distributed training strategy passed to the Lightning `Trainer`. Default is `auto` (appropriate for single-GPU runs). Use `ddp` for multi-GPU or multi-node training. + +Set in the config file, as Hydra overrides (`strategy=ddp`), or via `create_trainer(strategy="ddp")`. CLI and `create_trainer` kwargs override the config file. ```python -from deepforest import model as m +from deepforest import main -m.create_trainer(logger=comet_logger, strategy="ddp") +m = main.deepforest() +m.config.accelerator = "gpu" +m.config.devices = 4 +m.config.num_nodes = 2 +m.config.strategy = "ddp" +m.create_trainer(logger=comet_logger) ``` -This is passed to the pytorch-lightning trainer, documented in the link above for multi-gpu training. +On Slurm clusters, launch with `srun` so Lightning can read the job environment. Details are in [distributed runs](distributed.md). ### batch_size diff --git a/docs/user_guide/11_training.md b/docs/user_guide/11_training.md index 310ee9219..e8b8690b1 100644 --- a/docs/user_guide/11_training.md +++ b/docs/user_guide/11_training.md @@ -526,38 +526,28 @@ Usually creating this object does not cost too much computational time. #### Training across multiple nodes on a HPC system -We have heard that this error can appear when trying to deep copy the pytorch lightning module. The trainer object is not pickleable. -For example, on multi-gpu environments when trying to scale the deepforest model the entire module is copied leading to this error. -Setting the trainer object to None and directly using the pytorch object is a reasonable workaround. +On Slurm clusters, submit jobs with `srun` and set `devices`, `num_nodes`, and `strategy=ddp` to match your `#SBATCH` allocation. See [Scaling](07_scaling.md) and [distributed runs](distributed.md). -Replace +If you see **Weakly referenced objects** when scaling across GPUs, the trainer object may not pickle cleanly when the module is copied. A workaround is to construct a `Trainer` directly: ```python m = main.deepforest() -m.create_trainer() -m.trainer.fit(m) -``` - -with - -```python m.trainer = None from pytorch_lightning import Trainer - trainer = Trainer( - accelerator="gpu", - strategy="ddp", - devices=model.config.devices, - enable_checkpointing=False, - max_epochs=model.config.train.epochs, - logger=comet_logger - ) +trainer = Trainer( + accelerator="gpu", + strategy="ddp", + devices=m.config.devices, + num_nodes=m.config.num_nodes, + enable_checkpointing=False, + max_epochs=m.config.train.epochs, + logger=comet_logger, +) trainer.fit(m) ``` -The added benefits of this is more control over the trainer object. -The downside is that it doesn't align with the .config pattern where a user now has to look into the config to create the trainer. -We are open to changing this to be the default pattern in the future and welcome input from users. +We are open to making this the default pattern and welcome input from users. #### Visualization during training @@ -598,6 +588,8 @@ We provide a basic script to trigger a training run via CLI. This script is inst If you are using `uv` to manage your Python environment, remember to prefix these commands with `uv run`, for example: `uv run deepforest predict`. ``` +On a Slurm cluster, wrap the command in `srun` inside your batch script (see [Scaling](07_scaling.md) and [distributed runs](distributed.md)). + ```bash deepforest train batch_size=8 train.csv_file=your_labels.csv train.root_dir=some/path ``` diff --git a/docs/user_guide/distributed.md b/docs/user_guide/distributed.md new file mode 100644 index 000000000..3bfc86740 --- /dev/null +++ b/docs/user_guide/distributed.md @@ -0,0 +1,84 @@ +# Multi-GPU and Multi-Node Runs + +DeepForest uses PyTorch Lightning distributed execution. For most multi-GPU and multi-node runs, these settings matter: + +- `accelerator=gpu` +- `devices=` +- `num_nodes=` +- `strategy=ddp` + +On Slurm clusters, launch with **`srun`** inside your job allocation. Match `#SBATCH --ntasks-per-node` to `devices` and `#SBATCH --nodes` to `num_nodes`. See the [cluster developer guide](../development/cluster.md). + +Single-GPU jobs can keep the default `strategy=auto`. + +## Train + +```bash +#SBATCH --nodes= +#SBATCH --ntasks-per-node= +#SBATCH --gres=gpu: + +srun uv run deepforest train \ + --strategy ddp \ + accelerator=gpu \ + devices= \ + num_nodes= \ + train.csv_file=/path/to/train.csv \ + train.root_dir=/path/to/train_images \ + validation.csv_file=/path/to/val.csv \ + validation.root_dir=/path/to/val_images +``` + +## Evaluate + +```bash +srun uv run deepforest evaluate \ + /path/to/ground_truth.csv \ + --root-dir /path/to/images \ + --save-predictions eval_preds.csv \ + -o eval_metrics.csv \ + --strategy ddp \ + accelerator=gpu \ + devices= \ + num_nodes= +``` + +## Predict From CSV + +```bash +srun uv run deepforest predict \ + /path/to/images.csv \ + --mode csv \ + --root-dir /path/to/images \ + -o predictions.csv \ + --strategy ddp \ + accelerator=gpu \ + devices= \ + num_nodes= +``` + +## Predict A Large Tile + +For large geospatial rasters, use `predict_tile(..., dataloader_strategy="window")` instead of the simple CLI tile mode. + +```python +from deepforest.main import deepforest + +m = deepforest() +m.load_model("weecology/everglades-bird-species-detector") +m.config.accelerator = "gpu" +m.config.devices = 2 +m.config.num_nodes = 2 +m.config.strategy = "ddp" +m.config.workers = 0 +m.create_trainer() + +results = m.predict_tile( + path="/path/to/tile.tif", + patch_size=1500, + patch_overlap=0, + dataloader_strategy="window", +) +``` + +Launch that script with the same `srun` Slurm pattern and trainer settings. For a complete cluster example, see the [cluster developer guide](../development/cluster.md). diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index e087156f5..4639326fb 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -18,6 +18,7 @@ The User Guide covers the core DeepForest package usage and functionalities. 05_model_architecture 06_multi_species 07_scaling +distributed 08_visualizations 09_configuration_file 10_better diff --git a/pyproject.toml b/pyproject.toml index ec0920c00..e86c0c405 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,7 +155,8 @@ filterwarnings = [ ] markers = [ "slow: marks tests that are slow to run", - "integration: marks integration tests" + "integration: marks integration tests", + "cluster: marks tests intended for Slurm cluster runs only" ] [tool.coverage.run] diff --git a/src/deepforest/callbacks.py b/src/deepforest/callbacks.py index dfd728ce8..631b09412 100644 --- a/src/deepforest/callbacks.py +++ b/src/deepforest/callbacks.py @@ -17,7 +17,7 @@ from PIL import Image from pytorch_lightning import Callback -from deepforest import utilities, visualize +from deepforest import distributed, utilities, visualize from deepforest.datasets.training import BoxDataset @@ -54,7 +54,7 @@ def __init__( def on_train_start(self, trainer, pl_module): """Log sample images from training and validation datasets at training start.""" - if trainer.fast_dev_run: + if trainer.fast_dev_run or not distributed.is_global_zero(trainer): return self.trainer = trainer @@ -73,7 +73,11 @@ def on_train_start(self, trainer, pl_module): def on_validation_end(self, trainer, pl_module): """Run callback at validation end.""" - if trainer.sanity_checking or trainer.fast_dev_run: + if ( + trainer.sanity_checking + or trainer.fast_dev_run + or not distributed.is_global_zero(trainer) + ): return if (trainer.current_epoch + 1) % self.every_n_epochs == 0: diff --git a/src/deepforest/conf/bird.yaml b/src/deepforest/conf/bird.yaml new file mode 100644 index 000000000..945f95748 --- /dev/null +++ b/src/deepforest/conf/bird.yaml @@ -0,0 +1,68 @@ +# Config file for DeepForest pytorch module +defaults: + - config + - _self_ + +# Dataloader settings (--batch_size 32 --workers 8; train.sh --cpus-per-task=10). +# Scale: increase batch_size if GPU memory allows; increase workers only with more Slurm CPUs +# (train.sh --cpus-per-task); keep workers roughly cpus-per-task minus 1–2. +workers: 8 +batch_size: 32 + +# Model Architecture +architecture: 'retinanet' +score_thresh: 0.1 + +# Set model name to None to initialize from scratch +model: + name: 'weecology/deepforest-bird' + revision: 'refs/pr/5' + +# Trainer precision. Override to 16-mixed for faster training, 32-true for full precision. +precision: 16-mixed + +# On CUDA, setting the matmul precision can provide speed up without affecting model performance +# Start with 'high' for initial tests. +matmul_precision: highest + +# By default, this will be populated from the model +# checkpoint that is selected in model.name/revision. +# e.g. { "Tree": 0 } +label_dict: + Bird: 0 + +num_classes: 1 + +log_root: ./lightning_logs + +train: + epochs: 95 + + csv_file: /blue/ewhite/everglades/cluster_deepforest/datasets/fine_tune_2025/train.csv + root_dir: /blue/ewhite/everglades/cluster_deepforest/datasets/fine_tune_2025/root + #validate_labels: true + + lr: 0.001 + optimizer: + type: SGD + scheduler: + type: cosine + params: + T_max: 95 + eta_min: 0.00005 + + # Data augmentations for training + augmentations: + - HorizontalFlip: {p: 0.5} + - VerticalFlip: {p: 0.5} + - Rotate: {degrees: 15, p: 0.5} + - RandomResizedCrop: {size: [800, 800], scale: [0.3, 1.0], p: 0.5} + +validation: + csv_file: /blue/ewhite/everglades/cluster_deepforest/datasets/fine_tune_2025/val.csv + root_dir: /blue/ewhite/everglades/cluster_deepforest/datasets/fine_tune_2025/root + + # For retinanet you may prefer val_classification, but the default val_loss + # should work with all models + lr_plateau_target: val_loss + val_accuracy_interval: 1 diff --git a/src/deepforest/conf/config.yaml b/src/deepforest/conf/config.yaml index 63ef1a137..cb13d46b1 100644 --- a/src/deepforest/conf/config.yaml +++ b/src/deepforest/conf/config.yaml @@ -5,6 +5,11 @@ workers: 0 devices: auto accelerator: auto +num_nodes: 1 +strategy: auto +precision: 32-true +sync_batchnorm: False +use_distributed_sampler: True batch_size: 1 # Model Architecture @@ -18,9 +23,6 @@ model: name: 'weecology/deepforest-tree' revision: 'main' -# Trainer precision. Override to 16-mixed for faster training, 32-true for full precision. -precision: - # On CUDA, setting the matmul precision can provide speed up without affecting model performance # Start with 'high' for initial tests. matmul_precision: highest diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index e9bcd133d..96c2ea6ab 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -160,6 +160,11 @@ class Config: workers: int = 0 devices: int | str = "auto" accelerator: str = "auto" + num_nodes: int = 1 + strategy: str = "auto" + precision: str = "32-true" + sync_batchnorm: bool = False + use_distributed_sampler: bool = True batch_size: int = 1 precision: str | None = None matmul_precision: str = "highest" diff --git a/src/deepforest/conf/smoke.yaml b/src/deepforest/conf/smoke.yaml new file mode 100644 index 000000000..827368541 --- /dev/null +++ b/src/deepforest/conf/smoke.yaml @@ -0,0 +1,20 @@ +# Quick cluster smoke-test config (bundled OSBS sample data). +# Paths are set by src/deepforest/scripts/HPC/run_cluster_train.sh (TRAIN_MODE=smoke). +defaults: + - config + - _self_ + +workers: 0 +batch_size: 2 +log_root: ./lightning_logs_smoke + +train: + epochs: 1 + fast_dev_run: false + csv_file: + root_dir: + +validation: + csv_file: + root_dir: + val_accuracy_interval: 1 diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py index c41ad3869..641ba98fa 100644 --- a/src/deepforest/datasets/prediction.py +++ b/src/deepforest/datasets/prediction.py @@ -38,12 +38,14 @@ def __init__( images=None, patch_size=400, patch_overlap=0, + return_metadata=False, ): self.image = image self.images = images self.path = path self.patch_size = patch_size self.patch_overlap = patch_overlap + self.return_metadata = return_metadata self.items = self.prepare_items() def load_and_preprocess_image( @@ -89,10 +91,18 @@ def __len__(self): def __getitem__(self, idx): """Get the item at the given index.""" - return self.get_crop(idx) + if not self.return_metadata: + return self.get_crop(idx) + + return {"image": self.get_crop(idx), "metadata": self.get_metadata(idx)} def collate_fn(self, batch): """Collate the batch into a list.""" + if self.return_metadata: + return { + "images": [item["image"] for item in batch], + "metadata": [item["metadata"] for item in batch], + } return batch def get_crop_bounds(self, idx): @@ -108,6 +118,13 @@ def get_image_basename(self, idx): """Get the basename of the image at the given index.""" raise NotImplementedError("Subclasses must implement this method") + def get_metadata(self, idx): + """Get metadata needed to postprocess a prediction item.""" + return { + "image_path": self.get_image_basename(idx), + "window_bounds": self.get_crop_bounds(idx), + } + def determine_geometry_type(self, batched_result): """Determine the geometry type of the batched result.""" # Assumes that all geometries are the same in a batch @@ -163,9 +180,20 @@ def postprocess(self, batch, prediction_index): class SingleImage(PredictionDataset): """Take in a single image path, preprocess and batch together.""" - def __init__(self, path=None, image=None, patch_size=400, patch_overlap=0): + def __init__( + self, + path=None, + image=None, + patch_size=400, + patch_overlap=0, + return_metadata=False, + ): super().__init__( - path=path, image=image, patch_size=patch_size, patch_overlap=patch_overlap + path=path, + image=image, + patch_size=patch_size, + patch_overlap=patch_overlap, + return_metadata=return_metadata, ) def prepare_items(self): @@ -199,10 +227,10 @@ class FromCSVFile(PredictionDataset): """Take in a csv file with image paths and preprocess and batch together.""" - def __init__(self, csv_file: str, root_dir: str): + def __init__(self, csv_file: str, root_dir: str, return_metadata=False): self.csv_file = csv_file self.root_dir = root_dir - super().__init__() + super().__init__(return_metadata=return_metadata) def prepare_items(self): self.annotations = read_file(self.csv_file) @@ -248,9 +276,10 @@ class _IndexedCrops(list): """List of crops with the dataset index attached for correct batch collation.""" - def __init__(self, index: int, crops: list): + def __init__(self, index: int, crops: list, metadata: list | None = None): super().__init__(crops) self.index = index + self.metadata = metadata or [] class MultiImage(PredictionDataset): @@ -259,7 +288,13 @@ class MultiImage(PredictionDataset): Note: This dataset will load the first image to determine the image dimensions. Images are expected to be the same size. For variable sized images, write a csv file and use the FromCSVFile dataset. """ - def __init__(self, paths: list[str], patch_size: int, patch_overlap: float): + def __init__( + self, + paths: list[str], + patch_size: int, + patch_overlap: float, + return_metadata=False, + ): """ Args: paths (List[str]): List of image paths. @@ -270,6 +305,7 @@ def __init__(self, paths: list[str], patch_size: int, patch_overlap: float): self.patch_size = patch_size self.patch_overlap = patch_overlap self.batch_indices = [] + self.return_metadata = return_metadata image = self.load_and_preprocess_image(image_path=self.paths[0]) self.image_height = image.shape[1] @@ -359,7 +395,11 @@ def window_list(self): def __getitem__(self, idx): """Return crops with dataset index so collate_fn can use global indices.""" - return _IndexedCrops(idx, self.get_crop(idx)) + crops = self.get_crop(idx) + metadata = None + if self.return_metadata: + metadata = self.get_metadata(idx, crop_count=len(crops)) + return _IndexedCrops(idx, crops, metadata=metadata) def collate_fn(self, batch): """Collate the batch into a single list of crops. @@ -369,6 +409,11 @@ def collate_fn(self, batch): image_path correctly. """ # item.index is the global dataset index; sublist is the list of crops + if self.return_metadata: + flattened_batch = [crop for item in batch for crop in item] + metadata = [meta for item in batch for meta in item.metadata] + return {"images": flattened_batch, "metadata": metadata} + batch_indices = [ [item.index, sub_idx] for item in batch for sub_idx in range(len(item)) ] @@ -392,6 +437,21 @@ def get_image_basename(self, idx): def get_crop_bounds(self, idx): return self.window_list()[idx] + def get_metadata(self, idx, crop_count=None): + if crop_count is None: + crop_count = len(self.get_crop(idx)) + + windows = self.window_list() + return [ + { + "image_path": self.get_image_basename(idx), + "window_bounds": windows[window_idx] + if window_idx < len(windows) + else None, + } + for window_idx in range(crop_count) + ] + def postprocess(self, batch, prediction_index, original_batch_structure): """Postprocess flattened batch of predictions from multiple images. @@ -436,10 +496,15 @@ class TiledRaster(PredictionDataset): A dataset of raster windows """ - def __init__(self, path, patch_size, patch_overlap): + def __init__(self, path, patch_size, patch_overlap, return_metadata=False): if path is None: raise ValueError("path is required for a memory raster dataset") - super().__init__(path=path, patch_size=patch_size, patch_overlap=patch_overlap) + super().__init__( + path=path, + patch_size=patch_size, + patch_overlap=patch_overlap, + return_metadata=return_metadata, + ) def prepare_items(self): # Get raster shape without keeping file open diff --git a/src/deepforest/distributed.py b/src/deepforest/distributed.py new file mode 100644 index 000000000..ea625f75f --- /dev/null +++ b/src/deepforest/distributed.py @@ -0,0 +1,90 @@ +"""Helpers for distributed-safe logging and object gathering.""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +import pandas as pd +import torch.distributed as dist +from torch.utils.data import Sampler + + +def is_distributed() -> bool: + """Return True when torch.distributed is initialized.""" + return dist.is_available() and dist.is_initialized() + + +def is_global_zero(trainer: Any | None = None) -> bool: + """Return True on the global-zero rank.""" + if trainer is not None: + return bool(getattr(trainer, "is_global_zero", True)) + + if not is_distributed(): + return True + + return dist.get_rank() == 0 + + +def should_sync(trainer: Any | None = None) -> bool: + """Return True when metrics should be synchronized across ranks.""" + if trainer is not None: + world_size = getattr(trainer, "world_size", 1) + return world_size is not None and world_size > 1 + + return is_distributed() + + +def get_rank() -> int: + """Return the current distributed rank, defaulting to zero.""" + if not is_distributed(): + return 0 + + return dist.get_rank() + + +def get_world_size() -> int: + """Return the distributed world size, defaulting to one.""" + if not is_distributed(): + return 1 + + return dist.get_world_size() + + +class FixedOrderSampler(Sampler[int]): + """Yield a fixed sequence of dataset indices.""" + + def __init__(self, indices: list[int]): + self.indices = indices + + def __iter__(self) -> Iterator[int]: + return iter(self.indices) + + def __len__(self) -> int: + return len(self.indices) + + +def gather_object(obj: Any) -> list[Any]: + """Gather a Python object from every rank.""" + if not is_distributed(): + return [obj] + + gathered = [None] * get_world_size() + dist.all_gather_object(gathered, obj) + return gathered + + +def gather_dataframe(frame: pd.DataFrame) -> pd.DataFrame: + """Gather pandas DataFrames from every rank.""" + gathered = gather_object(frame) + non_empty_frames = [ + item for item in gathered if isinstance(item, pd.DataFrame) and not item.empty + ] + + if non_empty_frames: + return pd.concat(non_empty_frames, ignore_index=True) + + if isinstance(frame, pd.DataFrame): + return frame.iloc[0:0].copy() + + return pd.DataFrame() diff --git a/src/deepforest/main.py b/src/deepforest/main.py index c30a9737a..329ca5629 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -2,6 +2,7 @@ import importlib import os import warnings +from numbers import Number import numpy as np import pandas as pd @@ -15,7 +16,7 @@ from torch import optim from torchmetrics.detection import IntersectionOverUnion, MeanAveragePrecision -from deepforest import predict, utilities +from deepforest import distributed, predict, utilities from deepforest.datasets import prediction, training from deepforest.metrics import RecallPrecision @@ -258,6 +259,11 @@ def create_trainer(self, logger=None, callbacks=None, **kwargs): "enable_checkpointing": enable_checkpointing, "devices": self.config.devices, "accelerator": self.config.accelerator, + "num_nodes": self.config.num_nodes, + "strategy": self.config.strategy, + "precision": self.config.precision, + "sync_batchnorm": self.config.sync_batchnorm, + "use_distributed_sampler": self.config.use_distributed_sampler, "fast_dev_run": self.config.train.fast_dev_run, "callbacks": callbacks, "limit_val_batches": limit_val_batches, @@ -271,6 +277,33 @@ def create_trainer(self, logger=None, callbacks=None, **kwargs): self.trainer = pl.Trainer(**trainer_args) + def _format_prediction_frame(self, prediction_result, metadata): + """Format a raw model prediction with dataset metadata.""" + formatted_result = utilities.format_geometry(prediction_result) + if formatted_result is None: + return pd.DataFrame() + + window_bounds = metadata.get("window_bounds") + if window_bounds is not None: + formatted_result["window_xmin"] = window_bounds[0] + formatted_result["window_ymin"] = window_bounds[1] + + formatted_result["image_path"] = metadata.get("image_path") + + return formatted_result.reset_index(drop=True) + + def _gather_prediction_frames(self, frames): + """Gather prediction frames across ranks.""" + if not frames: + return distributed.gather_dataframe(pd.DataFrame()) + + local_frames = [frame for frame in frames if not frame.empty] + if not local_frames: + return distributed.gather_dataframe(frames[0].iloc[0:0].copy()) + + local_results = pd.concat(local_frames, ignore_index=True) + return distributed.gather_dataframe(local_results) + def on_fit_start(self): if self.config.train.csv_file is None and self.existing_train_dataloader is None: raise AttributeError( @@ -464,12 +497,24 @@ def predict_dataloader(self, ds, batch_size=None): batch_size = self.config.batch_size else: batch_size = batch_size + sampler = None + if ( + self.config.use_distributed_sampler + and distributed.is_distributed() + and len(ds) < distributed.get_world_size() + ): + rank = distributed.get_rank() + local_indices = [rank] if rank < len(ds) else [] + sampler = distributed.FixedOrderSampler(local_indices) + loader = torch.utils.data.DataLoader( ds, batch_size=batch_size, shuffle=False, + sampler=sampler, num_workers=self.config.workers, collate_fn=ds.collate_fn, + pin_memory=self.config.predict.pin_memory, ) return loader @@ -554,7 +599,9 @@ def predict_file( Returns: df: pandas dataframe with bounding boxes, label and scores for each image in the csv file """ - ds = prediction.FromCSVFile(csv_file=csv_file, root_dir=root_dir) + ds = prediction.FromCSVFile( + csv_file=csv_file, root_dir=root_dir, return_metadata=True + ) dataloader = self.predict_dataloader(ds, batch_size=self.config.batch_size) results = predict._dataloader_wrapper_( model=self, @@ -630,6 +677,7 @@ def predict_tile( image=image, patch_overlap=patch_overlap, patch_size=patch_size, + return_metadata=True, ) else: # Check for workers config when using out of memory dataset @@ -643,46 +691,39 @@ def predict_tile( path=image_path, patch_overlap=patch_overlap, patch_size=patch_size, + return_metadata=True, ) dataloader = self.predict_dataloader(ds) batched_results = self.trainer.predict(self, dataloader) - # Flatten list from batched prediction - # Track global window index across batches - global_window_idx = 0 - for batch in batched_results: - for window_result in batch: - image_results.append( - ds.postprocess(window_result, global_window_idx) - ) - global_window_idx += 1 + image_results.append( + predict._flatten_prediction_batches_(batched_results) + ) if not image_results: results = pd.DataFrame() else: - results = pd.concat(image_results) + results = self._gather_prediction_frames(image_results) elif dataloader_strategy == "batch": self.original_batch_structure.clear() ds = prediction.MultiImage( - paths=paths, patch_overlap=patch_overlap, patch_size=patch_size + paths=paths, + patch_overlap=patch_overlap, + patch_size=patch_size, + return_metadata=True, ) dataloader = self.predict_dataloader(ds) batched_results = self.trainer.predict(self, dataloader) - # Flatten list from batched prediction - for idx, batch in enumerate(batched_results): - formatted_result = ds.postprocess( - batch, idx, self.original_batch_structure - ) - image_results.append(formatted_result) + image_results.append(predict._flatten_prediction_batches_(batched_results)) if not image_results: results = pd.DataFrame() else: - results = pd.concat(image_results) + results = self._gather_prediction_frames(image_results) else: raise ValueError(f"Invalid dataloader_strategy: {dataloader_strategy}") @@ -763,11 +804,21 @@ def training_step(self, batch, batch_idx): # Log loss for key, value in loss_dict.items(): self.log( - f"train_{key}", value.detach(), on_epoch=True, batch_size=len(images) + f"train_{key}", + value.detach(), + on_epoch=True, + batch_size=len(images), + sync_dist=distributed.should_sync(self.trainer), ) # Log sum of losses - self.log("train_loss", losses.detach(), on_epoch=True, batch_size=len(images)) + self.log( + "train_loss", + losses.detach(), + on_epoch=True, + batch_size=len(images), + sync_dist=distributed.should_sync(self.trainer), + ) return losses @@ -786,9 +837,21 @@ def validation_step(self, batch, batch_idx): # Log losses for key, value in loss_dict.items(): - self.log(f"val_{key}", value.detach(), on_epoch=True, batch_size=len(images)) + self.log( + f"val_{key}", + value.detach(), + on_epoch=True, + batch_size=len(images), + sync_dist=distributed.should_sync(self.trainer), + ) - self.log("val_loss", losses.detach(), on_epoch=True, batch_size=len(images)) + self.log( + "val_loss", + losses.detach(), + on_epoch=True, + batch_size=len(images), + sync_dist=distributed.should_sync(self.trainer), + ) # In eval model, return predictions to calculate prediction metrics self.model.eval() @@ -883,16 +946,38 @@ def _compute_epoch_metrics(self) -> dict: return metrics + def _prepare_metrics_for_sync(self, metrics: dict) -> dict: + """Move scalar metrics onto the module device for NCCL reduction.""" + synced_metrics = {} + + for key, value in metrics.items(): + if isinstance(value, torch.Tensor): + synced_metrics[key] = value.to(self.device) + elif isinstance(value, Number): + synced_metrics[key] = torch.tensor(value, device=self.device) + else: + synced_metrics[key] = value + + return synced_metrics + def on_validation_epoch_end(self): """Compute metrics and predictions at the end of the validation epoch.""" if self.trainer.sanity_checking: # optional skip return + gathered_predictions = self._gather_prediction_frames(self.predictions) + self.predictions = ( + [gathered_predictions] if not gathered_predictions.empty else [] + ) + # Log epoch metrics if (self.current_epoch + 1) % self.config.validation.val_accuracy_interval == 0: metrics = self._compute_epoch_metrics() - self.log_dict(metrics) + should_sync = distributed.should_sync(self.trainer) + if should_sync: + metrics = self._prepare_metrics_for_sync(metrics) + self.log_dict(metrics, sync_dist=should_sync) # Manual reset. Lightning does not do this automatically # unless we log the metric objects directly @@ -920,17 +1005,25 @@ def predict_step(self, batch, batch_idx): """ if isinstance(batch, dict): images = batch["images"] - batch_indices = batch["batch_indices"] - self.original_batch_structure.append(batch_indices) + metadata = batch.get("metadata") + batch_indices = batch.get("batch_indices") + if batch_indices is not None: + self.original_batch_structure.append(batch_indices) else: - batch_indices = None images = batch + metadata = None self.model.eval() with torch.no_grad(): preds = self.model.forward(images) - return preds + if metadata is None: + return preds + + return [ + self._format_prediction_frame(prediction_result, sample_metadata) + for prediction_result, sample_metadata in zip(preds, metadata, strict=True) + ] def predict_batch(self, images, preprocess_fn=None): """Predict a batch of images with the deepforest model. @@ -1106,7 +1199,7 @@ def evaluate( self.config.validation.val_accuracy_interval = 1 self.create_trainer() - self.trainer.validate(self) + validation_results = self.trainer.validate(self) # Gather predictions from all ranks in multi-GPU settings if self.trainer.world_size > 1: @@ -1127,7 +1220,11 @@ def evaluate( self.predictions = pd.DataFrame() results = {} - results.update(self.trainer.logged_metrics) + if isinstance(validation_results, list): + if validation_results and isinstance(validation_results[0], dict): + results.update(validation_results[0]) + elif isinstance(validation_results, dict): + results.update(validation_results) results["predictions"] = self.predictions if self.model.task == "box" or self.model.task == "point": results["results"] = self.precision_recall_metric.get_results() diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index 7396a023f..a4422a25f 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -9,7 +9,7 @@ from shapely import affinity from torchvision.ops import nms -from deepforest import utilities +from deepforest import distributed, utilities from deepforest.datasets import cropmodel from deepforest.utilities import read_file @@ -223,6 +223,25 @@ def across_class_nms(predicted_boxes, iou_threshold=0.15): return new_df +def _flatten_prediction_batches_(batched_results): + """Flatten prediction batches returned by Lightning predict().""" + flattened = [] + for batch in batched_results: + if isinstance(batch, pd.DataFrame): + if not batch.empty: + flattened.append(batch) + continue + + for item in batch: + if isinstance(item, pd.DataFrame) and not item.empty: + flattened.append(item) + + if not flattened: + return pd.DataFrame() + + return pd.concat(flattened, ignore_index=True) + + def _dataloader_wrapper_(model, trainer, dataloader, root_dir, crop_model): """ @@ -237,31 +256,19 @@ def _dataloader_wrapper_(model, trainer, dataloader, root_dir, crop_model): results: pandas dataframe with bounding boxes, label and scores for each image in the csv file """ batched_results = trainer.predict(model, dataloader) - - # Flatten list from batched prediction - prediction_list = [] - global_image_idx = 0 - for _idx, batch in enumerate(batched_results): - for _image_idx, image_result in enumerate(batch): - formatted_result = dataloader.dataset.postprocess( - image_result, global_image_idx - ) - global_image_idx += 1 - prediction_list.append(formatted_result) - - # Postprocess predictions, return empty dataframe if no predictions - if not prediction_list: - return pd.DataFrame() - - results = pd.concat(prediction_list) + results = distributed.gather_dataframe(_flatten_prediction_batches_(batched_results)) if results.empty: - return results + return pd.DataFrame() # Apply across class NMS for each image processed_results = [] for image_path in results.image_path.unique(): image_results = results[results.image_path == image_path].copy() + if image_results.label.nunique() > 1: + image_results = across_class_nms( + image_results, iou_threshold=model.config.nms_thresh + ) if crop_model: # Flag to check if only one model is passed @@ -278,6 +285,11 @@ def _dataloader_wrapper_(model, trainer, dataloader, root_dir, crop_model): ) processed_results.append(crop_model_results) + else: + processed_results.append(image_results) + + if processed_results: + results = pd.concat(processed_results, ignore_index=True) results = read_file(results, root_dir) diff --git a/src/deepforest/scripts/HPC/run_cluster_predict_test.sbatch b/src/deepforest/scripts/HPC/run_cluster_predict_test.sbatch new file mode 100644 index 000000000..4ced188bc --- /dev/null +++ b/src/deepforest/scripts/HPC/run_cluster_predict_test.sbatch @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +#SBATCH --job-name=deepforest-predict-test +#SBATCH --output=%x-%j.out +#SBATCH --error=%x-%j.err +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=80G +#SBATCH --time=05:30:00 + +set -euo pipefail + +cd "${SLURM_SUBMIT_DIR:?SLURM_SUBMIT_DIR not set}" +REPO_ROOT="$(pwd)" + +DEFAULT_BIRD_ROOT="/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated/test_splits/patch_600" +DEFAULT_BIRD_CSV="$DEFAULT_BIRD_ROOT/test_split_patch_600.csv" +CONDA_MODULE="${CONDA_MODULE:-conda}" +CONDA_ENV_NAME="${CONDA_ENV_NAME:-predict}" + +export DEEPFOREST_HPC_PREDICT_CSV="${DEEPFOREST_HPC_PREDICT_CSV:-$DEFAULT_BIRD_CSV}" +export DEEPFOREST_HPC_ROOT_DIR="${DEEPFOREST_HPC_ROOT_DIR:-$DEFAULT_BIRD_ROOT}" + +export RUN_CLUSTER_TESTS="${RUN_CLUSTER_TESTS:-1}" +export DEEPFOREST_HPC_NNODES="${DEEPFOREST_HPC_NNODES:-${SLURM_NNODES:-2}}" +export DEEPFOREST_HPC_GPUS_PER_NODE="${DEEPFOREST_HPC_GPUS_PER_NODE:-2}" +export DEEPFOREST_HPC_OUTPUT_DIR="${DEEPFOREST_HPC_OUTPUT_DIR:-$REPO_ROOT/cluster_test_outputs}" + +mkdir -p "$DEEPFOREST_HPC_OUTPUT_DIR" + +ml "$CONDA_MODULE" +eval "$(conda shell.bash hook)" +conda activate "$CONDA_ENV_NAME" + +echo "Running cluster predict integration test" +echo "CSV: $DEEPFOREST_HPC_PREDICT_CSV" +echo "ROOT_DIR: $DEEPFOREST_HPC_ROOT_DIR" +echo "OUTPUT_DIR: $DEEPFOREST_HPC_OUTPUT_DIR" +echo "NNODES: $DEEPFOREST_HPC_NNODES" +echo "GPUS_PER_NODE: $DEEPFOREST_HPC_GPUS_PER_NODE" +echo "CONDA_MODULE: $CONDA_MODULE" +echo "CONDA_ENV_NAME: $CONDA_ENV_NAME" + +if [[ ! -f "$DEEPFOREST_HPC_PREDICT_CSV" ]]; then + echo "Prediction CSV not found: $DEEPFOREST_HPC_PREDICT_CSV" + exit 1 +fi + +if [[ ! -d "$DEEPFOREST_HPC_ROOT_DIR" ]]; then + echo "Prediction root directory not found: $DEEPFOREST_HPC_ROOT_DIR" + exit 1 +fi + +uv run pytest -m cluster tests/test_cluster_predict.py -s diff --git a/src/deepforest/scripts/HPC/run_cluster_predict_tile_test.sbatch b/src/deepforest/scripts/HPC/run_cluster_predict_tile_test.sbatch new file mode 100644 index 000000000..03f76f553 --- /dev/null +++ b/src/deepforest/scripts/HPC/run_cluster_predict_tile_test.sbatch @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +#SBATCH --job-name=deepforest-predict-tile-test +#SBATCH --output=%x-%j.out +#SBATCH --error=%x-%j.err +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus-per-node=2 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=80G +#SBATCH --time=05:30:00 + +set -euo pipefail + +cd "${SLURM_SUBMIT_DIR:?SLURM_SUBMIT_DIR not set}" +REPO_ROOT="$(pwd)" + +DEFAULT_TILE_DIR="/blue/ewhite/everglades/projected_mosaics/2023/6thBridge" +CONDA_MODULE="${CONDA_MODULE:-conda}" +CONDA_ENV_NAME="${CONDA_ENV_NAME:-predict}" + +export RUN_CLUSTER_TESTS="${RUN_CLUSTER_TESTS:-1}" +export DEEPFOREST_HPC_NNODES="${DEEPFOREST_HPC_NNODES:-${SLURM_NNODES:-2}}" +export DEEPFOREST_HPC_GPUS_PER_NODE="${DEEPFOREST_HPC_GPUS_PER_NODE:-2}" +export DEEPFOREST_HPC_OUTPUT_DIR="${DEEPFOREST_HPC_OUTPUT_DIR:-$REPO_ROOT/cluster_tile_outputs}" +export DEEPFOREST_HPC_TILE_DIR="${DEEPFOREST_HPC_TILE_DIR:-$DEFAULT_TILE_DIR}" +export DEEPFOREST_HPC_TILE_MODEL_NAME="${DEEPFOREST_HPC_TILE_MODEL_NAME:-weecology/everglades-bird-species-detector}" +export DEEPFOREST_HPC_TILE_PATCH_SIZE="${DEEPFOREST_HPC_TILE_PATCH_SIZE:-1500}" +export DEEPFOREST_HPC_TILE_PATCH_OVERLAP="${DEEPFOREST_HPC_TILE_PATCH_OVERLAP:-0}" +export DEEPFOREST_HPC_TILE_IOU_THRESHOLD="${DEEPFOREST_HPC_TILE_IOU_THRESHOLD:-0.15}" +export DEEPFOREST_HPC_TILE_DATALOADER_STRATEGY="${DEEPFOREST_HPC_TILE_DATALOADER_STRATEGY:-window}" + +mkdir -p "$DEEPFOREST_HPC_OUTPUT_DIR" + +ml "$CONDA_MODULE" +eval "$(conda shell.bash hook)" +conda activate "$CONDA_ENV_NAME" + +if [[ -z "${DEEPFOREST_HPC_TILE_PATH:-}" ]]; then + shopt -s globstar nullglob nocaseglob + tile_candidates=("$DEEPFOREST_HPC_TILE_DIR"/**/*.tif "$DEEPFOREST_HPC_TILE_DIR"/*.tif) + shopt -u globstar nullglob nocaseglob + + if (( ${#tile_candidates[@]} == 0 )); then + echo "No TIFF files found under: $DEEPFOREST_HPC_TILE_DIR" + exit 1 + fi + + export DEEPFOREST_HPC_TILE_PATH="${tile_candidates[0]}" +fi + +echo "Running cluster tiled predict integration test" +echo "TILE_PATH: $DEEPFOREST_HPC_TILE_PATH" +echo "OUTPUT_DIR: $DEEPFOREST_HPC_OUTPUT_DIR" +echo "MODEL_NAME: $DEEPFOREST_HPC_TILE_MODEL_NAME" +echo "NNODES: $DEEPFOREST_HPC_NNODES" +echo "GPUS_PER_NODE: $DEEPFOREST_HPC_GPUS_PER_NODE" +echo "PATCH_SIZE: $DEEPFOREST_HPC_TILE_PATCH_SIZE" +echo "PATCH_OVERLAP: $DEEPFOREST_HPC_TILE_PATCH_OVERLAP" +echo "DATALOADER_STRATEGY: $DEEPFOREST_HPC_TILE_DATALOADER_STRATEGY" +echo "CONDA_MODULE: $CONDA_MODULE" +echo "CONDA_ENV_NAME: $CONDA_ENV_NAME" + +if [[ ! -f "$DEEPFOREST_HPC_TILE_PATH" ]]; then + echo "Tile path not found: $DEEPFOREST_HPC_TILE_PATH" + exit 1 +fi + +uv run pytest -m cluster tests/test_cluster_predict_tile.py -s diff --git a/src/deepforest/scripts/HPC/run_cluster_train.sbatch b/src/deepforest/scripts/HPC/run_cluster_train.sbatch new file mode 100644 index 000000000..5f3b5fedc --- /dev/null +++ b/src/deepforest/scripts/HPC/run_cluster_train.sbatch @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +#SBATCH --job-name=deepforest-train +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=10 +#SBATCH --mem=128G +#SBATCH --time=15:00:00 +#SBATCH --output=./slurm_logs/train-%j.out +#SBATCH --error=./slurm_logs/train-%j.err +#SBATCH --signal=SIGUSR1@360 + +set -euo pipefail + +cd "${SLURM_SUBMIT_DIR:?SLURM_SUBMIT_DIR not set}" +HPC_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +bash "$HPC_DIR/run_cluster_train.sh" "$@" diff --git a/src/deepforest/scripts/HPC/run_cluster_train.sh b/src/deepforest/scripts/HPC/run_cluster_train.sh new file mode 100755 index 000000000..d3fd0cd39 --- /dev/null +++ b/src/deepforest/scripts/HPC/run_cluster_train.sh @@ -0,0 +1,196 @@ +#!/usr/bin/env bash +# Cluster training launcher for Slurm (HiPerGator and similar). +# +# TRAIN_MODE: +# train - full training (default CONFIG_NAME=bird) +# smoke - 1-epoch OSBS smoke test (default CONFIG_NAME=smoke) +# +# SCENARIO (smoke, or inferred for train when unset): +# 1gpu - single GPU +# multigpu - DDP on one node +# multinode - DDP across nodes +# +# Override at submit time, e.g.: +# sbatch src/deepforest/scripts/HPC/run_cluster_train.sbatch train.lr=0.0005 +# TRAIN_MODE=smoke SCENARIO=multigpu sbatch --nodes=1 --ntasks-per-node=2 --gpus-per-node=2 \ +# src/deepforest/scripts/HPC/run_cluster_train.sbatch +set -euo pipefail + +TRAIN_MODE="${TRAIN_MODE:-train}" +CONFIG_NAME="${CONFIG_NAME:-}" +SCENARIO="${SCENARIO:-}" +GPUS_PER_NODE="${GPUS_PER_NODE:-${SLURM_GPUS_ON_NODE:-${SLURM_GPUS_PER_NODE:-1}}}" +NNODES="${NNODES:-${SLURM_NNODES:-1}}" + +CONDA_MODULE="${CONDA_MODULE:-conda}" +CONDA_ENV_NAME="${CONDA_ENV_NAME:-predict}" +USE_COMET="${USE_COMET:-1}" + +if [[ -z "${SLURM_JOB_ID:-}" ]]; then + echo "SLURM_JOB_ID is not set. Submit with sbatch." + exit 1 +fi + +mkdir -p "${SLURM_SUBMIT_DIR:?SLURM_SUBMIT_DIR not set}/slurm_logs" +cd "${SLURM_SUBMIT_DIR}" +REPO_ROOT="$(pwd)" + +ml "$CONDA_MODULE" +eval "$(conda shell.bash hook)" +conda activate "$CONDA_ENV_NAME" + +if command -v uv &>/dev/null; then + DEEPFOREST_CMD=(uv run deepforest) +elif command -v deepforest &>/dev/null; then + DEEPFOREST_CMD=(deepforest) +else + echo "Neither 'uv' nor 'deepforest' found after conda activate $CONDA_ENV_NAME" + exit 1 +fi + +STRATEGY_ARGS=() +EXTRA_TRAIN_ARGS=() +COMET_ARGS=() +DEVICES=1 +NODES=1 + +if [[ "$TRAIN_MODE" == "smoke" ]]; then + CONFIG_NAME="${CONFIG_NAME:-smoke}" + SCENARIO="${SCENARIO:-1gpu}" + DATA_ROOT="${DATA_ROOT:-$REPO_ROOT/src/deepforest/data}" + TRAIN_CSV="${TRAIN_CSV:-$DATA_ROOT/OSBS_029.csv}" + TRAIN_ROOT="${TRAIN_ROOT:-$DATA_ROOT}" + VAL_CSV="${VAL_CSV:-$DATA_ROOT/OSBS_029.csv}" + VAL_ROOT="${VAL_ROOT:-$DATA_ROOT}" + LOG_ROOT="${LOG_ROOT:-$REPO_ROOT/lightning_logs_smoke}" + EXTRA_TRAIN_ARGS+=(--disable-checkpoint) + COMET_TAGS=(smoke) +else + CONFIG_NAME="${CONFIG_NAME:-bird}" + LOG_ROOT="${LOG_ROOT:-$REPO_ROOT/lightning_logs}" + COMET_TAGS=(bird) + if [[ -z "$SCENARIO" ]]; then + if [[ "$NNODES" -gt 1 ]]; then + SCENARIO=multinode + elif [[ "$GPUS_PER_NODE" -gt 1 ]]; then + SCENARIO=multigpu + else + SCENARIO=1gpu + fi + fi +fi + +case "$SCENARIO" in + 1gpu) + DEVICES=1 + NODES=1 + ;; + multigpu) + DEVICES="${GPUS_PER_NODE}" + NODES=1 + STRATEGY_ARGS=(--strategy ddp) + ;; + multinode) + DEVICES="${GPUS_PER_NODE}" + NODES="${NNODES}" + STRATEGY_ARGS=(--strategy ddp) + ;; + *) + echo "Unknown SCENARIO=$SCENARIO (use 1gpu, multigpu, or multinode)" + exit 1 + ;; +esac + +if [[ "$USE_COMET" == "1" ]]; then + export COMET_WORKSPACE="${COMET_WORKSPACE:-henrykironde}" + export COMET_PROJECT="${COMET_PROJECT:-bird-detector}" + if [[ -z "${COMET_API_KEY:-}" && -f "$HOME/.comet_api_key" ]]; then + export COMET_API_KEY="$(cat "$HOME/.comet_api_key")" + fi + if [[ -z "${COMET_API_KEY:-}" ]]; then + echo "WARNING: COMET_API_KEY not set; running without --comet" + else + COMET_ARGS=(--comet) + for tag in "${COMET_TAGS[@]}"; do + COMET_ARGS+=(--tag "$tag") + done + if [[ "$TRAIN_MODE" == "smoke" ]]; then + COMET_ARGS+=(--tag "$SCENARIO") + fi + if [[ -n "${COMET_EXPERIMENT_NAME:-}" ]]; then + COMET_ARGS+=(--experiment-name "$COMET_EXPERIMENT_NAME") + elif [[ "$TRAIN_MODE" == "smoke" ]]; then + COMET_ARGS+=(--experiment-name "smoke-${SCENARIO}-job${SLURM_JOB_ID}") + fi + fi +fi + +if [[ "$TRAIN_MODE" == "smoke" ]]; then + if [[ ! -f "$TRAIN_CSV" ]]; then + echo "Training CSV not found: $TRAIN_CSV" + exit 1 + fi +fi + +if [[ -f "$HOME/.secrets/hf_token" ]]; then + export HF_TOKEN="$(cat "$HOME/.secrets/hf_token")" +fi + +export NCCL_IB_DISABLE="${NCCL_IB_DISABLE:-1}" +export NCCL_NVLS_ENABLE="${NCCL_NVLS_ENABLE:-0}" +export NCCL_DEBUG="${NCCL_DEBUG:-INFO}" +export TORCH_NCCL_ASYNC_ERROR_HANDLING="${TORCH_NCCL_ASYNC_ERROR_HANDLING:-1}" +export PYTHONFAULTHANDLER="${PYTHONFAULTHANDLER:-1}" + +RESUME_CKPT="${RESUME_CKPT:-}" +RESUME_ARGS=() +if [[ -n "$RESUME_CKPT" ]]; then + RESUME_ARGS+=(--resume "$RESUME_CKPT") +fi + +echo "=== DeepForest cluster train ===" +echo "TRAIN_MODE=$TRAIN_MODE" +echo "CONFIG_NAME=$CONFIG_NAME" +echo "SCENARIO=$SCENARIO" +echo "HOSTNAME=$(hostname)" +echo "SLURM_JOB_ID=${SLURM_JOB_ID}" +echo "SLURM_NNODES=$NODES" +echo "GPUS_PER_NODE=$DEVICES" +echo "SLURM_NTASKS_PER_NODE=${SLURM_NTASKS_PER_NODE:-?}" +echo "USE_COMET=$USE_COMET" +echo "================================" + +TRAIN_ARGS=( + --config-name="$CONFIG_NAME" + train + "${STRATEGY_ARGS[@]}" + "${COMET_ARGS[@]}" + "${EXTRA_TRAIN_ARGS[@]}" + accelerator=gpu + "devices=$DEVICES" + "num_nodes=$NODES" + "log_root=$LOG_ROOT" + "${RESUME_ARGS[@]}" + "$@" +) + +if [[ "$TRAIN_MODE" == "smoke" ]]; then + TRAIN_ARGS+=( + workers=0 + "train.csv_file=$TRAIN_CSV" + "train.root_dir=$TRAIN_ROOT" + "validation.csv_file=$VAL_CSV" + "validation.root_dir=$VAL_ROOT" + ) +fi + +if [[ "$SCENARIO" == "1gpu" ]]; then + echo "Launching training (single GPU)" + "${DEEPFOREST_CMD[@]}" "${TRAIN_ARGS[@]}" +else + echo "Launching training via srun (distributed)" + srun --kill-on-bad-exit=1 --export=ALL --cpu-bind=none \ + "${DEEPFOREST_CMD[@]}" "${TRAIN_ARGS[@]}" +fi + +echo "Cluster training finished: mode=$TRAIN_MODE scenario=$SCENARIO" diff --git a/src/deepforest/scripts/cli.py b/src/deepforest/scripts/cli.py index 5231cea9f..77e09ddc6 100644 --- a/src/deepforest/scripts/cli.py +++ b/src/deepforest/scripts/cli.py @@ -42,8 +42,8 @@ def main(): ) train_parser.add_argument( "--strategy", - help="Training strategy to use (e.g., 'ddp', 'auto')", - default="auto", + help="Optional training strategy override (e.g., 'ddp'). Defaults to config value.", + default=None, ) train_parser.add_argument( "--resume", diff --git a/src/deepforest/scripts/evaluate.py b/src/deepforest/scripts/evaluate.py index e9bd3a155..bd457ab2a 100644 --- a/src/deepforest/scripts/evaluate.py +++ b/src/deepforest/scripts/evaluate.py @@ -4,6 +4,7 @@ import pandas as pd from omegaconf import DictConfig +from deepforest import distributed from deepforest.main import deepforest @@ -55,10 +56,13 @@ def evaluate( csv_file=ground_truth, root_dir=root_dir, ) + results["class_recall"] = getattr(m.precision_recall_metric, "_class_recall", None) # Save generated predictions if requested and they were generated (not loaded from file) - if save_predictions is not None: + if save_predictions is not None and distributed.is_global_zero(m.trainer): predictions_df = results.get("predictions") + if predictions_df is None: + predictions_df = pd.DataFrame() if predictions_df is not None and not predictions_df.empty: if os.path.dirname(save_predictions): os.makedirs(os.path.dirname(save_predictions), exist_ok=True) @@ -71,6 +75,9 @@ def evaluate( ) # Print results to console + if not distributed.is_global_zero(m.trainer): + return + m.print("Evaluation Results:") for key, value in results.items(): if key not in ["predictions", "results", "ground_df", "class_recall"]: diff --git a/src/deepforest/scripts/predict.py b/src/deepforest/scripts/predict.py index f4c8e9a42..fd3e6cdd6 100644 --- a/src/deepforest/scripts/predict.py +++ b/src/deepforest/scripts/predict.py @@ -2,7 +2,7 @@ from omegaconf import DictConfig -from deepforest import utilities +from deepforest import distributed, utilities from deepforest.main import deepforest from deepforest.visualize import plot_results @@ -75,7 +75,7 @@ def predict( else: raise ValueError(f"Invalid prediction mode: {mode}. Pick one of single/tile/csv.") - if output_path is not None: + if output_path is not None and distributed.is_global_zero(m.trainer): if os.path.dirname(output_path): os.makedirs(os.path.dirname(output_path), exist_ok=True) if output_path.endswith(".shp") or output_path.endswith(".gpkg"): @@ -84,5 +84,5 @@ def predict( else: res.to_csv(output_path, index=False) - if plot: + if plot and distributed.is_global_zero(m.trainer): plot_results(res) diff --git a/src/deepforest/scripts/train.py b/src/deepforest/scripts/train.py index 779ac4ada..767d9b74e 100644 --- a/src/deepforest/scripts/train.py +++ b/src/deepforest/scripts/train.py @@ -10,6 +10,7 @@ from pytorch_lightning.callbacks import DeviceStatsMonitor, ModelCheckpoint from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger +from deepforest import distributed from deepforest.callbacks import ImagesCallback from deepforest.main import deepforest @@ -21,7 +22,7 @@ def train( tensorboard: bool = False, trace: bool = False, resume: str | None = None, - strategy: str = "auto", + strategy: str | None = None, experiment_name: str | None = None, tags: list[str] | None = None, ) -> bool: @@ -141,13 +142,15 @@ def train( checkpoint_callback.CHECKPOINT_EQUALS_CHAR = ":" callbacks.append(checkpoint_callback) - m.create_trainer( - logger=loggers, - callbacks=callbacks, - gradient_clip_val=0.5, - accelerator=config.accelerator, - strategy=strategy, - ) + trainer_kwargs = { + "logger": loggers, + "callbacks": callbacks, + "gradient_clip_val": 0.5, + } + if strategy is not None: + trainer_kwargs["strategy"] = strategy + + m.create_trainer(**trainer_kwargs) # Add experiment ID to hyperparameters if available if experiment_id is not None: @@ -156,8 +159,9 @@ def train( current_hparams["experiment_id"] = experiment_id m.save_hyperparameters(current_hparams) - os.makedirs(csv_logger.log_dir, exist_ok=True) - OmegaConf.save(config, Path(csv_logger.log_dir) / "config.yaml") + if distributed.is_global_zero(m.trainer): + os.makedirs(csv_logger.log_dir, exist_ok=True) + OmegaConf.save(config, Path(csv_logger.log_dir) / "config.yaml") train_success = False try: @@ -170,12 +174,12 @@ def train( ) warnings.warn(traceback.format_exc(), stacklevel=2) - if trace and torch.cuda.is_available(): + if trace and torch.cuda.is_available() and distributed.is_global_zero(m.trainer): torch.cuda.memory._dump_snapshot( filename=Path(csv_logger.log_dir) / "dump_snapshot.pickle" ) - if checkpoint: + if checkpoint and distributed.is_global_zero(m.trainer): for logger in m.trainer.loggers: if hasattr(logger.experiment, "log_model"): for checkpoint in glob.glob( diff --git a/src/deepforest/utilities.py b/src/deepforest/utilities.py index 5c0591545..57828881c 100644 --- a/src/deepforest/utilities.py +++ b/src/deepforest/utilities.py @@ -315,6 +315,9 @@ def __shapefile_to_annotations__( print(f"CRS of image is {raster_crs}") gdf = geo_to_image_coordinates(gdf, src.bounds, src.res[0]) + gdf = DeepForest_DataFrame(gdf) + gdf.root_dir = os.path.dirname(full_image_path) + return gdf @@ -664,9 +667,11 @@ def read_file( gdf = __assign_image_path__(gdf, image_path=image_path) gdf = __assign_root_dir__(input, gdf, root_dir=root_dir) gdf = DeepForest_DataFrame(gdf) + original_root_dir = gdf.root_dir gdf_list = [] for image_path in gdf.image_path.unique(): - image_annotations = gdf[gdf.image_path == image_path] + image_annotations = DeepForest_DataFrame(gdf[gdf.image_path == image_path]) + image_annotations.root_dir = original_root_dir gdf = __shapefile_to_annotations__(image_annotations) gdf_list.append(gdf) @@ -674,6 +679,7 @@ def read_file( gdf = pd.concat(gdf_list) gdf = gpd.GeoDataFrame(gdf) gdf = DeepForest_DataFrame(gdf) + gdf.root_dir = original_root_dir gdf = __check_and_assign_label__(gdf, label=label) elif isinstance(input, pd.DataFrame): @@ -813,6 +819,7 @@ def geo_to_image_coordinates(gdf, image_bounds, image_resolution): if len(image_bounds) != 4: raise ValueError("image_bounds must be a tuple of (left, bottom, right, top)") + root_dir = getattr(gdf, "root_dir", None) transformed_gdf = gdf.copy(deep=True) # unpack image bounds left, bottom, right, top = image_bounds @@ -822,6 +829,9 @@ def geo_to_image_coordinates(gdf, image_bounds, image_resolution): xfact=1 / image_resolution, yfact=-1 / image_resolution, origin=(0, 0) ) transformed_gdf.crs = None + transformed_gdf = DeepForest_DataFrame(transformed_gdf) + if root_dir is not None: + transformed_gdf.root_dir = root_dir return transformed_gdf diff --git a/test.py b/test.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/cluster_predict_tile_driver.py b/tests/cluster_predict_tile_driver.py new file mode 100644 index 000000000..6c84a2a24 --- /dev/null +++ b/tests/cluster_predict_tile_driver.py @@ -0,0 +1,45 @@ +import argparse + +from deepforest import distributed +from deepforest.main import deepforest + + +def main() -> None: + parser = argparse.ArgumentParser(description="Distributed large-tile prediction driver") + parser.add_argument("--input-path", required=True) + parser.add_argument("--output-path", required=True) + parser.add_argument( + "--model-name", + default="weecology/everglades-bird-species-detector", + ) + parser.add_argument("--patch-size", type=int, default=1500) + parser.add_argument("--patch-overlap", type=float, default=0) + parser.add_argument("--iou-threshold", type=float, default=0.15) + parser.add_argument("--dataloader-strategy", default="window") + parser.add_argument("--devices", type=int, default=1) + parser.add_argument("--num-nodes", type=int, default=1) + args = parser.parse_args() + + model = deepforest() + model.load_model(model_name=args.model_name) + model.config.accelerator = "gpu" + model.config.devices = args.devices + model.config.strategy = "ddp" + model.config.num_nodes = args.num_nodes + model.config.workers = 0 + model.create_trainer() + + results = model.predict_tile( + path=args.input_path, + patch_size=args.patch_size, + patch_overlap=args.patch_overlap, + iou_threshold=args.iou_threshold, + dataloader_strategy=args.dataloader_strategy, + ) + + if distributed.is_global_zero(model.trainer) and results is not None: + results.to_csv(args.output_path, index=False) + + +if __name__ == "__main__": + main() diff --git a/tests/test_cluster_predict.py b/tests/test_cluster_predict.py new file mode 100644 index 000000000..ea38efbc9 --- /dev/null +++ b/tests/test_cluster_predict.py @@ -0,0 +1,68 @@ +import os +import shlex +import subprocess +from pathlib import Path + +import pandas as pd +import pytest + + +def _required_env(name: str) -> str: + value = os.environ.get(name) + if not value: + pytest.skip(f"{name} is not set") + return value + + +def _build_predict_command(output_path: Path) -> list[str]: + repo_root = Path(__file__).resolve().parents[1] + input_csv = _required_env("DEEPFOREST_HPC_PREDICT_CSV") + root_dir = _required_env("DEEPFOREST_HPC_ROOT_DIR") + nnodes = os.environ.get("DEEPFOREST_HPC_NNODES", os.environ.get("SLURM_NNODES", "1")) + gpus_per_node = os.environ.get("DEEPFOREST_HPC_GPUS_PER_NODE", "1") + + return [ + "srun", + f"--nodes={nnodes}", + f"--ntasks-per-node={gpus_per_node}", + "bash", + "-lc", + ( + f"set -euo pipefail && cd {shlex.quote(str(repo_root))} && " + f"uv run deepforest predict {shlex.quote(input_csv)} " + f"--mode csv --root-dir {shlex.quote(root_dir)} " + f"-o {shlex.quote(str(output_path))} " + f"--strategy ddp accelerator=gpu devices={shlex.quote(str(gpus_per_node))} " + f"num_nodes={shlex.quote(str(nnodes))}" + ), + ] + + +@pytest.mark.integration +@pytest.mark.cluster +def test_multinode_predict_cli_on_cluster(tmp_path): + if os.environ.get("RUN_CLUSTER_TESTS") != "1": + pytest.skip("Set RUN_CLUSTER_TESTS=1 to enable cluster integration tests") + + if "SLURM_JOB_ID" not in os.environ: + pytest.skip("This test must run inside a Slurm allocation") + + output_dir = Path( + os.environ.get("DEEPFOREST_HPC_OUTPUT_DIR", str(tmp_path / "cluster_outputs")) + ) + output_dir.mkdir(parents=True, exist_ok=True) + + output_path = output_dir / f"predict_job_{os.environ['SLURM_JOB_ID']}.csv" + if output_path.exists(): + output_path.unlink() + + command = _build_predict_command(output_path) + subprocess.run(command, check=True) + + assert output_path.exists() + + predictions = pd.read_csv(output_path) + assert not predictions.empty + assert {"image_path", "xmin", "ymin", "xmax", "ymax", "score", "label"}.issubset( + predictions.columns + ) diff --git a/tests/test_cluster_predict_tile.py b/tests/test_cluster_predict_tile.py new file mode 100644 index 000000000..f49abdec1 --- /dev/null +++ b/tests/test_cluster_predict_tile.py @@ -0,0 +1,80 @@ +import os +import shlex +import subprocess +from pathlib import Path + +import pandas as pd +import pytest + + +def _required_env(name: str) -> str: + value = os.environ.get(name) + if not value: + pytest.skip(f"{name} is not set") + return value + + +def _build_predict_tile_command(output_path: Path) -> list[str]: + repo_root = Path(__file__).resolve().parents[1] + tile_path = _required_env("DEEPFOREST_HPC_TILE_PATH") + model_name = os.environ.get( + "DEEPFOREST_HPC_TILE_MODEL_NAME", + "weecology/everglades-bird-species-detector", + ) + nnodes = os.environ.get("DEEPFOREST_HPC_NNODES", os.environ.get("SLURM_NNODES", "1")) + gpus_per_node = os.environ.get("DEEPFOREST_HPC_GPUS_PER_NODE", "1") + patch_size = os.environ.get("DEEPFOREST_HPC_TILE_PATCH_SIZE", "1500") + patch_overlap = os.environ.get("DEEPFOREST_HPC_TILE_PATCH_OVERLAP", "0") + iou_threshold = os.environ.get("DEEPFOREST_HPC_TILE_IOU_THRESHOLD", "0.15") + dataloader_strategy = os.environ.get("DEEPFOREST_HPC_TILE_DATALOADER_STRATEGY", "window") + + return [ + "srun", + f"--nodes={nnodes}", + f"--ntasks-per-node={gpus_per_node}", + "bash", + "-lc", + ( + f"set -euo pipefail && cd {shlex.quote(str(repo_root))} && " + f"uv run python tests/cluster_predict_tile_driver.py " + f"--input-path {shlex.quote(tile_path)} " + f"--output-path {shlex.quote(str(output_path))} " + f"--model-name {shlex.quote(str(model_name))} " + f"--patch-size {shlex.quote(str(patch_size))} " + f"--patch-overlap {shlex.quote(str(patch_overlap))} " + f"--iou-threshold {shlex.quote(str(iou_threshold))} " + f"--dataloader-strategy {shlex.quote(str(dataloader_strategy))} " + f"--devices {shlex.quote(str(gpus_per_node))} " + f"--num-nodes {shlex.quote(str(nnodes))}" + ), + ] + + +@pytest.mark.integration +@pytest.mark.cluster +def test_multinode_predict_tile_on_cluster(tmp_path): + if os.environ.get("RUN_CLUSTER_TESTS") != "1": + pytest.skip("Set RUN_CLUSTER_TESTS=1 to enable cluster integration tests") + + if "SLURM_JOB_ID" not in os.environ: + pytest.skip("This test must run inside a Slurm allocation") + + output_dir = Path( + os.environ.get("DEEPFOREST_HPC_OUTPUT_DIR", str(tmp_path / "cluster_outputs")) + ) + output_dir.mkdir(parents=True, exist_ok=True) + + output_path = output_dir / f"predict_tile_job_{os.environ['SLURM_JOB_ID']}.csv" + if output_path.exists(): + output_path.unlink() + + command = _build_predict_tile_command(output_path) + subprocess.run(command, check=True) + + assert output_path.exists() + + predictions = pd.read_csv(output_path) + assert not predictions.empty + assert {"image_path", "xmin", "ymin", "xmax", "ymax", "score", "label"}.issubset( + predictions.columns + ) diff --git a/tests/test_datasets_prediction.py b/tests/test_datasets_prediction.py index a0f2519f3..7db1dd2c0 100644 --- a/tests/test_datasets_prediction.py +++ b/tests/test_datasets_prediction.py @@ -116,3 +116,32 @@ def test_translate_predictions_points(): assert translated.geometry.iloc[0].y == pytest.approx(306.0) assert translated.geometry.iloc[1].x == pytest.approx(210.0) assert translated.geometry.iloc[1].y == pytest.approx(412.0) + + +def test_SingleImage_metadata_batch(): + path = get_data("OSBS_029.png") + ds = SingleImage(path=path, patch_size=300, patch_overlap=0, return_metadata=True) + + sample = ds[0] + assert sample["image"].shape == (3, 300, 300) + assert sample["metadata"]["image_path"] == os.path.basename(path) + assert sample["metadata"]["window_bounds"] == ds.get_crop_bounds(0) + + batch = ds.collate_fn([ds[0], ds[1]]) + assert len(batch["images"]) == 2 + assert len(batch["metadata"]) == 2 + + +def test_MultiImage_metadata_batch(): + path = get_data("OSBS_029.png") + ds = MultiImage( + paths=[path, path], + patch_size=300, + patch_overlap=0, + return_metadata=True, + ) + + batch = ds.collate_fn([ds[0]]) + assert len(batch["images"]) == 4 + assert len(batch["metadata"]) == 4 + assert all(item["image_path"] == os.path.basename(path) for item in batch["metadata"]) diff --git a/tests/test_main.py b/tests/test_main.py index c403483e1..6ebacdf6d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -169,6 +169,41 @@ def test_tensorboard_logger(m, tmp_path): print("TensorBoard is not installed. Skipping test_tensorboard_logger.") +def test_create_trainer_uses_runtime_config(monkeypatch): + captured_kwargs = {} + + class FakeTrainer: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + monkeypatch.setattr(main.pl, "Trainer", FakeTrainer) + + m = main.deepforest( + config_args={ + "model": {"name": None}, + "num_classes": 1, + "label_dict": {"Tree": 0}, + } + ) + + m.config.devices = 2 + m.config.accelerator = "gpu" + m.config.num_nodes = 3 + m.config.strategy = "ddp" + m.config.precision = "16-mixed" + m.config.sync_batchnorm = True + m.config.use_distributed_sampler = True + + m.create_trainer() + + assert captured_kwargs["devices"] == 2 + assert captured_kwargs["accelerator"] == "gpu" + assert captured_kwargs["num_nodes"] == 3 + assert captured_kwargs["strategy"] == "ddp" + assert captured_kwargs["precision"] == "16-mixed" + assert captured_kwargs["sync_batchnorm"] is True + assert captured_kwargs["use_distributed_sampler"] is True + def test_load_model(m): imgpath = get_data("OSBS_029.png") @@ -1205,6 +1240,38 @@ def test_recall_not_lowered_by_unprocessed_images(): f"box_recall={results['box_recall']:.2f}, expected 1.0" ) +def test_predict_dataloader_small_dataset_rank_zero(monkeypatch, m): + ds = prediction.FromCSVFile( + csv_file=get_data("OSBS_029.csv"), + root_dir=os.path.dirname(get_data("OSBS_029.csv")), + return_metadata=True, + ) + + monkeypatch.setattr(main.distributed, "is_distributed", lambda: True) + monkeypatch.setattr(main.distributed, "get_world_size", lambda: 2) + monkeypatch.setattr(main.distributed, "get_rank", lambda: 0) + + loader = m.predict_dataloader(ds, batch_size=1) + + assert len(loader.sampler) == 1 + assert list(iter(loader.sampler)) == [0] + + +def test_predict_dataloader_small_dataset_extra_rank(monkeypatch, m): + ds = prediction.FromCSVFile( + csv_file=get_data("OSBS_029.csv"), + root_dir=os.path.dirname(get_data("OSBS_029.csv")), + return_metadata=True, + ) + + monkeypatch.setattr(main.distributed, "is_distributed", lambda: True) + monkeypatch.setattr(main.distributed, "get_world_size", lambda: 2) + monkeypatch.setattr(main.distributed, "get_rank", lambda: 1) + + loader = m.predict_dataloader(ds, batch_size=1) + + assert len(loader.sampler) == 0 + assert list(iter(loader.sampler)) == [] def test_custom_log_root(m, tmpdir): """Test that setting a custom log_root creates logs in the expected location""" custom_log_dir = tmpdir.join("custom_logs")