diff --git a/src/deepforest/augmentations.py b/src/deepforest/augmentations.py index 5a65d25ce..29741b9fd 100644 --- a/src/deepforest/augmentations.py +++ b/src/deepforest/augmentations.py @@ -51,8 +51,9 @@ def get_available_augmentations() -> list[str]: def get_transform( augmentations: str | list[str] | dict[str, Any] | None = None, + task: str = "box", ) -> A.Compose: - """Create Albumentations transform for bounding boxes. + """Create Albumentations transform for boxes or keypoints. Args: augmentations: Augmentation configuration: @@ -60,6 +61,7 @@ def get_transform( - list: List of augmentation names - dict: Dict with names as keys and params as values - None: No augmentations + task: Task type - "box" for bounding boxes or "keypoint" for keypoints Returns: Composed albumentations transform @@ -79,9 +81,13 @@ def get_transform( ... "HorizontalFlip": {"p": 0.5}, ... "Downscale": {"scale_min": 0.25, "scale_max": 0.75} ... }) + + >>> # Keypoint augmentations + >>> transform = get_transform(augmentations=["HorizontalFlip"], task="keypoint") """ transforms_list = [] bbox_params = None + keypoint_params = None if augmentations is not None: augment_configs = _parse_augmentations(augmentations) @@ -90,12 +96,19 @@ def get_transform( aug_transform = _create_augmentation(aug_name, aug_params) transforms_list.append(aug_transform) - bbox_params = A.BboxParams(format="pascal_voc", label_fields=["category_ids"]) + if task == "box": + bbox_params = A.BboxParams(format="pascal_voc", label_fields=["labels"]) + elif task == "keypoint": + keypoint_params = A.KeypointParams(format="xy", label_fields=["labels"]) + else: + raise ValueError(f"Unsupported task: {task}. Must be 'box' or 'keypoint'.") # Always add ToTensorV2 at the end transforms_list.append(ToTensorV2()) - return A.Compose(transforms_list, bbox_params=bbox_params) + return A.Compose( + transforms_list, bbox_params=bbox_params, keypoint_params=keypoint_params + ) def _parse_augmentations( diff --git a/src/deepforest/callbacks.py b/src/deepforest/callbacks.py index c5883293c..d192dacd4 100644 --- a/src/deepforest/callbacks.py +++ b/src/deepforest/callbacks.py @@ -151,12 +151,12 @@ def _log_last_predictions(self, trainer, pl_module): else: selected_images = df.image_path.unique()[: self.prediction_samples] - # Ensure color is correctly assigned - if self.color is None: - num_classes = len(df["label"].unique()) - results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes) - else: - results_color = self.color + # Ensure color is correctly assigned + if self.color is None: + num_classes = len(df["label"].unique()) + results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes) + else: + results_color = self.color for image_name in selected_images: pred_df = df[df.image_path == image_name] diff --git a/src/deepforest/conf/bird.yaml b/src/deepforest/conf/bird.yaml new file mode 100644 index 000000000..7b7baf2ad --- /dev/null +++ b/src/deepforest/conf/bird.yaml @@ -0,0 +1,10 @@ +# Ensure we inherit from default config + overlay these overrides. +defaults: + - config + - _self_ + +task: 'box' + +model: + name: 'weecology/deepforest-bird' + revision: 'main' diff --git a/src/deepforest/conf/config.yaml b/src/deepforest/conf/config.yaml index 6ae059410..6294b3736 100644 --- a/src/deepforest/conf/config.yaml +++ b/src/deepforest/conf/config.yaml @@ -8,6 +8,7 @@ accelerator: auto batch_size: 1 # Model Architecture +task: 'box' architecture: 'retinanet' num_classes: 1 nms_thresh: 0.05 @@ -66,6 +67,10 @@ train: min_lr: 0 eps: 0.00000001 + # Currently sgd and adamw are supported. If you use Adam, + # make sure your learning rate is lowered sufficiently. + optimizer: sgd + # How many epochs to run for epochs: 1 # Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings. diff --git a/src/deepforest/conf/keypoint.yaml b/src/deepforest/conf/keypoint.yaml new file mode 100644 index 000000000..d66eee60a --- /dev/null +++ b/src/deepforest/conf/keypoint.yaml @@ -0,0 +1,33 @@ +# Config file for DeepForest keypoint detection tasks + +# Ensure we inherit from default config + overlay these overrides. +defaults: + - config + - _self_ + +# Task and Model Architecture +task: 'keypoint' +architecture: 'DeformableDetr' + +# Keypoint-specific parameters for Deformable DETR +# point_cost: Relative weight of point distance in matching cost (default: 5.0) +# point_loss_coefficient: Weight of point loss in total loss (default: 5.0) +# point_loss_type: Type of loss for coordinates - "l1" (default) or "mse" +point_cost: 5.0 +point_loss_coefficient: 5.0 +point_loss_type: 'l1' + +# For keypoint detection, start from pretrained Deformable DETR backbone +# Override with our DETR backbone once trained. +model: + name: 'SenseTime/deformable-detr' + revision: 'main' + +# Transformer-based models often prefer lower learning rates +train: + lr: 0.0001 + +# Pixel distance threshold for keypoint matching (instead of IoU for boxes) +# A prediction is considered correct if within this many pixels of ground truth +validation: + pixel_distance_threshold: 10.0 diff --git a/src/deepforest/conf/livestock.yaml b/src/deepforest/conf/livestock.yaml new file mode 100644 index 000000000..79fee6f25 --- /dev/null +++ b/src/deepforest/conf/livestock.yaml @@ -0,0 +1,10 @@ +# Ensure we inherit from default config + overlay these overrides. +defaults: + - config + - _self_ + +task: 'box' + +model: + name: 'weecology/deepforest-livestock' + revision: 'main' diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index f78e01d94..cdf866431 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -1,7 +1,5 @@ from dataclasses import dataclass, field -from omegaconf import MISSING - @dataclass class ModelConfig: @@ -53,6 +51,11 @@ class TrainConfig: architectures, such as transformers-based models which sometimes prefer a lower learning rate. + The optimizer can be "sgd" (with momentum=0.9) or "adamw". SGD is + the default and works well for RetinaNet. AdamW is recommended for + transformer-based models like DeformableDetr, typically with a lower + learning rate (e.g., 1e-4 to 5e-4). + The number of epochs should be user-specified and depends on the size of the dataset (e.g. how many iterations the model will train for and how diverse the imagery is). DeepForest uses Lightning to @@ -60,14 +63,16 @@ class TrainConfig: sanity checking. """ - csv_file: str | None = MISSING - root_dir: str | None = MISSING + csv_file: str | None = None + root_dir: str | None = None lr: float = 0.001 + optimizer: str = "sgd" scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) epochs: int = 1 fast_dev_run: bool = False preload_images: bool = False augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"]) + log_root: str = "logs" @dataclass @@ -79,14 +84,16 @@ class ValidationConfig: converged or is overfitting. """ - csv_file: str | None = MISSING - root_dir: str | None = MISSING + csv_file: str | None = None + root_dir: str | None = None preload_images: bool = False size: int | None = None iou_threshold: float = 0.4 val_accuracy_interval: int = 20 lr_plateau_target: str = "val_loss" augmentations: list[str] | None = field(default_factory=lambda: []) + # Keypoint-specific validation (used when task="keypoint") + pixel_distance_threshold: float = 10.0 @dataclass @@ -130,21 +137,27 @@ class Config: accelerator: str = "auto" batch_size: int = 1 + task: str = "box" architecture: str = "retinanet" num_classes: int = 1 label_dict: dict[str, int] = field(default_factory=lambda: {"Tree": 0}) + # Keypoint-specific parameters (used when task="keypoint") + point_cost: float = 5.0 + point_loss_coefficient: float = 5.0 + point_loss_type: str = "l1" + nms_thresh: float = 0.05 score_thresh: float = 0.1 model: ModelConfig = field(default_factory=ModelConfig) # Preprocessing - path_to_raster: str | None = MISSING + path_to_raster: str | None = None patch_size: int = 400 patch_overlap: float = 0.05 - annotations_xml: str | None = MISSING - rgb_dir: str | None = MISSING - path_to_rgb: str | None = MISSING + annotations_xml: str | None = None + rgb_dir: str | None = None + path_to_rgb: str | None = None train: TrainConfig = field(default_factory=TrainConfig) validation: ValidationConfig = field(default_factory=ValidationConfig) diff --git a/src/deepforest/conf/tree.yaml b/src/deepforest/conf/tree.yaml new file mode 100644 index 000000000..c50d0fdbd --- /dev/null +++ b/src/deepforest/conf/tree.yaml @@ -0,0 +1,10 @@ +# Ensure we inherit from default config + overlay these overrides. +defaults: + - config + - _self_ + +task: 'box' + +model: + name: 'weecology/deepforest-tree' + revision: 'main' diff --git a/src/deepforest/datasets/training.py b/src/deepforest/datasets/training.py index 790d4b70c..dcd923386 100644 --- a/src/deepforest/datasets/training.py +++ b/src/deepforest/datasets/training.py @@ -3,6 +3,7 @@ import os import numpy as np +import pandas as pd import shapely import torch from PIL import Image @@ -106,7 +107,7 @@ def collate_fn(self, batch): return images, targets, image_names def load_image(self, idx): - img_name = os.path.join(self.root_dir, self.image_names[idx]) + img_name = os.path.join(self.root_dir, os.path.basename(self.image_names[idx])) image = np.array(Image.open(img_name).convert("RGB")) / 255 image = image.astype("float32") return image @@ -173,20 +174,225 @@ def __getitem__(self, idx): augmented = self.transform( image=image, bboxes=targets["boxes"], - category_ids=targets["labels"].astype(np.int64), + labels=targets["labels"].astype(np.int64), ) image = augmented["image"] # Convert boxes to tensor boxes = np.array(augmented["bboxes"]) boxes = torch.from_numpy(boxes).float() - labels = np.array(augmented["category_ids"]) + labels = np.array(augmented["labels"]) labels = torch.from_numpy(labels.astype(np.int64)) targets = {"boxes": boxes, "labels": labels} return image, targets, self.image_names[idx] +# TODO: Combine training dataset classes to reduce duplication? +class KeypointDataset(Dataset): + """Dataset for keypoint/point detection. + + Supports two CSV formats: + + 1. Keypoint format (direct): + image_path,x,y,label + - x, y: Keypoint coordinates in pixels + + 2. Bounding box format (auto-converted to center points): + image_path,xmin,ymin,xmax,ymax,label + + Args: + csv_file: Path to CSV file with keypoint annotations + root_dir: Directory containing images + transforms: Function applied to each sample + augmentations: Augmentation configuration + label_dict: Mapping from string labels to class IDs + preload_images: Preload all images into memory + + Returns: + List of (image, target) pairs where target contains: + - "points": torch.Tensor of shape (N, 2) - (x, y) coordinates + - "labels": torch.Tensor of shape (N,) - class indices + """ + + def __init__( + self, + csv_file, + root_dir, + *, + transforms=None, + augmentations=None, + label_dict=None, + preload_images=False, + ): + """ + Args: + csv_file (str): Path to the CSV file containing keypoint annotations. + root_dir (str): Directory containing all referenced images. + transforms (callable, optional): Function applied to each sample. Defaults to None. + augmentations (str | list | dict, optional): Augmentation configuration. + label_dict (dict[str, int]): Mapping from string labels to integer class IDs. + preload_images (bool): If True, preload all images into memory. Defaults to False. + + Returns: + list: A list of (image, target) pairs, where each target is a dict with: + - "points": torch.Tensor of shape (N, 2) + - "labels": torch.Tensor of shape (N,) + """ + self.annotations = pd.read_csv(csv_file) + self.root_dir = root_dir + + # Check if CSV has keypoint columns or box columns + keypoint_columns = {"image_path", "x", "y", "label"} + box_columns_xyxy = {"image_path", "xmin", "ymin", "xmax", "ymax", "label"} + + if not keypoint_columns.issubset(self.annotations.columns): + if box_columns_xyxy.issubset(self.annotations.columns): + # Use box centers as keypoints + self.annotations["x"] = ( + self.annotations["xmin"] + self.annotations["xmax"] + ) / 2 + self.annotations["y"] = ( + self.annotations["ymin"] + self.annotations["ymax"] + ) / 2 + else: + raise ValueError( + f"CSV must contain either keypoint columns {keypoint_columns} " + f"or box columns {box_columns_xyxy}. " + f"Found: {set(self.annotations.columns)}" + ) + + # Initialize label_dict with default if None + if label_dict is None: + label_dict = {"Tree": 0} + self.label_dict = label_dict + + if transforms is None: + self.transform = get_transform(augmentations=augmentations, task="keypoint") + else: + self.transform = transforms + + self.image_names = self.annotations.image_path.unique() + self.preload_images = preload_images + + self._validate_labels() + + # Pin data to memory if desired + if self.preload_images: + print("Pinning dataset to GPU memory") + self.image_dict = {} + for idx, _ in enumerate(self.image_names): + self.image_dict[idx] = self.load_image(idx) + + def _validate_labels(self): + """Validate that all labels in annotations exist in label_dict. + + Raises: + ValueError: If any label in annotations is missing from label_dict + """ + csv_labels = self.annotations["label"].unique() + missing_labels = [label for label in csv_labels if label not in self.label_dict] + + if missing_labels: + raise ValueError( + f"Labels {missing_labels} are missing from label_dict. " + f"Please ensure all labels in the annotations exist as keys in label_dict." + ) + + def __len__(self): + return len(self.image_names) + + def collate_fn(self, batch): + """Collate function for DataLoader.""" + images = [item[0] for item in batch] + targets = [item[1] for item in batch] + image_names = [item[2] for item in batch] + + return images, targets, image_names + + def load_image(self, idx): + """Load and preprocess an image.""" + img_name = os.path.join(self.root_dir, os.path.basename(self.image_names[idx])) + image = np.array(Image.open(img_name).convert("RGB")) / 255 + image = image.astype("float32") + return image + + def annotations_for_path(self, image_path, return_tensor=False): + """Construct target dictionary for a given image path. + + Args: + image_path (str): Path to image, expected to be in dataframe + return_tensor (bool): If true, convert fields from numpy to tensor + + Returns: + target dictionary with points and labels entries + """ + image_annotations = self.annotations[self.annotations.image_path == image_path] + targets = {} + + # Extract x, y coordinates as points + targets["points"] = image_annotations[["x", "y"]].values.astype("float32") + + # Labels need to be encoded + targets["labels"] = image_annotations.label.apply( + lambda x: self.label_dict[x] + ).values.astype(np.int64) + + if return_tensor: + for k, v in targets.items(): + targets[k] = torch.from_numpy(v) + + return targets + + def __getitem__(self, idx): + # Read image if not in memory + if self.preload_images: + image = self.image_dict[idx] + else: + image = self.load_image(idx) + + # TODO: Not sure we need this any more, will check before review + orig_h, orig_w = image.shape[0], image.shape[1] + + targets = self.annotations_for_path(self.image_names[idx]) + + # If image has no annotations, don't augment + if len(targets["points"]) == 0: + points = torch.zeros((0, 2), dtype=torch.float32) + labels = torch.zeros(0, dtype=torch.int64) + # channels last + image = np.rollaxis(image, 2, 0) + image = torch.from_numpy(image).float() + targets = {"points": points, "labels": labels} + + return image, targets, self.image_names[idx] + + augmented = self.transform( + image=image, + keypoints=targets["points"], + labels=targets["labels"].astype(np.int64), + ) + image = augmented["image"] + + # Convert points back from augmented keypoints + if len(augmented["keypoints"]) > 0: + points = np.array(augmented["keypoints"]) + points = torch.from_numpy(points).float() + else: + points = torch.zeros((0, 2), dtype=torch.float32) + + labels = np.array(augmented["labels"]) + labels = torch.from_numpy(labels.astype(np.int64)) + + targets = { + "points": points, + "labels": labels, + "orig_size": torch.tensor([orig_h, orig_w], dtype=torch.int64), + } + + return image, targets, self.image_names[idx] + + # ---------- ImageFolder alignment utilities ---------- diff --git a/src/deepforest/evaluate.py b/src/deepforest/evaluate.py index d4f76a9af..02b12c750 100644 --- a/src/deepforest/evaluate.py +++ b/src/deepforest/evaluate.py @@ -7,7 +7,7 @@ import pandas as pd import shapely -from deepforest import IoU +from deepforest import IoU, keypoint_distance from deepforest.utilities import determine_geometry_type @@ -80,17 +80,39 @@ def compute_class_recall(results): return class_recall -def __evaluate_wrapper__(predictions, ground_df, iou_threshold, numeric_to_label_dict): +def __evaluate_wrapper__( + predictions, + ground_df, + match_threshold=None, + iou_threshold=None, + numeric_to_label_dict=None, +): """Evaluate a set of predictions against a ground truth csv file Args: predictions: a pandas dataframe, if supplied a root dir is needed to give the relative path of files in df.name. The labels in ground truth and predictions must match. If one is numeric, the other must be numeric. ground_df: a pandas dataframe, if supplied a root dir is needed to give the relative path of files in df.name - iou_threshold: intersection-over-union threshold, see deepforest.evaluate + match_threshold: matching threshold - IoU for boxes (default 0.4), pixel distance for keypoints (default 10.0) + iou_threshold: DEPRECATED - use match_threshold instead + numeric_to_label_dict: dictionary mapping numeric labels to string labels Returns: results: a dictionary of results with keys, results, box_recall, box_precision, class_recall """ + + if iou_threshold is not None: + warnings.warn( + "iou_threshold parameter is deprecated and will be removed in a future version. " + "Use match_threshold instead (IoU for boxes, pixel distance for keypoints).", + DeprecationWarning, + stacklevel=2, + ) + if match_threshold is None: + match_threshold = iou_threshold + else: + raise ValueError("Found both iou_threshold and match_threshold set.") + # remove empty samples from ground truth - ground_df = ground_df[~((ground_df.xmin == 0) & (ground_df.xmax == 0))] + if "xmin" in ground_df.columns and "xmax" in ground_df.columns: + ground_df = ground_df[~((ground_df.xmin == 0) & (ground_df.xmax == 0))] # Default results for blank predictions if predictions.empty: @@ -138,10 +160,12 @@ def __evaluate_wrapper__(predictions, ground_df, iou_threshold, numeric_to_label prediction_geometry = determine_geometry_type(predictions) if prediction_geometry == "point": - raise NotImplementedError("Point evaluation is not yet implemented") + results = evaluate_keypoints( + predictions=predictions, ground_df=ground_df, pixel_threshold=match_threshold + ) elif prediction_geometry == "box": results = evaluate_boxes( - predictions=predictions, ground_df=ground_df, iou_threshold=iou_threshold + predictions=predictions, ground_df=ground_df, iou_threshold=match_threshold ) else: raise NotImplementedError(f"Geometry type {prediction_geometry} not implemented") @@ -398,3 +422,173 @@ def point_recall(predictions, ground_df): class_recall = compute_class_recall(matched_results) return {"results": results, "box_recall": box_recall, "class_recall": class_recall} + + +def evaluate_image_keypoints(predictions, ground_df): + """Match predicted keypoints to ground truth keypoints for one image using + Hungarian algorithm. + + Args: + predictions: a geopandas dataframe with Point geometry + ground_df: a geopandas dataframe with Point geometry + + Returns: + result: pandas dataframe with matched keypoints, distance, and labels + """ + plot_names = predictions["image_path"].unique() + if len(plot_names) > 1: + raise ValueError( + f"More than one plot passed to image keypoint evaluation: {plot_names}" + ) + + # Use keypoint_distance module for Hungarian matching + result = keypoint_distance.compute_distances(ground_df, predictions) + + # Map prediction/truth IDs back to their original labels + pred_label_dict = predictions.label.to_dict() + ground_label_dict = ground_df.label.to_dict() + result["predicted_label"] = result.prediction_id.map(pred_label_dict) + result["true_label"] = result.truth_id.map(ground_label_dict) + + return result + + +def evaluate_keypoints(predictions, ground_df, pixel_threshold=10.0): + """Evaluate keypoint detection predictions against ground truth. + + Uses Hungarian algorithm to optimally match predicted keypoints to ground truth + based on Euclidean pixel distance. This is mostly identical to the box evaluation. + + Args: + predictions: a geopandas dataframe with Point geometry + ground_df: a geopandas dataframe with Point geometry + pixel_threshold: maximum pixel distance for a match to be considered valid + + Returns: + dict with keys: + - results: dataframe of matched keypoints with distance and labels + - recall: overall recall (proportion of ground truth matched) + - precision: overall precision (proportion of predictions matched) + - class_recall: per-class recall and precision metrics + - predictions: original predictions dataframe + - ground_df: original ground truth dataframe + """ + # If all empty ground truth, return 0 recall and precision + if ground_df.empty: + return { + "results": None, + "recall": None, + "precision": 0, + "class_recall": None, + "predictions": predictions, + "ground_df": ground_df, + } + + # Convert to GeoDataFrame if needed and create Point geometries + if not isinstance(predictions, gpd.GeoDataFrame): + if "geometry" not in predictions.columns and all( + col in predictions.columns for col in ["x", "y"] + ): + predictions = predictions.copy() + predictions["geometry"] = predictions.apply( + lambda row: shapely.geometry.Point(row.x, row.y), axis=1 + ) + predictions = gpd.GeoDataFrame(predictions, geometry="geometry") + + if not isinstance(ground_df, gpd.GeoDataFrame): + if "geometry" not in ground_df.columns and all( + col in ground_df.columns for col in ["x", "y"] + ): + ground_df = ground_df.copy() + ground_df["geometry"] = ground_df.apply( + lambda row: shapely.geometry.Point(row.x, row.y), axis=1 + ) + ground_df = gpd.GeoDataFrame(ground_df, geometry="geometry") + + # Pre-group predictions by image + predictions_by_image = { + name: group.reset_index(drop=True) + for name, group in predictions.groupby("image_path") + } + + # Run evaluation on all images + results = [] + recalls = [] + precisions = [] + for image_path, group in ground_df.groupby("image_path"): + # Predictions for this image + image_predictions = predictions_by_image.get(image_path, pd.DataFrame()) + if not isinstance(image_predictions, pd.DataFrame) or image_predictions.empty: + image_predictions = pd.DataFrame() + + # If empty, add to list without computing matching + if image_predictions.empty: + # Reset index to ensure consistent DataFrame creation + group_reset = group.reset_index(drop=True) + result = pd.DataFrame( + { + "truth_id": group_reset.index.values, + "prediction_id": pd.Series([None] * len(group_reset), dtype="object"), + "distance": pd.Series([np.inf] * len(group_reset), dtype="float64"), + "predicted_label": pd.Series( + [None] * len(group_reset), dtype="object" + ), + "score": pd.Series([None] * len(group_reset), dtype="float64"), + "match": pd.Series([False] * len(group_reset), dtype="bool"), + "true_label": group_reset.label.astype("object"), + "geometry": group_reset.geometry, + } + ) + recalls.append(0) + results.append(result) + continue + else: + group = group.reset_index(drop=True) + result = evaluate_image_keypoints( + predictions=image_predictions, ground_df=group + ) + + result["image_path"] = image_path + result["match"] = result.distance <= pixel_threshold + # Convert None to False for boolean consistency + result["match"] = result["match"].fillna(False) + true_positive = sum(result["match"]) + recall = true_positive / result.shape[0] + precision = true_positive / image_predictions.shape[0] + + recalls.append(recall) + precisions.append(precision) + results.append(result) + + # Concatenate results + if results: + results = pd.concat(results, ignore_index=True) + else: + columns = [ + "truth_id", + "prediction_id", + "distance", + "predicted_label", + "score", + "match", + "true_label", + "geometry", + "image_path", + ] + results = pd.DataFrame(columns=columns) + + precision = np.mean(precisions) + recall = np.mean(recalls) + + # Only matching keypoints are considered in class recall + matched_results = results[results.match] + class_recall = compute_class_recall(matched_results) + + return { + "results": results, + "precision": precision, + "recall": recall, + "class_recall": class_recall, + "predictions": predictions, + "ground_df": ground_df, + } diff --git a/src/deepforest/keypoint_distance.py b/src/deepforest/keypoint_distance.py new file mode 100644 index 000000000..a3a6bc7ab --- /dev/null +++ b/src/deepforest/keypoint_distance.py @@ -0,0 +1,138 @@ +"""Keypoint Distance Module for matching predicted keypoints to ground truth. + +Similar to IoU.py but uses Euclidean pixel distance instead of +intersection-over-union. +""" + +import geopandas as gpd +import numpy as np +import pandas as pd +from scipy.optimize import linear_sum_assignment + + +def _compute_distances(predictions: "gpd.GeoDataFrame", ground_truth: "gpd.GeoDataFrame"): + """Computes pairwise Euclidean distances between all predicted and ground + truth keypoints. + + Args: + predictions: GeoDataFrame with Point geometry (predicted keypoints) + ground_truth: GeoDataFrame with Point geometry (ground truth keypoints) + + Returns: + distances: (n_truth, n_pred) array of Euclidean distances in pixels + truth_ids: (n_truth,) truth index values + pred_ids: (n_pred,) prediction index values + """ + # Extract coordinates from Point geometries + pred_coords = np.array([[p.x, p.y] for p in predictions.geometry]) + truth_coords = np.array([[p.x, p.y] for p in ground_truth.geometry]) + + pred_ids = predictions.index.to_numpy() + truth_ids = ground_truth.index.to_numpy() + + n_pred = len(pred_coords) + n_truth = len(truth_coords) + + # Handle empty cases + if n_pred == 0 or n_truth == 0: + return ( + np.full((n_truth, n_pred), np.inf, dtype=float), + truth_ids, + pred_ids, + ) + + # Compute pairwise Euclidean distances + # Broadcasting: (n_truth, 1, 2) - (1, n_pred, 2) = (n_truth, n_pred, 2) + distances = np.sqrt( + ((truth_coords[:, np.newaxis, :] - pred_coords[np.newaxis, :, :]) ** 2).sum( + axis=2 + ) + ) + + return distances, truth_ids, pred_ids + + +# TODO - consider making this a shared/generic function with IoU where we can pass in +# indices + costs. +def compute_distances(ground_truth: "gpd.GeoDataFrame", predictions: "gpd.GeoDataFrame"): + """Match predicted keypoints to ground truth using Hungarian algorithm with + pixel distance. + + This function performs matching between ground truth and predicted keypoints. + For each ground truth keypoint, we compute the Euclidean pixel distance to all + predictions. These distances are used as the cost matrix for Hungarian matching, + which ensures that each ground truth is matched to at most one prediction, and + each prediction is used at most once, minimizing the total distance. + + No filtering on distance threshold or score is performed - that happens downstream. + + Args: + ground_truth: a geopandas dataframe with Point geometry + predictions: a geopandas dataframe with Point geometry + + Returns: + distance_df: dataframe with columns: + - prediction_id: matched prediction ID (or None if no match) + - truth_id: ground truth ID + - distance: Euclidean pixel distance + - score: prediction confidence score (if available) + - geometry: ground truth geometry + """ + # Compute pairwise distances + distance_matrix, truth_ids, pred_ids = _compute_distances( + predictions=predictions, ground_truth=ground_truth + ) + + if distance_matrix.size == 0: + # No matches, early exit + return pd.DataFrame( + { + "prediction_id": pd.Series(dtype="float64"), + "truth_id": pd.Series(dtype=truth_ids.dtype), + "distance": pd.Series(dtype="float64"), + "score": pd.Series(dtype="float64"), + "geometry": pd.Series(dtype=object), + } + ) + + # Linear sum assignment (minimizes total distance) + # We want to MINIMIZE distance, so no need for maximize=True + row_ind, col_ind = linear_sum_assignment(distance_matrix, maximize=False) + match_for_truth = dict(zip(row_ind, col_ind, strict=False)) + + # Score lookup + pred_scores = predictions["score"].to_dict() if "score" in predictions.columns else {} + + # Build rows for every truth element (unmatched => None, distance inf) + records = [] + for t_idx, truth_id in enumerate(truth_ids): + # If we matched this truth keypoint + if t_idx in match_for_truth: + # Look up matching prediction and corresponding distance and score + p_idx = match_for_truth[t_idx] + matched_id = pred_ids[p_idx] + distance = float(distance_matrix[t_idx, p_idx]) + score = pred_scores.get(matched_id, None) + else: + matched_id = None + distance = np.inf + score = None + + records.append( + { + "prediction_id": matched_id, + "truth_id": truth_id, + "distance": distance, + "score": score, + } + ) + + # Output dataframe + distance_df = pd.DataFrame.from_records(records) + distance_df = distance_df.merge( + ground_truth.assign(truth_id=truth_ids)[["truth_id", "geometry"]], + on="truth_id", + how="left", + ) + + return distance_df diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 3d53375cf..d775dbd15 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -1,5 +1,4 @@ # entry point for deepforest model -import importlib import os import warnings @@ -29,7 +28,7 @@ class deepforest(pl.LightningModule): model: DeepForest model object existing_train_dataloader: PyTorch dataloader for training data existing_val_dataloader: PyTorch dataloader for validation data - config: DeepForest configuration object + config: DeepForest configuration object or name (tree, bird, etc) config_args: Dictionary of config overrides """ @@ -39,17 +38,19 @@ def __init__( transforms=None, existing_train_dataloader=None, existing_val_dataloader=None, - config: DictConfig = None, + config: str | DictConfig = "tree", config_args: dict | None = None, ): super().__init__() - # If not provided, load default config via OmegaConf. - if config is None: - config = utilities.load_config(overrides=config_args) + # Default/string config name + if isinstance(config, str): + config = utilities.load_config(config_name=config, overrides=config_args) # Hub overrides elif "config_args" in config: config = utilities.load_config(overrides=config["config_args"]) + elif config is None: + config = utilities.load_config(overrides=config_args) elif config_args is not None: warnings.warn( f"Ignoring options as configuration object was provided: {config_args}", @@ -114,10 +115,35 @@ def load_model(self, model_name=None, revision=None): if revision is None: revision = self.config.model.revision - model_class = importlib.import_module( - f"deepforest.models.{self.config.architecture}" - ) - self.model = model_class.Model(config=self.config).create_model( + # TODO - utility/model function for this block + # Load appropriate model based on task type + if self.config.task == "box": + # For box detection, use specified architecture + if self.config.architecture == "DeformableDetr": + from deepforest.models.DeformableDetr import Model + elif self.config.architecture == "retinanet": + from deepforest.models.retinanet import Model + else: + raise ValueError( + f"Unknown architecture: '{self.config.architecture}'. " + f"Supported: 'DeformableDetr', 'retinanet'" + ) + elif self.config.task == "keypoint": + # For keypoint detection, use keypoint model + if self.config.architecture == "DeformableDetr": + from deepforest.models.keypoint import Model + else: + raise ValueError( + f"Unknown architecture: '{self.config.architecture}'. " + f"Supported: 'DeformableDetr'" + ) + else: + raise ValueError( + f"Invalid task type: '{self.config.task}'. " + f"Must be either 'box' or 'keypoint'." + ) + + self.model = Model(config=self.config).create_model( pretrained=model_name, revision=revision ) @@ -180,10 +206,27 @@ def create_model(self, initialize_model=False): None """ if self.config.model.name is None or initialize_model: - model_class = importlib.import_module( - f"deepforest.models.{self.config.architecture}" - ) - self.model = model_class.Model(config=self.config).create_model() + # TODO: DRY + # Load appropriate model based on task type + if self.config.task == "box": + if self.config.architecture == "DeformableDetr": + from deepforest.models.DeformableDetr import Model + elif self.config.architecture == "retinanet": + from deepforest.models.retinanet import Model + else: + raise ValueError( + f"Unknown architecture: '{self.config.architecture}'. " + f"Supported: 'DeformableDetr', 'retinanet'" + ) + elif self.config.task == "keypoint": + from deepforest.models.keypoint import Model + else: + raise ValueError( + f"Invalid task type: '{self.config.task}'. " + f"Must be either 'box' or 'keypoint'." + ) + + self.model = Model(config=self.config).create_model() self.set_labels(self.config.label_dict) else: self.load_model() @@ -310,14 +353,31 @@ def load_dataset( ds: a pytorch dataset """ - ds = training.BoxDataset( - csv_file=csv_file, - root_dir=root_dir, - transforms=transforms, - label_dict=self.label_dict, - augmentations=augmentations, - preload_images=preload_images, - ) + # Create appropriate dataset based on task type + # TODO: could this be a factory function/similar so we call dataset(task=x)? + if self.config.task == "box": + ds = training.BoxDataset( + csv_file=csv_file, + root_dir=root_dir, + transforms=transforms, + label_dict=self.label_dict, + augmentations=augmentations, + preload_images=preload_images, + ) + elif self.config.task == "keypoint": + ds = training.KeypointDataset( + csv_file=csv_file, + root_dir=root_dir, + transforms=transforms, + label_dict=self.label_dict, + augmentations=augmentations, + preload_images=preload_images, + ) + else: + raise ValueError( + f"Invalid task type: '{self.config.task}'. " + f"Must be either 'box' or 'keypoint'." + ) if len(ds) == 0: raise ValueError( f"Dataset from {csv_file} is empty. Check CSV for valid entries and columns." @@ -604,11 +664,11 @@ def predict_tile( # Perform mosaic for each image_path, or all if image_path is None mosaic_results = [] if results["image_path"].isnull().all(): - mosaic_results.append(predict.mosiac(results, iou_threshold=iou_threshold)) + mosaic_results.append(predict.mosaic(results, iou_threshold=iou_threshold)) else: for image_path in results["image_path"].unique(): image_results = results[results["image_path"] == image_path] - image_mosaic = predict.mosiac(image_results, iou_threshold=iou_threshold) + image_mosaic = predict.mosaic(image_results, iou_threshold=iou_threshold) image_mosaic["image_path"] = image_path mosaic_results.append(image_mosaic) @@ -654,17 +714,28 @@ def training_step(self, batch, batch_idx): # allow for empty data if data augmentation is generated images, targets, image_names = batch - loss_dict = self.model.forward(images, targets) + model_output = self.model.forward(images, targets) + + # TODO: This is messy to handle some other metrics from the keypoint model, but + # ideally we'd log them from somewhere else. + + # Handle different return formats (keypoint returns dict, box returns loss_dict) + if isinstance(model_output, dict) and "loss_dict" in model_output: + loss_dict = model_output["loss_dict"] + else: + loss_dict = model_output # sum of regression and classification loss - losses = sum(loss_dict.values()) + losses = sum([loss_dict[k] for k in loss_dict.keys() if k != "cardinality_error"]) # Log loss for key, value in loss_dict.items(): self.log(f"train_{key}", value, on_epoch=True, batch_size=len(images)) - # Log sum of losses - self.log("train_loss", losses, on_epoch=True, batch_size=len(images)) + # Log sum of losses and show in pbar + self.log( + "train_loss", losses, on_epoch=True, batch_size=len(images), prog_bar=True + ) return losses @@ -676,7 +747,14 @@ def validation_step(self, batch, batch_idx): # Torchvision does not return loss in eval mode. self.model.train() with torch.no_grad(): - loss_dict = self.model.forward(images, targets) + model_output = self.model.forward(images, targets) + + # TODO: same as above + # Handle different return formats (keypoint returns dict, box returns loss_dict) + if isinstance(model_output, dict) and "loss_dict" in model_output: + loss_dict = model_output["loss_dict"] + else: + loss_dict = model_output # sum of regression and classification loss losses = sum(loss_dict.values()) @@ -690,8 +768,8 @@ def validation_step(self, batch, batch_idx): except MisconfigurationException: pass - # In eval model, return predictions to calculate prediction metrics - preds = self.model.eval() + # In eval mode, return predictions to calculate prediction metrics + self.model.eval() with torch.no_grad(): preds = self.model.forward(images, targets) @@ -700,12 +778,22 @@ def validation_step(self, batch, batch_idx): filtered_preds = [] filtered_targets = [] for i, target in enumerate(targets): - if target["boxes"].shape[0] > 0: + # Check for non-empty targets based on task type + if "boxes" in target: + has_annotations = target["boxes"].shape[0] > 0 + elif "points" in target: + has_annotations = target["points"].shape[0] > 0 + else: + has_annotations = False + + if has_annotations: filtered_preds.append(preds[i]) filtered_targets.append(target) - self.iou_metric.update(filtered_preds, filtered_targets) - self.mAP_metric.update(filtered_preds, filtered_targets) + # Box/polygon metrics + if self.config.task == "box": + self.iou_metric.update(filtered_preds, filtered_targets) + self.mAP_metric.update(filtered_preds, filtered_targets) # Log the predictions if you want to use them for evaluation logs for i, result in enumerate(preds): @@ -785,6 +873,7 @@ def calculate_empty_frame_accuracy(self, ground_df, predictions_df): return empty_accuracy def log_epoch_metrics(self): + # Should be zero for points as the IoU metric has nothing in it. if len(self.iou_metric.groundtruth_labels) > 0: output = self.iou_metric.compute() # Lightning bug: claims this is a warning but it's not. See issue #16218 in Lightning-AI/pytorch-lightning @@ -911,9 +1000,18 @@ def predict_batch(self, images, preprocess_fn=None): return results def configure_optimizers(self): - optimizer = optim.SGD( - self.model.parameters(), lr=self.config.train.lr, momentum=0.9 - ) + optimizer_type = self.config.train.optimizer.lower() + + if optimizer_type == "adamw": + optimizer = optim.AdamW(self.model.parameters(), lr=self.config.train.lr) + elif optimizer_type == "sgd": + optimizer = optim.SGD( + self.model.parameters(), lr=self.config.train.lr, momentum=0.9 + ) + else: + raise ValueError( + f"Unknown optimizer: '{optimizer_type}'. Supported: 'sgd', 'adamw'" + ) scheduler_config = self.config.train.scheduler scheduler_type = scheduler_config.type @@ -977,18 +1075,22 @@ def lr_lambda(epoch): def evaluate( self, csv_file, + match_threshold=None, iou_threshold=None, root_dir=None, size=None, batch_size=None, predictions=None, ): - """Compute intersection-over-union and precision/recall for a given - iou_threshold. + """Compute precision/recall metrics for predictions against ground + truth. Args: - csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label" - iou_threshold: float [0,1] intersection-over-union threshold for true positive + csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label" (boxes) + or "image_path","x","y","label" (keypoints) + match_threshold: matching threshold - IoU [0,1] for boxes, pixel distance for keypoints. + If None, uses config.validation.iou_threshold or pixel_distance_threshold. + iou_threshold: deprecated - use match_threshold instead batch_size: int, the batch size to use for prediction. If None, uses the batch size of the model. size: int, the size to resize the images to. If None, no resizing is done. predictions: list of predictions to use for evaluation. If None, predictions are generated from the model. @@ -996,10 +1098,35 @@ def evaluate( Returns: dict: Results dictionary containing precision, recall and other metrics """ + # Deprecate IoU threshold in favour of a more generic matching threshold + # TODO: keep separate / different eval for evaluate wrapper? + if iou_threshold is not None: + import warnings + + warnings.warn( + "iou_threshold parameter is deprecated and will be removed in a future version. " + "Use match_threshold instead (IoU for boxes, pixel distance for keypoints).", + DeprecationWarning, + stacklevel=2, + ) + if match_threshold is None: + match_threshold = iou_threshold + self.model.eval() ground_df = utilities.read_file(csv_file) ground_df["label"] = ground_df.label.apply(lambda x: self.label_dict[x]) + # Convert box geometries to points for keypoint tasks + if self.config.task == "keypoint": + geom_type = utilities.determine_geometry_type(ground_df) + if geom_type == "box": + # Convert box centers to points (same as KeypointDataset does) + import shapely.geometry + + ground_df["geometry"] = ground_df.geometry.apply( + lambda box: shapely.geometry.Point(box.centroid.x, box.centroid.y) + ) + if root_dir is None: root_dir = os.path.dirname(csv_file) @@ -1009,17 +1136,20 @@ def evaluate( csv_file, root_dir, size=size, batch_size=batch_size ) - if iou_threshold is None: - iou_threshold = self.config.validation.iou_threshold + if match_threshold is None: + # Use task-specific threshold from config + if self.config.task == "keypoint": + match_threshold = self.config.validation.pixel_distance_threshold + else: + match_threshold = self.config.validation.iou_threshold results = evaluate_iou.__evaluate_wrapper__( predictions=predictions, ground_df=ground_df, - iou_threshold=iou_threshold, + match_threshold=match_threshold, numeric_to_label_dict=self.numeric_to_label_dict, ) - # empty frame accuracy empty_accuracy = self.calculate_empty_frame_accuracy(ground_df, predictions) results["empty_frame_accuracy"] = empty_accuracy diff --git a/src/deepforest/models/keypoint.py b/src/deepforest/models/keypoint.py new file mode 100644 index 000000000..8cb0c3dcf --- /dev/null +++ b/src/deepforest/models/keypoint.py @@ -0,0 +1,946 @@ +"""This code is largely derived from the transformers +DeformableDetrForObjectDetection class, with additional support for processing +and loss calculation. Several functions have sections that are copied mostly +verbatim due to only a few lines changing. + +Under the Apache 2.0 license, transformers code is copyright 2018- The +Hugging Face team. All rights reserved. +https://github.com/huggingface/transformers?tab=Apache-2.0-1-ov-file +""" + +from dataclasses import dataclass + +import numpy as np +import PIL +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn +from transformers import ( + DeformableDetrConfig, + DeformableDetrModel, + DeformableDetrPreTrainedModel, +) +from transformers.image_utils import ChannelDimension, get_image_size +from transformers.loss.loss_deformable_detr import DeformableDetrImageLoss +from transformers.loss.loss_for_object_detection import HungarianMatcher +from transformers.models.deformable_detr.image_processing_deformable_detr import ( + DeformableDetrImageProcessor, +) +from transformers.models.deformable_detr.modeling_deformable_detr import ( + DeformableDetrMLPPredictionHead, + inverse_sigmoid, +) +from transformers.utils import ModelOutput + + +class DeformableDetrKeypointConfig(DeformableDetrConfig): + """Configuration for Deformable DETR keypoint detection. + + Extends DeformableDetrConfig with keypoint-specific parameters. + """ + + def __init__( + self, + point_cost: float = 5.0, + point_loss_coefficient: float = 5.0, + point_loss_type: str = "l1", + **kwargs, + ): + """ + Args: + point_cost: The relative weight of the point distance in the matching cost. + point_loss_coefficient: The coefficient for the point loss in the total loss. + point_loss_type: Type of loss for point coordinates. Options: "l1" (default, standard for DETR) or "mse" (L2). + **kwargs: Additional arguments passed to DeformableDetrConfig. + """ + super().__init__(**kwargs) + self.point_cost = point_cost + self.point_loss_coefficient = point_loss_coefficient + if point_loss_type not in ["l1", "mse"]: + raise ValueError( + f"point_loss_type must be 'l1' or 'mse', got '{point_loss_type}'" + ) + self.point_loss_type = point_loss_type + + +@dataclass +class DeformableDetrKeypointDetectionOutput(ModelOutput): + r"""init_reference_points (`torch.FloatTensor` of shape `(batch_size, + num_queries, 2)`): + + Initial reference points sent through the Transformer decoder. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 2)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*): + NOT CURRENTLY SUPPORTED. Would be used for two-stage detection where encoder predicts + initial keypoint proposals. Currently always None. + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 2)`, *optional*): + NOT CURRENTLY SUPPORTED. Would be used for two-stage detection where encoder predicts + initial keypoint coordinates. Currently always None. + """ + + loss: torch.FloatTensor | None = None + loss_dict: dict | None = None + logits: torch.FloatTensor | None = None + pred_points: torch.FloatTensor | None = None + auxiliary_outputs: list[dict] | None = None + init_reference_points: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + enc_outputs_class: torch.FloatTensor | None = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + + +class DeformableDetrKeypointMatcher(HungarianMatcher): + """Hungarian matcher for keypoint detection using L2 distance for matching + cost. + + Note: The matcher always uses L2 (Euclidean) distance for computing matching cost. + The actual training loss (L1 vs MSE) is configured separately in DeformableDetrKeypointLoss. + + Args: + class_cost: Relative weight of the classification error in the matching cost. + point_cost: Relative weight of the L2 point distance in the matching cost. + """ + + def __init__(self, class_cost: float = 1, point_cost: float = 1): + # Map point_cost to parent's bbox_cost parameter (semantically it's for points here) + super().__init__(class_cost=class_cost, bbox_cost=point_cost, giou_cost=0) + + @torch.no_grad() + def forward(self, outputs, targets): + """ + Matches predicted keypoints to ground truth using: + - Classification cost (focal loss) + - L2 point distance cost + """ + batch_size, num_queries = outputs["logits"].shape[:2] + + # Flatten to compute cost matrices in a batch + out_prob = ( + outputs["logits"].flatten(0, 1).sigmoid() + ) # [batch_size * num_queries, num_classes] + out_points = outputs["pred_points"].flatten(0, 1) # [batch_size * num_queries, 2] + + # Concatenate target labels and points + target_ids = torch.cat([v["class_labels"] for v in targets]) + target_points = torch.cat([v["points"] for v in targets]) + + # Compute approximate classification cost + class_cost = -out_prob[:, target_ids] + + # Compute L1 point distance cost + point_cost = torch.cdist(out_points, target_points, p=1) + + # Final cost matrix + # Note: self.bbox_cost was set to point_cost value in __init__ + cost_matrix = self.class_cost * class_cost + self.bbox_cost * point_cost + cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() + + sizes = [len(v["points"]) for v in targets] + + indices = [ + linear_sum_assignment(c[i]) + for i, c in enumerate(cost_matrix.split(sizes, -1)) + ] + return [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices + ] + + +class DeformableDetrKeypointLoss(DeformableDetrImageLoss): + """Loss for keypoint detection using focal loss and configurable point + distance loss. + + Inherits loss_labels, loss_cardinality, and forward from DeformableDetrImageLoss. + Only adds loss computation for points via get_loss override. + + Args: + loss_type: Type of loss for coordinates - "l1" (default, standard for DETR) or "mse" (L2). + """ + + def __init__(self, matcher, num_classes, focal_alpha, losses, loss_type="l1"): + super().__init__(matcher, num_classes, focal_alpha, losses) + if loss_type not in ["l1", "mse"]: + raise ValueError(f"loss_type must be 'l1' or 'mse', got '{loss_type}'") + self.loss_type = loss_type + + def loss_points(self, outputs, targets, indices, num_objects): + """Distance loss for keypoint coordinates (L1 or MSE based on + config).""" + idx = self._get_source_permutation_idx(indices) + src_points = outputs["pred_points"][idx] + target_points = torch.cat( + [t["points"][i] for t, (_, i) in zip(targets, indices, strict=False)], dim=0 + ) + + # Apply configured loss type + if self.loss_type == "l1": + loss_point = nn.functional.l1_loss( + src_points, target_points, reduction="none" + ) + else: # mse + loss_point = nn.functional.mse_loss( + src_points, target_points, reduction="none" + ) + + losses = {f"loss_point_{self.loss_type}": loss_point.sum() / num_objects} + + return losses + + def get_loss(self, loss, outputs, targets, indices, num_objects): + """Extend parent loss map to support 'points'.""" + loss_map = { + "labels": self.loss_labels, + "cardinality": self.loss_cardinality, + "points": self.loss_points, + } + if loss not in loss_map: + raise ValueError(f"Loss {loss} not supported") + return loss_map[loss](outputs, targets, indices, num_objects) + + +def DeformableDetrForKeypointDetectionLoss( + logits, + labels, + device, + pred_points, + config, + outputs_class=None, + outputs_coord=None, + **kwargs, +): + """Loss function for keypoint detection.""" + # Point matching can use Hungarian, just like boxes + point_cost = getattr( + config, "point_cost", config.bbox_cost + ) # Fallback for backwards compatibility + matcher = DeformableDetrKeypointMatcher( + class_cost=config.class_cost, point_cost=point_cost + ) + + # Setup the criterion, default L1 but L2/MSE is also allowed + losses = ["labels", "points", "cardinality"] + loss_type = getattr(config, "point_loss_type", "l1") + criterion = DeformableDetrKeypointLoss( + matcher=matcher, + num_classes=config.num_labels, + focal_alpha=config.focal_alpha, + losses=losses, + loss_type=loss_type, + ) + criterion.to(device) + + # Compute individual losses + outputs_loss = {} + auxiliary_outputs = None + outputs_loss["logits"] = logits + outputs_loss["pred_points"] = pred_points + if config.auxiliary_loss: + # Adapt _set_aux_loss for points instead of boxes + auxiliary_outputs = [ + {"logits": a, "pred_points": b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1], strict=False) + ] + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + + loss_dict = criterion(outputs_loss, labels) + # Compute total loss + point_loss_coefficient = getattr( + config, "point_loss_coefficient", config.bbox_loss_coefficient + ) + weight_dict = {"loss_ce": 1, f"loss_point_{loss_type}": point_loss_coefficient} + if config.auxiliary_loss: + aux_weight_dict = {} + for i in range(config.decoder_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict if k in weight_dict) + + return loss, loss_dict, auxiliary_outputs + + +def prepare_keypoint_annotation( + image, + target, + input_data_format: ChannelDimension | str | None = None, +): + """Convert keypoint annotations into the format expected by DeformableDetr. + + Expected input format: + { + "image_id": int, + "annotations": [ + { + "category_id": int, + "keypoints": [x1,y1,...], + "keypoints": [x, y] or [[x1, y1], [x2, y2], ...], # Single or multiple keypoints + }, + ... + ] + } + + Output format: + { + "image_id": array, + "labels": array of shape (num_keypoints,), + "points": array of shape (num_keypoints, 2), # Internal representation + "orig_size": array of shape (2,) # [height, width] + } + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + + image_id = target["image_id"] + image_id = np.asarray([image_id], dtype=np.int64) + + # Get all annotations for the given image + annotations = target["annotations"] + + classes = [obj["category_id"] for obj in annotations] + classes = np.asarray(classes, dtype=np.int64) + + # Extract keypoints + keypoints_list = [] + for obj in annotations: + if "keypoints" in obj: + kpts = obj["keypoints"] + kpts = np.asarray(kpts, dtype=np.float32) + + # Handle different keypoint formats + if kpts.ndim == 1: + # Single keypoint: [x, y] + if len(kpts) != 2: + raise ValueError( + f"Expected keypoint to have 2 coordinates (x, y), got {len(kpts)}" + ) + kpts = kpts.reshape(1, 2) + elif kpts.ndim == 2: + # Multiple keypoints: [[x1, y1], [x2, y2], ...] + if kpts.shape[1] != 2: + raise ValueError( + f"Expected keypoints to have 2 coordinates per point (x, y), got {kpts.shape[1]}" + ) + else: + raise ValueError(f"Invalid keypoint format with {kpts.ndim} dimensions") + # If objects, they're separate + elif "bbox" in obj: + x, y, w, h = obj["bbox"] + kpt = np.asarray([[x + w / 2, y + h / 2]], dtype=np.float32) + + keypoints_list.append(kpts) + + # Flatten all keypoints and their corresponding class labels + all_points = [] + all_classes = [] + + for i, kpts in enumerate(keypoints_list): + for kpt in kpts: + all_points.append(kpt) + all_classes.append(classes[i]) + + points = np.array(all_points, dtype=np.float32).reshape(-1, 2) + + # Clip points to image boundaries + points[:, 0] = points[:, 0].clip(min=0, max=image_width) + points[:, 1] = points[:, 1].clip(min=0, max=image_height) + + new_target = {} + new_target["image_id"] = image_id + new_target["labels"] = np.array(all_classes, dtype=np.int64) + # Transformers library expects "class_labels" key for loss computation + new_target["class_labels"] = new_target["labels"] + new_target["points"] = points + new_target["orig_size"] = np.asarray( + [int(image_height), int(image_width)], dtype=np.int64 + ) + + return new_target + + +def normalize_keypoint_annotation(annotation: dict, image_size: tuple[int, int]) -> dict: + """Normalize keypoint annotations to [0, 1] coordinate space. + + Args: + annotation: Dictionary containing "points" key with shape (num_points, 2) + image_size: Tuple of (height, width) + + Returns: + Normalized annotation dictionary with points scaled to [0, 1] + """ + image_height, image_width = image_size + norm_annotation = annotation.copy() + + # Normalize points: divide x by width, y by height + if "points" in annotation: + norm_annotation["points"] = annotation["points"] / np.array( + [image_width, image_height], dtype=np.float32 + ) + + # Ensure both "labels" and "class_labels" are preserved for transformers compatibility + if "labels" in norm_annotation and "class_labels" not in norm_annotation: + norm_annotation["class_labels"] = norm_annotation["labels"] + + return norm_annotation + + +class DeformableDetrKeypointImageProcessor(DeformableDetrImageProcessor): + """Image processor for keypoint detection with Deformable DETR. + + Extends DeformableDetrImageProcessor to handle keypoint annotations + instead of bounding boxes. Uses "keypoints" in external API and + "points" for internal model representation. + """ + + def prepare_annotation( + self, + image: np.ndarray, + target: dict, + format: str | None = None, + return_segmentation_masks=None, + masks_path=None, + input_data_format: str | ChannelDimension | None = None, + ) -> dict: + """Prepare a keypoint annotation for feeding into DeformableDetr model. + + Overrides parent to handle keypoint annotations instead of + bounding boxes. + """ + return prepare_keypoint_annotation( + image, target, input_data_format=input_data_format + ) + + def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict: + """Normalize keypoint annotations to [0, 1] coordinate space. + + Called by parent's preprocess() method when + do_convert_annotations=True. + """ + return normalize_keypoint_annotation(annotation, image_size) + + def resize_annotation( + self, + annotation, + orig_size, + size, + resample: PIL.Image.Resampling = PIL.Image.Resampling.NEAREST, + ) -> dict: + """Resize the annotation to match the resized image. + + This is an override of the existing function in transformers to + handle keypoints only (since we don't care about other + annotations here). Since the processor may resize samples, we + also need to scale keypoints to match (even though they are then + scaled to [0,1] anyway). + """ + ratios = tuple( + float(s) / float(s_orig) for s, s_orig in zip(size, orig_size, strict=False) + ) + ratio_height, ratio_width = ratios + + new_annotation = dict(annotation) + new_annotation["points"] *= np.array([ratio_height, ratio_width]) + + return new_annotation + + def post_process_keypoint_detection( + self, + outputs, + threshold: float = 0.5, + target_sizes: torch.Tensor | list[tuple] = None, + top_k: int = 100, + ): + """Converts the raw output of DeformableDetrForKeypointDetection into + final keypoints. + + Args: + outputs: Raw outputs of the model with 'logits' and 'pred_points' + threshold: Score threshold to keep keypoint predictions + target_sizes: Tensor of shape (batch_size, 2) or list of tuples (height, width) + top_k: Keep only top k keypoints before filtering by threshold + + Returns: + List of dictionaries with 'scores', 'labels', and 'keypoints' for each image + """ + out_logits = outputs.logits if hasattr(outputs, "logits") else outputs["logits"] + out_points = ( + outputs.pred_points + if hasattr(outputs, "pred_points") + else outputs["pred_points"] + ) + + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + # Get class probabilities + prob = out_logits.sigmoid() + prob = prob.view(out_logits.shape[0], -1) + k_value = min(top_k, prob.size(1)) + topk_values, topk_indexes = torch.topk(prob, k_value, dim=1) + scores = topk_values + + # Get corresponding point indices and labels + topk_points_idx = torch.div( + topk_indexes, out_logits.shape[2], rounding_mode="floor" + ) + labels = topk_indexes % out_logits.shape[2] + + # Gather the corresponding points + points = torch.gather( + out_points, 1, topk_points_idx.unsqueeze(-1).repeat(1, 1, 2) + ) + + # Convert from relative [0, 1] to absolute [0, height/width] coordinates + if target_sizes is not None: + if isinstance(target_sizes, list): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h], dim=1).to(points.device) + points = points * scale_fct[:, None, :] + + # Filter by threshold and return results for each item in batch + results = [] + for result in zip(scores, labels, points, strict=False): + score, label, point = result + + # Filter all outputs by score threshold + mask = score > threshold + score = score[mask] + label = label[mask] + keypoint = point[mask] + + results.append( + { + "scores": score, + "labels": label, + "keypoints": keypoint, + } + ) + + return results + + +class DeformableDetrForKeypointDetection(DeformableDetrPreTrainedModel): + # When using clones, all layers > 0 will be clones, but layer 0 *is* required + _tied_weights_keys = [r"point_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"] + # We can't initialize the model on meta device as some weights are modified during the initialization + _no_split_modules = None + + def __init__(self, config: DeformableDetrConfig): + super().__init__(config) + + # Deformable DETR encoder-decoder model + self.model = DeformableDetrModel(config) + + # Detection heads on top + self.class_embed = nn.Linear(config.d_model, config.num_labels) + + # 2D output for x/y + self.point_embed = DeformableDetrMLPPredictionHead( + input_dim=config.d_model, + hidden_dim=config.d_model, + output_dim=2, + num_layers=3, + ) + + # Currently no support for with_box_refine (iterative refinement) or two_stage + # with_box_refine: Would create independent prediction heads per decoder layer + # for iterative point coordinate refinement + # two_stage: Would add encoder-based proposal generation before decoder refinement + # (requires with_box_refine=True) + num_pred = config.decoder_layers + + # Weight-tied prediction heads across decoder layers + self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) + self.point_embed = nn.ModuleList([self.point_embed for _ in range(num_pred)]) + self.model.decoder.point_embed = None + + if config.two_stage: + raise NotImplementedError( + "Two-stage keypoint detection is not currently supported. " + "This would require implementing encoder-side proposal generation. " + "Set config.two_stage=False to use standard single-stage detection." + ) + + # Initialize weights and apply final processing + self.post_init() + + def loss_function(self, *args, **kwargs): + """Wrapper for the keypoint detection loss function.""" + return DeformableDetrForKeypointDetectionLoss(*args, **kwargs) + + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + decoder_attention_mask: torch.FloatTensor | None = None, + encoder_outputs: torch.FloatTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + decoder_inputs_embeds: torch.FloatTensor | None = None, + labels: list[dict] | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple[torch.FloatTensor] | DeformableDetrKeypointDetectionOutput: + r"""For full documentation, look at DeformableDetrForObjectDetection. + + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'labels' and 'points' (the class labels and object centers (points) of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of objects + in the image,)` and the points a `torch.FloatTensor` of shape `(number of points in the image, 2)`. + ``` + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # First, send images through DETR base model to obtain encoder + decoder outputs + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2] + init_reference = outputs.init_reference_points if return_dict else outputs[0] + inter_references = ( + outputs.intermediate_reference_points if return_dict else outputs[3] + ) + + # class logits + predicted points + outputs_classes = [] + outputs_coords = [] + + # References in Deformable DETR are 2D points corresponding to object + # centers. This naturally leads to keypoints anyway. + for level in range(hidden_states.shape[1]): + if level == 0: + reference = init_reference + else: + reference = inter_references[:, level - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[level](hidden_states[:, level]) + delta_point = self.point_embed[level](hidden_states[:, level]) + + # For keypoints: reference points are (x, y) with shape [..., 2] + if reference.shape[-1] != 2: + raise ValueError( + f"Keypoint detection requires 2D reference points (x, y), " + f"but got shape [..., {reference.shape[-1]}]" + ) + + # Add predicted delta to reference in logit space + outputs_coord_logits = delta_point + reference + # Convert back to [0, 1] normalized coordinates + outputs_coord = outputs_coord_logits.sigmoid() + + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + + logits = outputs_class[-1] + pred_points = outputs_coord[-1] + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_points, + self.config, + outputs_class, + outputs_coord, + ) + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_points) + auxiliary_outputs + outputs + else: + output = (logits, pred_points) + outputs + tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output + + return tuple_outputs + + dict_outputs = DeformableDetrKeypointDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_points=pred_points, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_reference_points=outputs.intermediate_reference_points, + init_reference_points=outputs.init_reference_points, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + ) + + return dict_outputs + + +def keypoint_to_coco(targets): + if not isinstance(targets, list): + targets = [targets] + + coco_targets = [] + for target in targets: + annotations_for_target = [] + for i, (label, point) in enumerate( + zip(target["labels"], target["points"], strict=False) + ): + if isinstance(point, torch.Tensor): + point = point.tolist() + if isinstance(label, torch.Tensor): + label = label.item() + + annotations_for_target.append( + { + "id": i, + "image_id": i, + "category_id": label, + "keypoints": point, # [x, y] + } + ) + + coco_target = {"image_id": 0, "annotations": annotations_for_target} + + # Preserve orig_size if available for coordinate scaling during preprocessing + if "orig_size" in target: + coco_target["orig_size"] = target["orig_size"] + + coco_targets.append(coco_target) + + return coco_targets + + +class KeypointDetrWrapper(nn.Module): + """Wrapper for DeformableDetrForKeypointDetection that handles + preprocessing and postprocessing transparently. + + This class translates between DeepForest's KeypointDataset format + and the transformers keypoint model format. + """ + + def __init__(self, config, name, revision, **hf_args): + """Initialize a DeformableDetrForKeypointDetection model. + + Args: + config: DeepForest config object + name: HuggingFace model name or path + revision: Model revision/branch + **hf_args: Additional arguments for from_pretrained + """ + super().__init__() + self.config = config + + # Import here to avoid circular imports + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + + # Create keypoint model config with keypoint-specific parameters + from transformers import DeformableDetrConfig + + # Load base config from pretrained model + model_config = DeformableDetrConfig.from_pretrained( + name, revision=revision, **hf_args + ) + + # Add keypoint-specific loss parameters if they exist in config + if hasattr(self.config, "point_cost"): + model_config.point_cost = self.config.point_cost + if hasattr(self.config, "point_loss_coefficient"): + model_config.point_loss_coefficient = self.config.point_loss_coefficient + if hasattr(self.config, "point_loss_type"): + model_config.point_loss_type = self.config.point_loss_type + + # Create model from config + self.net = DeformableDetrForKeypointDetection(model_config) + + # Load pretrained weights if available (backbone only, not prediction heads) + if name is not None: + from transformers import DeformableDetrModel + + pretrained_base = DeformableDetrModel.from_pretrained( + name, revision=revision, **hf_args + ) + # Copy only the base model weights (encoder/decoder, not classification/bbox heads) + self.net.model.load_state_dict(pretrained_base.state_dict(), strict=False) + + # Create processor + self.processor = DeformableDetrKeypointImageProcessor() + + # Update label mappings + if hasattr(self.config, "label_dict"): + self.label_dict = self.config.label_dict + else: + self.label_dict = {"Tree": 0} + self.num_classes = model_config.num_labels + + def _prepare_targets(self, targets): + """Translate KeypointDataset targets to COCO keypoint format. + + Args: + targets: List of dicts with "points" (N, 2) and "labels" (N,) + + Returns: + List of dicts in COCO annotation format for keypoints + """ + return keypoint_to_coco(targets) + + def forward(self, images, targets=None, prepare_targets=True): + """Forward pass for keypoint detection. + + Args: + images: Input images (list of tensors or batch tensor) + targets: Optional targets for training + prepare_targets: Whether to convert targets to COCO format + + Returns: + If training: loss dictionary + If inference: list of dicts with "keypoints", "scores", "labels" + """ + if targets and prepare_targets: + targets = self._prepare_targets(targets) + + encoded_inputs = self.processor.preprocess( + images=images, + annotations=targets, + return_tensors="pt", + do_rescale=False, # Dataset already normalized [0,255]→[0,1] + # Processor still does: resize, normalize (ImageNet), pad + ) + + # Move tensors to model device + for k, v in encoded_inputs.items(): + if isinstance(v, torch.Tensor): + encoded_inputs[k] = v.to(self.net.device) + + if isinstance(encoded_inputs.get("labels"), list): + encoded_inputs["labels"] = [ + { + key: val.to(self.net.device) if isinstance(val, torch.Tensor) else val + for key, val in target.items() + } + for target in encoded_inputs["labels"] + ] + + preds = self.net(**encoded_inputs) + + if targets is None or not self.training: + # Inference mode: post-process and return predictions + # Use original image sizes from targets to scale predictions back to original coordinate space + target_sizes = [t["orig_size"].cpu().tolist() for t in targets] + + results = self.processor.post_process_keypoint_detection( + preds, + threshold=self.config.score_thresh, + target_sizes=target_sizes, + ) + return results + else: + # Training mode: return loss dict and predictions for logging + return { + "loss_dict": preds.loss_dict, + "pred_points": preds.pred_points, + "targets": encoded_inputs.get("labels"), + } + + +class Model: + """Model factory for keypoint detection following DeepForest interface. + + This class provides a simple interface to create keypoint detection + models compatible with the DeepForest training pipeline. + """ + + def __init__(self, config, **kwargs): + """Initialize model factory. + + Args: + config: DeepForest configuration object + """ + self.config = config + + def create_model( + self, + pretrained: str | None = "SenseTime/deformable-detr", + *, + revision: str | None = "main", + map_location: str | torch.device | None = None, + **hf_args, + ) -> KeypointDetrWrapper: + """Create a keypoint detection model from pretrained weights. + + The model starts from a pretrained Deformable DETR backbone (encoder/decoder) + but initializes new prediction heads for the configured number of classes. + + Args: + pretrained: HuggingFace model name or path. If None, random initialization. + revision: Model revision/branch + map_location: Device to load model onto + **hf_args: Additional arguments for from_pretrained + + Returns: + KeypointDetrWrapper model ready for training or inference + """ + # Set label mapping if provided + if pretrained is None: + hf_args.setdefault("id2label", self.config.numeric_to_label_dict) + + model = KeypointDetrWrapper( + self.config, + name=pretrained, + revision=revision, + num_labels=self.config.num_classes, + **hf_args, + ) + + if map_location is not None: + model = model.to(map_location) + + return model + + +__all__ = [ + "DeformableDetrKeypointConfig", + "DeformableDetrKeypointDetectionOutput", + "DeformableDetrForKeypointDetection", + "DeformableDetrKeypointMatcher", + "DeformableDetrKeypointLoss", + "DeformableDetrKeypointImageProcessor", + "KeypointDetrWrapper", + "Model", +] diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index ab685b296..0afa02284 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -1,5 +1,6 @@ # Prediction utilities import os +from warnings import warn import numpy as np import pandas as pd @@ -23,7 +24,7 @@ def _predict_image_( model: a deepforest.main.model object image: a tensor of shape (channels, height, width) path: optional path to read image from disk instead of passing image arg - nms_thresh: Non-max suppression threshold, see config.nms_thresh + nms_thresh: Non-max suppression threshold, see config.nms_thresh (only for box detection) Returns: df: A pandas dataframe of predictions (Default) img: The input with predictions overlaid (Optional) @@ -35,13 +36,14 @@ def _predict_image_( with torch.no_grad(): prediction = model(image.unsqueeze(0)) + df = utilities.format_geometry(prediction[0]) + # return None for no predictions - if len(prediction[0]["boxes"]) == 0: + if df is None: return None - df = utilities.format_boxes(prediction[0]) - - if df.label.nunique() > 1: + # NMS for boxes only + if "xmin" in df.columns and df.label.nunique() > 1: df = across_class_nms(df, iou_threshold=nms_thresh) # Add image path if provided @@ -51,28 +53,40 @@ def _predict_image_( return df -def transform_coordinates(boxes): - """Transform box coordinates from window space to original image space. +def transform_coordinates(predictions): + """Transform coordinates from window space to original image space. Args: - boxes: DataFrame of predictions with xmin, ymin, xmax, ymax, window_xmin, window_ymin columns + predictions: DataFrame of predictions with coordinate columns and window_xmin, window_ymin Returns: DataFrame with transformed coordinates """ - boxes = boxes.copy() - boxes["xmin"] += boxes["window_xmin"] - boxes["xmax"] += boxes["window_xmin"] - boxes["ymin"] += boxes["window_ymin"] - boxes["ymax"] += boxes["window_ymin"] + predictions = predictions.copy() + + # Handle box coordinates + if "xmin" in predictions.columns: + predictions["xmin"] += predictions["window_xmin"] + predictions["xmax"] += predictions["window_xmin"] + predictions["ymin"] += predictions["window_ymin"] + predictions["ymax"] += predictions["window_ymin"] + + # Cast to int + predictions["xmin"] = predictions["xmin"].astype(int) + predictions["ymin"] = predictions["ymin"].astype(int) + predictions["xmax"] = predictions["xmax"].astype(int) + predictions["ymax"] = predictions["ymax"].astype(int) + + # Handle keypoint coordinates + elif "x" in predictions.columns and "y" in predictions.columns: + predictions["x"] += predictions["window_xmin"] + predictions["y"] += predictions["window_ymin"] - # Cast to int - boxes["xmin"] = boxes["xmin"].astype(int) - boxes["ymin"] = boxes["ymin"].astype(int) - boxes["xmax"] = boxes["xmax"].astype(int) - boxes["ymax"] = boxes["ymax"].astype(int) + # Cast to int + predictions["x"] = predictions["x"].astype(int) + predictions["y"] = predictions["y"].astype(int) - return boxes + return predictions def apply_nms(boxes, scores, labels, iou_threshold): @@ -109,7 +123,7 @@ def apply_nms(boxes, scores, labels, iou_threshold): ) -def mosiac(predictions, iou_threshold=0.1): +def mosaic(predictions, iou_threshold=0.1): """Mosaic predictions from overlapping windows. Args: @@ -119,22 +133,31 @@ def mosiac(predictions, iou_threshold=0.1): Returns: A pandas dataframe of predictions. """ - predicted_boxes = transform_coordinates(predictions) + predicted_results = transform_coordinates(predictions) # Skip NMS if there's is one or less prediction - if predicted_boxes.shape[0] <= 1: - return predicted_boxes + if predicted_results.shape[0] <= 1: + return predicted_results + + # TODO: Should probably have an aggregation function here. + # For keypoints, eturn transformed coordinates + if "x" in predicted_results.columns and "y" in predicted_results.columns: + warn( + "Keypoint merging for overlapping windows is not yet supported, returning all points.", + stacklevel=2, + ) + return predicted_results print( - f"{predicted_boxes.shape[0]} predictions in overlapping windows, applying non-max suppression" + f"{predicted_results.shape[0]} box predictions in overlapping windows, applying non-max suppression" ) # Convert to tensors boxes = torch.tensor( - predicted_boxes[["xmin", "ymin", "xmax", "ymax"]].values, dtype=torch.float32 + predicted_results[["xmin", "ymin", "xmax", "ymax"]].values, dtype=torch.float32 ) - scores = torch.tensor(predicted_boxes.score.values, dtype=torch.float32) - labels = predicted_boxes.label.values + scores = torch.tensor(predicted_results.score.values, dtype=torch.float32) + labels = predicted_results.label.values # Apply NMS filtered_boxes = apply_nms(boxes, scores, labels, iou_threshold) diff --git a/src/deepforest/utilities.py b/src/deepforest/utilities.py index 76f91d2ad..2e6d5d5c2 100644 --- a/src/deepforest/utilities.py +++ b/src/deepforest/utilities.py @@ -7,6 +7,7 @@ import pandas as pd import rasterio import shapely +import torch import xmltodict from omegaconf import DictConfig, OmegaConf from PIL import Image @@ -45,6 +46,10 @@ def load_config( yaml_cfg = OmegaConf.load(yaml_path) + # Drop Hydra-specific override and inherit directly from base + if "defaults" in yaml_cfg: + yaml_cfg.pop("defaults") + # Merge in sequence (overrides last) config = OmegaConf.merge(base, yaml_cfg, overrides) @@ -309,7 +314,7 @@ def determine_geometry_type(df): geometry_type = "box" elif "polygon" in df.keys(): geometry_type = "polygon" - elif "points" in df.keys(): + elif "points" in df.keys() or "keypoints" in df.keys(): geometry_type = "point" return geometry_type @@ -331,18 +336,15 @@ def format_geometry(predictions, scores=True, geom_type=None): if geom_type == "box": df = format_boxes(predictions, scores=scores) - if df is None: - return None - elif geom_type == "polygon": raise ValueError("Polygon predictions are not yet supported for formatting") elif geom_type == "point": - raise ValueError("Point predictions are not yet supported for formatting") + df = format_points(predictions, scores=scores) return df -def format_boxes(prediction, scores=True): +def format_boxes(prediction: dict[str, torch.Tensor], scores=True): """Format a retinanet prediction into a pandas dataframe for a single image. @@ -370,6 +372,39 @@ def format_boxes(prediction, scores=True): return df +def format_points(prediction: dict[str, torch.Tensor], scores=True): + """Format a keypoint prediction into a pandas dataframe for a single image. + + Args: + prediction: a dictionary with keys 'keypoints' (or 'points'), 'labels', and optionally 'scores' + scores: Whether points come with scores, during prediction, or without scores, as in during training. + Returns: + df: a pandas dataframe with columns x, y, label, score (optional), geometry + """ + + if "keypoints" in prediction: + points_key = "keypoints" + elif "points" in prediction: + points_key = "points" + else: + raise ValueError( + "Prediction dict must contain either 'keypoints' or 'points' key" + ) + + if len(prediction[points_key]) == 0: + return None + + points = prediction[points_key].cpu().detach().numpy() + df = pd.DataFrame(points, columns=["x", "y"]) + df["label"] = prediction["labels"].cpu().detach().numpy() + + if scores and "scores" in prediction: + df["score"] = prediction["scores"].cpu().detach().numpy() + + df["geometry"] = df.apply(lambda row: shapely.geometry.Point(row.x, row.y), axis=1) + return df + + def read_coco(json_file): """Read a COCO format JSON file and return a pandas dataframe. diff --git a/src/deepforest/visualize.py b/src/deepforest/visualize.py index 092caae23..dfc8f695b 100644 --- a/src/deepforest/visualize.py +++ b/src/deepforest/visualize.py @@ -491,7 +491,14 @@ def _plot_image_with_geometry( detections=detections, ) elif geom_type == "point": - point_annotator = sv.VertexAnnotator(color=sv_color, radius=radius) + # TODO can we abuse DotAnnotator and pass in a zero-area bbox with the keypoint as coords? + # VertexAnnotator doesn't accept ColorPalette, only single Color + # If we have a palette, use the first color + if isinstance(sv_color, sv.ColorPalette): + point_color = sv_color.colors[0] + else: + point_color = sv_color + point_annotator = sv.VertexAnnotator(color=point_color, radius=radius) annotated_frame = point_annotator.annotate( scene=image.copy(), key_points=detections ) diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 8ec1f41d1..5e1737e09 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -172,7 +172,7 @@ def test_bbox_params(): transform_repr = repr(transform) assert "bbox_params" in transform_repr assert "'format': 'pascal_voc'" in transform_repr - assert "'label_fields': ['category_ids']" in transform_repr + assert "'label_fields': ['labels']" in transform_repr def test_blur_augmentations(): diff --git a/tests/test_keypoint.py b/tests/test_keypoint.py new file mode 100644 index 000000000..6cc6df289 --- /dev/null +++ b/tests/test_keypoint.py @@ -0,0 +1,478 @@ +# Test keypoint detection model and loss functions +import numpy as np +import pytest +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import EarlyStopping +from torch.utils.data import DataLoader +from transformers import DeformableDetrConfig + +from deepforest.models.keypoint import ( + DeformableDetrForKeypointDetection, + DeformableDetrKeypointMatcher, + DeformableDetrKeypointLoss, + DeformableDetrKeypointImageProcessor, + DeformableDetrKeypointConfig, +) + + +@pytest.fixture +def config(): + """Create a test configuration.""" + config = DeformableDetrConfig.from_pretrained("SenseTime/deformable-detr") + config.num_labels = 5 # Use fewer labels for testing + return config + + +@pytest.fixture +def keypoint_matcher(): + """Create a keypoint matcher with standard costs.""" + return DeformableDetrKeypointMatcher( + class_cost=2.0, + point_cost=5.0, + ) + + +@pytest.fixture +def keypoint_loss(keypoint_matcher): + """Create a keypoint loss criterion.""" + return DeformableDetrKeypointLoss( + matcher=keypoint_matcher, + num_classes=5, + focal_alpha=0.25, + losses=["labels", "points", "cardinality"], + ) + + +@pytest.fixture +def sample_targets(): + """Create sample targets for testing.""" + return [ + { + "class_labels": torch.tensor([0, 1, 2]), + "points": torch.tensor([[0.2, 0.3], [0.5, 0.6], [0.8, 0.9]]), + }, + { + "class_labels": torch.tensor([0, 1, 2, 3]), + "points": torch.tensor([[0.2, 0.3], [0.5, 0.6], [0.8, 0.9], [0.3, 0.4]]), + } + ] + + +def create_predictions_from_targets(targets, num_queries=10, num_classes=5, jitter_std=0.0): + """ + Create predictions that match targets with optional jitter. + + Args: + targets: List of target dicts + num_queries: Number of query predictions per image + num_classes: Number of keypoint classes + jitter_std: Standard deviation of Gaussian jitter to add to positions + + Returns: + Dict with 'logits' and 'pred_points' + """ + batch_size = len(targets) + logits = torch.full((batch_size, num_queries, num_classes), -10.0) + pred_points = torch.rand(batch_size, num_queries, 2) + + for i, target in enumerate(targets): + num_kpts = len(target["class_labels"]) + for j in range(num_kpts): + logits[i, j, target["class_labels"][j]] = 10.0 + jitter = torch.randn(2) * jitter_std if jitter_std > 0 else 0 + pred_points[i, j] = torch.clamp(target["points"][j] + jitter, 0, 1) + + return {"logits": logits, "pred_points": pred_points} + + +"""Test suite for keypoint detection loss functions.""" + +def test_loss_identical_predictions(keypoint_loss, sample_targets): + """Test that loss is very low when predictions perfectly match targets.""" + outputs = create_predictions_from_targets(sample_targets, jitter_std=0.0) + loss_dict = keypoint_loss(outputs, sample_targets) + + assert loss_dict["loss_ce"] < 0.5, f"Classification loss too high: {loss_dict['loss_ce']}" + assert loss_dict["loss_point_l1"] < 0.01, f"Point loss too high: {loss_dict['loss_point_l1']}" + +def test_loss_small_jitter(keypoint_loss, sample_targets): + """Test that loss is small when predictions have small positional errors (~5%).""" + outputs = create_predictions_from_targets(sample_targets, jitter_std=0.01) + loss_dict = keypoint_loss(outputs, sample_targets) + + assert loss_dict["loss_ce"] < 0.5, f"Classification loss too high: {loss_dict['loss_ce']}" + assert 0 < loss_dict["loss_point_l1"] < 0.1, f"Point loss out of expected range: {loss_dict['loss_point_l1']}" + +def test_loss_shuffled_predictions(keypoint_loss): + """Test that Hungarian matching correctly handles shuffled predictions.""" + targets = [{ + "class_labels": torch.tensor([0, 1, 2]), + "points": torch.tensor([[0.2, 0.3], [0.5, 0.6], [0.8, 0.9]]), + }] + + # Create predictions in reversed order + num_queries = 10 + logits = torch.full((1, num_queries, 5), -10.0) + pred_points = torch.rand(1, num_queries, 2) + + shuffled_order = [2, 1, 0] + for j, target_idx in enumerate(shuffled_order): + logits[0, j, targets[0]["class_labels"][target_idx]] = 10.0 + pred_points[0, j] = targets[0]["points"][target_idx] + + loss_dict = keypoint_loss({"logits": logits, "pred_points": pred_points}, targets) + + assert loss_dict["loss_ce"] < 0.5, f"Classification loss too high: {loss_dict['loss_ce']}" + assert loss_dict["loss_point_l1"] < 0.01, f"Point loss too high (matching failed?): {loss_dict['loss_point_l1']}" + +def test_loss_wrong_classes(keypoint_loss): + """Test that classification loss is high when classes are incorrect.""" + targets = [{ + "class_labels": torch.tensor([0, 1, 2]), + "points": torch.tensor([[0.2, 0.3], [0.5, 0.6], [0.8, 0.9]]), + }] + + num_queries = 10 + logits = torch.full((1, num_queries, 5), -10.0) + pred_points = torch.rand(1, num_queries, 2) + + # Assign wrong classes (shifted by 1) but correct locations + for j in range(3): + wrong_class = (targets[0]["class_labels"][j] + 1) % 5 + logits[0, j, wrong_class] = 10.0 + pred_points[0, j] = targets[0]["points"][j] + + loss_dict = keypoint_loss({"logits": logits, "pred_points": pred_points}, targets) + + assert loss_dict["loss_ce"] > 1.0, f"Classification loss too low: {loss_dict['loss_ce']}" + +def test_loss_no_keypoints(keypoint_loss): + """Test loss computation with images that have no keypoints.""" + targets = [{ + "class_labels": torch.tensor([], dtype=torch.long), + "points": torch.empty(0, 2), + }] + + outputs = {"logits": torch.randn(1, 10, 5), "pred_points": torch.rand(1, 10, 2)} + loss_dict = keypoint_loss(outputs, targets) + + assert "loss_ce" in loss_dict + assert "loss_point_l1" in loss_dict + assert not torch.isnan(loss_dict["loss_ce"]) + assert not torch.isnan(loss_dict["loss_point_l1"]) + +"""Test suite for keypoint Hungarian matcher.""" + +def test_matcher_perfect_match(keypoint_matcher): + """Test matcher with perfect correspondence between predictions and targets.""" + outputs = { + "logits": torch.tensor([[[10.0, -10.0], [-10.0, 10.0], [-10.0, -10.0]]]), + "pred_points": torch.tensor([[[0.2, 0.3], [0.5, 0.6], [0.8, 0.9]]]), + } + targets = [{ + "class_labels": torch.tensor([0, 1]), + "points": torch.tensor([[0.2, 0.3], [0.5, 0.6]]), + }] + + indices = keypoint_matcher(outputs, targets) + pred_indices, target_indices = indices[0] + + assert len(pred_indices) == 2 + assert len(target_indices) == 2 + +def test_matcher_handles_more_predictions(keypoint_matcher): + """Test that matcher handles more predictions than targets.""" + outputs = { + "logits": torch.randn(1, 100, 5), + "pred_points": torch.rand(1, 100, 2), + } + targets = [{ + "class_labels": torch.tensor([0, 1, 2]), + "points": torch.rand(3, 2), + }] + + indices = keypoint_matcher(outputs, targets) + pred_indices, target_indices = indices[0] + + assert len(pred_indices) == 3 + assert len(target_indices) == 3 + + +"""Test suite for keypoint image processor.""" + +def test_processor_single_keypoint(): + """Test processor with single keypoint per annotation.""" + processor = DeformableDetrKeypointImageProcessor(do_resize=False, do_normalize=False, do_pad=False) + image = np.random.randint(0, 255, (800, 600, 3), dtype=np.uint8) + annotation = { + "image_id": 1, + "annotations": [ + {"category_id": 0, "keypoints": [100.0, 150.0]}, + {"category_id": 1, "keypoints": [200.0, 250.0]}, + ] + } + + # Call processor with images and annotations + result = processor(image, annotations=annotation, do_convert_annotations=True) + + # Check that we got normalized labels + assert "labels" in result + labels = result["labels"][0] + assert labels["points"].shape == (2, 2) + assert labels["class_labels"].shape == (2,) + # Points should be normalized + assert labels["points"].max() <= 1.0 + assert labels["points"].min() >= 0.0 + +def test_processor_multiple_keypoints(): + """Test processor with multiple keypoints per annotation.""" + processor = DeformableDetrKeypointImageProcessor(do_resize=False, do_normalize=False, do_pad=False) + image = np.random.randint(0, 255, (800, 600, 3), dtype=np.uint8) + annotation = { + "image_id": 1, + "annotations": [ + {"category_id": 0, "keypoints": [[100.0, 150.0], [120.0, 170.0]]}, + ] + } + + result = processor(image, annotations=annotation, do_convert_annotations=True) + + labels = result["labels"][0] + assert labels["points"].shape == (2, 2) + assert labels["class_labels"].shape == (2,) + assert all(labels["class_labels"] == 0) + +def test_processor_normalization(): + """Test that processor properly normalizes coordinates.""" + processor = DeformableDetrKeypointImageProcessor(do_resize=False, do_normalize=False, do_pad=False) + image = np.random.randint(0, 255, (800, 600, 3), dtype=np.uint8) + annotation = { + "image_id": 1, + "annotations": [ + {"category_id": 0, "keypoints": [300.0, 400.0]}, # Middle of image + ] + } + + result = processor(image, annotations=annotation, do_convert_annotations=True) + + labels = result["labels"][0] + # Check normalization: 300/600=0.5, 400/800=0.5 + assert np.allclose(labels["points"][0], [0.5, 0.5]) + +def test_processor_post_process_keypoints(): + """Test post-processing of keypoint predictions.""" + processor = DeformableDetrKeypointImageProcessor() + from deepforest.models.keypoint import DeformableDetrKeypointDetectionOutput + + outputs = DeformableDetrKeypointDetectionOutput( + logits=torch.tensor([[[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]]]), + pred_points=torch.tensor([[[0.5, 0.5], [0.3, 0.7], [0.8, 0.2]]]), + ) + + results = processor.post_process_keypoint_detection( + outputs, + threshold=0.5, + target_sizes=[(800, 600)], + top_k=10 + ) + + assert len(results) == 1 + result = results[0] + assert "keypoints" in result + assert "scores" in result + assert "labels" in result + assert result["keypoints"].max() <= 800 + + +"""Test suite for the full keypoint detection model.""" + +def test_model_forward_inference(config): + """Test model forward pass in inference mode.""" + model = DeformableDetrForKeypointDetection(config) + model.eval() + + batch_size = 2 + pixel_values = torch.randn(batch_size, 3, 800, 800) + + with torch.no_grad(): + outputs = model(pixel_values) + + assert outputs.logits.shape == (batch_size, 300, config.num_labels) + assert outputs.pred_points.shape == (batch_size, 300, 2) + assert outputs.loss is None + +def test_model_forward_training(config): + """Test model forward pass in training mode with labels.""" + model = DeformableDetrForKeypointDetection(config) + model.train() + + batch_size = 2 + pixel_values = torch.randn(batch_size, 3, 800, 800) + labels = [ + {"class_labels": torch.tensor([0, 1]), "points": torch.rand(2, 2)}, + {"class_labels": torch.tensor([2, 3, 4]), "points": torch.rand(3, 2)} + ] + + outputs = model(pixel_values, labels=labels) + + assert outputs.loss is not None + assert outputs.loss_dict is not None + assert "loss_ce" in outputs.loss_dict + assert "loss_point_l1" in outputs.loss_dict + assert not torch.isnan(outputs.loss) + +# TODO: Remove or simplify this when we have integration with the main library sorted out. +def test_model_train_overfit(): + """Test model can overfit to memorize 10 keypoints""" + # Create small config for faster training + config = DeformableDetrConfig.from_pretrained("SenseTime/deformable-detr") + config.num_labels = 3 + config.decoder_layers = 4 + config.encoder_layers = 4 + config.num_queries = 20 + + processor = DeformableDetrKeypointImageProcessor() + + # Create fixed dataset with 10 well-separated keypoints + image_height, image_width = 800, 600 + fixed_keypoints_pixel = np.array([ + [100, 100], [500, 100], [100, 700], [500, 700], [300, 400], + [150, 250], [450, 250], [150, 550], [450, 550], [300, 150], + ]) + fixed_classes = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0]) + + image = np.random.randint(0, 255, (image_height, image_width, 3), dtype=np.uint8) + annotation = { + "image_id": 0, + "annotations": [ + {"category_id": int(cls), "keypoints": kpt.tolist()} + for cls, kpt in zip(fixed_classes, fixed_keypoints_pixel) + ] + } + + result = processor( + image, + annotations=annotation, + do_resize=False, + do_normalize=False, + do_pad=False, + do_convert_annotations=True, + return_tensors="pt" + ) + + labels = [result["labels"][0]] + + # Lightning module + class KeypointLightningModule(pl.LightningModule): + def __init__(self, model): + super().__init__() + self.model = model + self.train_losses = [] + + def training_step(self, batch, batch_idx): + pixel_values, labels = batch + outputs = self.model(pixel_values, labels=labels) + self.log("train_loss", outputs.loss, prog_bar=True) + self.train_losses.append(outputs.loss.item()) + return outputs.loss + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=1e-3) + + # Dataset/dataloader + class KeypointDataset(torch.utils.data.Dataset): + def __init__(self, pixel_values, labels): + self.pixel_values = pixel_values.squeeze(0) + self.labels = labels[0] + + def __len__(self): + return 1 + + def __getitem__(self, idx): + return self.pixel_values, self.labels + + def collate_fn(batch): + pixel_values = torch.stack([item[0] for item in batch]) + labels = [item[1] for item in batch] + return pixel_values, labels + + pixel_values = torch.randn(1, 3, image_height, image_width) + dataset = KeypointDataset(pixel_values, labels) + dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn) + + lightning_model = KeypointLightningModule(DeformableDetrForKeypointDetection(config)) + + # Train with early stopping + trainer = pl.Trainer( + max_epochs=200, + devices=1, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False + ) + + trainer.fit(lightning_model, dataloader) + + # Verify loss reduction + initial_loss = lightning_model.train_losses[0] + final_loss = lightning_model.train_losses[-1] + assert final_loss < initial_loss * 0.3, \ + f"Loss did not decrease enough: {final_loss:.4f} vs {initial_loss:.4f}" + + # Test inference and post-processing + lightning_model.model.eval() + with torch.no_grad(): + outputs = lightning_model.model(pixel_values) + + results = processor.post_process_keypoint_detection( + outputs, + threshold=0.3, + target_sizes=[(image_height, image_width)], + top_k=50 + ) + + result = results[0] + num_predictions = len(result["keypoints"]) + assert num_predictions > 0, "No keypoints detected after training!" + + # Check predictions are well distributed + if num_predictions >= 2: + from scipy.spatial.distance import pdist + pred_points = result["keypoints"].numpy() + # Reshape if flattened (N*2,) -> (N, 2) + if pred_points.ndim == 1: + pred_points = pred_points.reshape(-1, 2) + + # Check localization accuracy + if num_predictions >= len(fixed_keypoints_pixel): + pred_points = result["keypoints"].numpy() + # Reshape if flattened (N*2,) -> (N, 2) + if pred_points.ndim == 1: + pred_points = pred_points.reshape(-1, 2) + errors = [] + for gt_point in fixed_keypoints_pixel: + distances = np.linalg.norm(pred_points - gt_point, axis=1) + errors.append(distances.min()) + + mean_error = np.mean(errors) + assert mean_error < 10, f"Mean localization error too high: {mean_error:.1f} pixels" + +def test_keypoint_config(): + """Test that DeformableDetrKeypointConfig has keypoint-specific parameters.""" + config = DeformableDetrKeypointConfig( + num_labels=5, + point_cost=3.0, + point_loss_coefficient=7.0 + ) + + assert hasattr(config, 'point_cost') + assert hasattr(config, 'point_loss_coefficient') + assert config.point_cost == 3.0 + assert config.point_loss_coefficient == 7.0 + assert config.num_labels == 5 + + # Test that it still has parent class attributes + assert hasattr(config, 'class_cost') + assert hasattr(config, 'bbox_cost') diff --git a/tests/test_keypoint_evaluation.py b/tests/test_keypoint_evaluation.py new file mode 100644 index 000000000..79aa5bcb0 --- /dev/null +++ b/tests/test_keypoint_evaluation.py @@ -0,0 +1,283 @@ +"""Focused tests for keypoint distance metrics and evaluation.""" +import numpy as np +import pandas as pd +import pytest +import geopandas as gpd +from shapely.geometry import Point + +from deepforest import keypoint_distance, evaluate + + +# Helper to reduce Point(x, y) repetition +def pts(coords): + """Convert [(x, y), ...] to [Point(x, y), ...].""" + return [Point(x, y) for x, y in coords] + + +# ============================================================================ +# Distance Computation Tests +# ============================================================================ + +def test_compute_distances_perfect_match(): + """Test distance computation with perfectly matching keypoints.""" + points = [(100, 100), (200, 200)] + predictions = gpd.GeoDataFrame({"geometry": pts(points), "score": [0.9, 0.8]}) + ground_truth = gpd.GeoDataFrame({"geometry": pts(points)}) + + result = keypoint_distance.compute_distances(ground_truth, predictions) + + assert len(result) == 2 + assert all(result["distance"] < 0.01) + assert all(result["prediction_id"].notna()) + assert list(result["score"]) == [0.9, 0.8] + + +def test_compute_distances_known_distances(): + """Test distance computation returns correct Euclidean distances (3-4-5 triangle).""" + result = keypoint_distance.compute_distances( + gpd.GeoDataFrame({"geometry": pts([(3, 4)])}), + gpd.GeoDataFrame({"geometry": pts([(0, 0)]), "score": [0.9]}) + ) + assert result["distance"].iloc[0] == pytest.approx(5.0) + + +def test_compute_distances_optimal_matching(): + """Test Hungarian algorithm finds optimal assignment, not greedy.""" + result = keypoint_distance.compute_distances( + gpd.GeoDataFrame({"geometry": pts([(200, 201), (100, 101)])}), + gpd.GeoDataFrame({"geometry": pts([(100, 100), (200, 200)]), "score": [0.9, 0.8]}) + ) + assert all(result["distance"] < 2.0) + + +def test_compute_distances_more_predictions(): + """Test matching when there are more predictions than ground truth.""" + result = keypoint_distance.compute_distances( + gpd.GeoDataFrame({"geometry": pts([(100, 100), (200, 200)])}), + gpd.GeoDataFrame({"geometry": pts([(i*100, i*100) for i in range(4)]), "score": [0.9, 0.8, 0.7, 0.6]}) + ) + assert len(result) == 2 + assert result["prediction_id"].notna().sum() == 2 + + +def test_compute_distances_more_ground_truth(): + """Test matching when there are more ground truth than predictions.""" + result = keypoint_distance.compute_distances( + gpd.GeoDataFrame({"geometry": pts([(i*100, i*100) for i in range(4)])}), + gpd.GeoDataFrame({"geometry": pts([(0, 0), (100, 100)]), "score": [0.9, 0.8]}) + ) + assert len(result) == 4 + assert result["prediction_id"].notna().sum() == 2 + assert result["prediction_id"].isna().sum() == 2 + assert all(result[result["prediction_id"].isna()]["distance"] == np.inf) + + +def test_compute_distances_empty_predictions(): + """Test handling of empty predictions.""" + result = keypoint_distance.compute_distances( + gpd.GeoDataFrame({"geometry": pts([(100, 100), (200, 200)])}), + gpd.GeoDataFrame({"geometry": pts([]), "score": []}) + ) + assert len(result) == 0 + + +def test_compute_distances_empty_ground_truth(): + """Test handling of empty ground truth.""" + result = keypoint_distance.compute_distances( + gpd.GeoDataFrame({"geometry": pts([])}), + gpd.GeoDataFrame({"geometry": pts([(100, 100)]), "score": [0.9]}) + ) + assert len(result) == 0 + + +def test_compute_distances_without_scores(): + """Test matching works even without score column.""" + result = keypoint_distance.compute_distances( + gpd.GeoDataFrame({"geometry": pts([(100, 100), (200, 200)])}), + gpd.GeoDataFrame({"geometry": pts([(100, 100), (200, 200)])}) + ) + assert len(result) == 2 + assert all(result["score"].isna()) + + +# ============================================================================ +# Image-level Keypoint Evaluation Tests +# ============================================================================ + +def test_evaluate_image_keypoints_perfect_match(): + """Test image-level evaluation with perfect matches.""" + points = [(100, 100), (200, 200)] + predictions = gpd.GeoDataFrame({ + "geometry": pts(points), + "score": [0.9, 0.8], + "label": ["Tree", "Bird"], + "image_path": ["img1.jpg", "img1.jpg"] + }) + ground_truth = gpd.GeoDataFrame({ + "geometry": pts(points), + "label": ["Tree", "Bird"], + "image_path": ["img1.jpg", "img1.jpg"] + }) + + result = evaluate.evaluate_image_keypoints(predictions, ground_truth) + + assert len(result) == 2 + assert all(result["distance"] < 0.01) + assert list(result["predicted_label"]) == ["Tree", "Bird"] + assert list(result["true_label"]) == ["Tree", "Bird"] + + +def test_evaluate_image_keypoints_label_mapping(): + """Test that labels are correctly mapped from indices.""" + result = evaluate.evaluate_image_keypoints( + gpd.GeoDataFrame({"geometry": pts([(100, 100)]), "score": [0.9], "label": ["Tree"], "image_path": ["img1.jpg"]}), + gpd.GeoDataFrame({"geometry": pts([(100, 100)]), "label": ["Bird"], "image_path": ["img1.jpg"]}) + ) + assert result["predicted_label"].iloc[0] == "Tree" + assert result["true_label"].iloc[0] == "Bird" + + +def test_evaluate_image_keypoints_multiple_images_error(): + """Test that function raises error with multiple images.""" + with pytest.raises(ValueError, match="More than one plot"): + evaluate.evaluate_image_keypoints( + gpd.GeoDataFrame({"geometry": pts([(100, 100), (200, 200)]), "label": ["Tree", "Tree"], "image_path": ["img1.jpg", "img2.jpg"]}), + gpd.GeoDataFrame({"geometry": pts([(100, 100)]), "label": ["Tree"], "image_path": ["img1.jpg"]}) + ) + + +# ============================================================================ +# Full Keypoint Evaluation Tests +# ============================================================================ + +def test_evaluate_keypoints_recall_precision(): + """Test recall and precision calculation with pixel threshold.""" + predictions = gpd.GeoDataFrame({ + "geometry": pts([(100, 100), (205, 205), (350, 350)]), + "score": [0.9, 0.8, 0.7], + "label": ["Tree", "Tree", "Tree"], + "image_path": ["img1.jpg", "img1.jpg", "img1.jpg"] + }) + ground_truth = gpd.GeoDataFrame({ + "geometry": pts([(100, 100), (200, 200)]), + "label": ["Tree", "Tree"], + "image_path": ["img1.jpg", "img1.jpg"] + }) + + result = evaluate.evaluate_keypoints(predictions, ground_truth, pixel_threshold=10.0) + + assert result["recall"] == 1.0 + assert result["precision"] == pytest.approx(2/3) + assert result["results"]["match"].sum() == 2 + + +def test_evaluate_keypoints_threshold_filtering(): + """Test that pixel threshold correctly filters matches.""" + predictions = gpd.GeoDataFrame({ + "geometry": pts([(100, 100), (250, 250)]), + "score": [0.9, 0.8], + "label": ["Tree", "Tree"], + "image_path": ["img1.jpg", "img1.jpg"] + }) + ground_truth = gpd.GeoDataFrame({ + "geometry": pts([(100, 100), (200, 200)]), + "label": ["Tree", "Tree"], + "image_path": ["img1.jpg", "img1.jpg"] + }) + + assert evaluate.evaluate_keypoints(predictions, ground_truth, pixel_threshold=10.0)["results"]["match"].sum() == 1 + assert evaluate.evaluate_keypoints(predictions, ground_truth, pixel_threshold=100.0)["results"]["match"].sum() == 2 + + +def test_evaluate_keypoints_empty_predictions(): + """Test evaluation with no predictions.""" + result = evaluate.evaluate_keypoints( + gpd.GeoDataFrame({"geometry": pts([]), "label": [], "image_path": []}), + gpd.GeoDataFrame({"geometry": pts([(100, 100)]), "label": ["Tree"], "image_path": ["img1.jpg"]}) + ) + assert result["recall"] == 0.0 + assert pd.isna(result["precision"]) + assert result["class_recall"] is None + + +def test_evaluate_keypoints_empty_ground_truth(): + """Test evaluation with no ground truth.""" + result = evaluate.evaluate_keypoints( + gpd.GeoDataFrame({"geometry": pts([(100, 100)]), "score": [0.9], "label": ["Tree"], "image_path": ["img1.jpg"]}), + gpd.GeoDataFrame({"geometry": pts([]), "label": [], "image_path": []}) + ) + assert result["results"] is None + assert result["recall"] is None + assert result["precision"] == 0.0 + + +def test_evaluate_keypoints_multi_image(): + """Test evaluation across multiple images.""" + predictions = gpd.GeoDataFrame({ + "geometry": pts([(100, 100), (200, 200), (300, 300)]), + "score": [0.9, 0.8, 0.7], + "label": ["Tree", "Tree", "Bird"], + "image_path": ["img1.jpg", "img1.jpg", "img2.jpg"] + }) + ground_truth = predictions.copy() + + result = evaluate.evaluate_keypoints(predictions, ground_truth, pixel_threshold=5.0) + + assert result["recall"] == 1.0 + assert result["precision"] == 1.0 + assert len(result["results"]) == 3 + assert set(result["results"]["image_path"]) == {"img1.jpg", "img2.jpg"} + + +def test_evaluate_keypoints_class_recall(): + """Test per-class recall and precision calculation.""" + predictions = gpd.GeoDataFrame({ + "geometry": pts([(100, 100), (205, 205), (300, 300)]), + "score": [0.9, 0.8, 0.7], + "label": ["Tree", "Tree", "Bird"], + "image_path": ["img1.jpg", "img1.jpg", "img1.jpg"] + }) + ground_truth = gpd.GeoDataFrame({ + "geometry": pts([(100, 100), (200, 200), (300, 300)]), + "label": ["Tree", "Tree", "Bird"], + "image_path": ["img1.jpg", "img1.jpg", "img1.jpg"] + }) + + result = evaluate.evaluate_keypoints(predictions, ground_truth, pixel_threshold=10.0) + class_recall = result["class_recall"] + + assert class_recall[class_recall["label"] == "Tree"]["recall"].iloc[0] == 1.0 + assert class_recall[class_recall["label"] == "Tree"]["precision"].iloc[0] == 1.0 + assert class_recall[class_recall["label"] == "Bird"]["recall"].iloc[0] == 1.0 + assert class_recall[class_recall["label"] == "Bird"]["precision"].iloc[0] == 1.0 + + +def test_evaluate_keypoints_wrong_labels(): + """Test evaluation when predicted labels don't match ground truth.""" + result = evaluate.evaluate_keypoints( + gpd.GeoDataFrame({"geometry": pts([(100, 100), (200, 200)]), "score": [0.9, 0.8], "label": ["Tree", "Bird"], "image_path": ["img1.jpg", "img1.jpg"]}), + gpd.GeoDataFrame({"geometry": pts([(100, 100), (200, 200)]), "label": ["Bird", "Tree"], "image_path": ["img1.jpg", "img1.jpg"]}), + pixel_threshold=5.0 + ) + assert result["results"]["match"].sum() == 2 + assert all(result["class_recall"]["recall"] == 0.0) + + +def test_evaluate_keypoints_partial_image_matches(): + """Test evaluation where images have different match rates.""" + predictions = gpd.GeoDataFrame({ + "geometry": pts([(100, 100), (150, 150), (200, 200)]), + "score": [0.9, 0.85, 0.8], + "label": ["Tree", "Tree", "Tree"], + "image_path": ["img1.jpg", "img1.jpg", "img2.jpg"] + }) + ground_truth = gpd.GeoDataFrame({ + "geometry": pts([(100, 100), (150, 150), (200, 200), (250, 250)]), + "label": ["Tree", "Tree", "Tree", "Tree"], + "image_path": ["img1.jpg", "img1.jpg", "img2.jpg", "img2.jpg"] + }) + + result = evaluate.evaluate_keypoints(predictions, ground_truth, pixel_threshold=5.0) + + assert result["recall"] == 0.75 + assert result["precision"] == 1.0 diff --git a/tests/test_main_keypoint.py b/tests/test_main_keypoint.py new file mode 100644 index 000000000..bcac7ce9a --- /dev/null +++ b/tests/test_main_keypoint.py @@ -0,0 +1,132 @@ +"""Unit tests for main.py keypoint detection integration.""" +import tempfile +from pathlib import Path + +import pandas as pd +import pytest +from PIL import Image +import numpy as np + +from deepforest import main + + +def test_main_load_keypoint_config(): + """Test that main.py can load keypoint configuration.""" + model = main.deepforest(config="keypoint") + + assert model.config.task == "keypoint" + assert model.config.architecture == "DeformableDetr" + assert hasattr(model, 'model') + + +def test_main_create_keypoint_model(): + """Test that main.py creates keypoint model correctly.""" + model = main.deepforest(config="keypoint") + + assert model.model is not None + assert hasattr(model.model, 'net') + assert hasattr(model.model, 'processor') + + +def test_main_load_keypoint_dataset(): + """Test that main.py can load keypoint dataset.""" + model = main.deepforest(config="keypoint") + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + # Create dummy image + img = Image.fromarray(np.random.randint(0, 255, (400, 300, 3), dtype=np.uint8)) + img_path = tmpdir / "test.jpg" + img.save(img_path) + + # Create keypoint CSV + csv_path = tmpdir / "keypoints.csv" + df = pd.DataFrame({ + "image_path": ["test.jpg", "test.jpg", "test.jpg"], + "x": [100, 200, 150], + "y": [150, 250, 200], + "label": ["Tree", "Tree", "Tree"] + }) + df.to_csv(csv_path, index=False) + + # Load dataset + dataloader = model.load_dataset( + csv_file=str(csv_path), + root_dir=str(tmpdir), + batch_size=1, + shuffle=False + ) + + assert len(dataloader.dataset) == 1 + + # Get one sample + images, targets, paths = dataloader.dataset[0] + assert images.shape[0] == 3 # channels + assert "points" in targets + assert "labels" in targets + assert targets["points"].shape == (3, 2) # 3 keypoints, (x,y) + assert targets["labels"].shape == (3,) + + +def test_main_invalid_task_raises(): + """Test that invalid task type raises ValueError.""" + from deepforest.conf.schema import Config + + config = Config() + config.task = "invalid_task" + + with pytest.raises(ValueError, match="Invalid task type"): + model = main.deepforest() + model.config = config + model.create_model() + + +def test_main_box_task_still_works(): + """Test that box detection still works after keypoint changes.""" + model = main.deepforest() # Default is box task + + assert model.config.task == "box" + assert hasattr(model, 'model') + + +def test_main_keypoint_with_box_csv(): + """Test that KeypointDataset auto-converts box CSV to keypoints.""" + model = main.deepforest(config="keypoint") + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + # Create dummy image + img = Image.fromarray(np.random.randint(0, 255, (400, 300, 3), dtype=np.uint8)) + img_path = tmpdir / "test.jpg" + img.save(img_path) + + # Create box CSV (should auto-convert to keypoints) + csv_path = tmpdir / "boxes.csv" + df = pd.DataFrame({ + "image_path": ["test.jpg"], + "xmin": [50], + "ymin": [100], + "xmax": [150], + "ymax": [200], + "label": ["Tree"] + }) + df.to_csv(csv_path, index=False) + + # Load dataset - should convert boxes to keypoints (center) + dataloader = model.load_dataset( + csv_file=str(csv_path), + root_dir=str(tmpdir), + batch_size=1, + shuffle=False + ) + + images, targets, paths = dataloader.dataset[0] + + # Check that keypoint is at center of box + assert targets["points"].shape == (1, 2) + expected_x = (50 + 150) / 2 + expected_y = (100 + 200) / 2 + assert targets["points"][0, 0].item() == pytest.approx(expected_x) + assert targets["points"][0, 1].item() == pytest.approx(expected_y) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 8346fa0cc..395e00323 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -640,9 +640,11 @@ def test_format_geometry_point(): "scores": torch.tensor([0.9, 0.8]) } - # Format geometry should raise ValueError since point predictions are not supported - with pytest.raises(ValueError, match="Point predictions are not yet supported for formatting"): - utilities.format_geometry(prediction, geom_type="point") + result = utilities.format_geometry(prediction, geom_type="point") + assert isinstance(result, pd.DataFrame) + assert "geometry" in result.columns + assert len(result) == 2 + assert isinstance(result.iloc[0]["geometry"], geometry.Point) def test_format_geometry_polygon():