diff --git a/src/deepforest/model.py b/src/deepforest/model.py index e667a03e1..9a21dcf7d 100644 --- a/src/deepforest/model.py +++ b/src/deepforest/model.py @@ -265,11 +265,15 @@ def get_transform(self, augmentations): resize_dims = self.config["cropmodel"].get("resize", [224, 224]) interp_name = self.config["cropmodel"].get("resize_interpolation", "bilinear") - interp = ( - transforms.InterpolationMode.NEAREST - if interp_name == "nearest" - else transforms.InterpolationMode.BILINEAR - ) + if interp_name == "nearest": + interp = transforms.InterpolationMode.NEAREST + elif interp_name == "bilinear": + interp = transforms.InterpolationMode.BILINEAR + else: + raise ValueError( + f"Invalid resize_interpolation '{interp_name}'. " + "Supported values are ['nearest', 'bilinear']." + ) data_transforms.append(transforms.Resize(resize_dims, interpolation=interp)) # Apply augmentations if specified @@ -281,7 +285,9 @@ def get_transform(self, augmentations): return transforms.Compose(data_transforms) - def expand_bbox_to_square(self, bbox, image_width, image_height): + def expand_bbox_to_square( + self, bbox: list[float], image_width: int, image_height: int + ) -> list[float]: """Expand a bounding box to a square by extending the shorter side. Parameters: