Skip to content

Commit dc34375

Browse files
committed
Add 'crop' option to RandomRotate that centre crops the image to remove any padding regions introduced by the rotation
1 parent 6676c01 commit dc34375

3 files changed

Lines changed: 122 additions & 3 deletions

File tree

test/test_transforms_v2.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2435,6 +2435,62 @@ def test_functional_image_fast_path_correctness(self, size, angle, expand):
24352435

24362436
torch.testing.assert_close(actual, expected)
24372437

2438+
@pytest.mark.parametrize("size", [(100, 100), (120, 80)])
2439+
@pytest.mark.parametrize("angle", [15.0, 30.0, 45.0])
2440+
def test_transform_crop_removes_fill(self, size, angle):
2441+
# Output of crop=True should contain no fill pixels when input is fully non-zero
2442+
h, w = size
2443+
image = tv_tensors.Image(torch.full((3, h, w), 200, dtype=torch.uint8))
2444+
transform = transforms.RandomRotation((angle, angle), fill=0, crop=True)
2445+
output = transform(image)
2446+
assert output.min().item() > 0, "crop=True output should have no fill pixels"
2447+
assert output.shape[-2] < h or output.shape[-1] < w, "crop=True should reduce at least one dimension"
2448+
2449+
@pytest.mark.parametrize("size", [(100, 100), (120, 80)])
2450+
@pytest.mark.parametrize("angle", [15.0, 30.0, 45.0])
2451+
def test_transform_crop_consistent_across_inputs(self, size, angle):
2452+
# Image, mask, and bounding boxes should all be cropped to the same canvas size
2453+
h, w = size
2454+
image = tv_tensors.Image(torch.full((3, h, w), 200, dtype=torch.uint8))
2455+
mask = tv_tensors.Mask(torch.ones(1, h, w, dtype=torch.uint8))
2456+
boxes = tv_tensors.BoundingBoxes(
2457+
torch.tensor([[10.0, 10.0, 50.0, 50.0]]),
2458+
format=tv_tensors.BoundingBoxFormat.XYXY,
2459+
canvas_size=(h, w),
2460+
)
2461+
transform = transforms.RandomRotation((angle, angle), crop=True)
2462+
out_image, out_mask, out_boxes = transform(image, mask, boxes)
2463+
assert out_image.shape[-2:] == out_mask.shape[-2:]
2464+
assert out_boxes.canvas_size == (out_image.shape[-2], out_image.shape[-1])
2465+
2466+
def test_transform_crop_and_expand_mutually_exclusive(self):
2467+
with pytest.raises(ValueError, match="crop and expand are mutually exclusive"):
2468+
transforms.RandomRotation(30, expand=True, crop=True)
2469+
2470+
@pytest.mark.parametrize("angle", [0.0, 90.0, 180.0, 270.0])
2471+
def test_transform_crop_zero_angle_preserves_size(self, angle):
2472+
# Multiples of 90° should not reduce the image size
2473+
image = tv_tensors.Image(torch.zeros(3, 100, 100, dtype=torch.uint8))
2474+
transform = transforms.RandomRotation((angle, angle), crop=True)
2475+
output = transform(image)
2476+
assert output.shape == image.shape
2477+
2478+
def test_largest_inscribed_crop_size(self):
2479+
from torchvision.transforms.v2.functional._geometry import _largest_inscribed_crop_size
2480+
2481+
# No rotation: crop equals original size
2482+
assert _largest_inscribed_crop_size(100, 100, 0) == (100, 100)
2483+
assert _largest_inscribed_crop_size(200, 100, 0) == (100, 200)
2484+
2485+
# 45° square: inscribed square has side = 100 / sqrt(2) ≈ 70.71 → floor to 70
2486+
crop_h, crop_w = _largest_inscribed_crop_size(100, 100, 45)
2487+
assert crop_h == crop_w == 70
2488+
2489+
# Crop is always smaller than or equal to original dimensions
2490+
for w, h, a in [(200, 100, 20), (640, 480, 15), (50, 50, 37)]:
2491+
ch, cw = _largest_inscribed_crop_size(w, h, a)
2492+
assert ch <= h and cw <= w
2493+
24382494

24392495
class TestContainerTransforms:
24402496
class BuiltinTransform(transforms.Transform):

torchvision/transforms/v2/_geometry.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,9 @@ class RandomRotation(Transform):
606606
Fill value can be also a dictionary mapping data type to the fill value, e.g.
607607
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
608608
``Mask`` will be filled with 0.
609+
crop (bool, optional): If ``True``, the rotated output is center-cropped to the largest axis-aligned
610+
rectangle that fits entirely within the rotated image, removing any fill/padding regions introduced
611+
by the rotation. Mutually exclusive with ``expand``. Default is ``False``.
609612
610613
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
611614
@@ -620,11 +623,15 @@ def __init__(
620623
expand: bool = False,
621624
center: Optional[list[float]] = None,
622625
fill: Union[_FillType, dict[Union[type, str], _FillType]] = 0,
626+
crop: bool = False,
623627
) -> None:
624628
super().__init__()
629+
if crop and expand:
630+
raise ValueError("crop and expand are mutually exclusive")
625631
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
626632
self.interpolation = interpolation
627633
self.expand = expand
634+
self.crop = crop
628635

629636
self.fill = fill
630637
self._fill = _setup_fill_arg(fill)
@@ -634,21 +641,37 @@ def __init__(
634641

635642
self.center = center
636643

644+
def _extract_params_for_v1_transform(self) -> dict[str, Any]:
645+
params = super()._extract_params_for_v1_transform()
646+
if params.pop("crop"):
647+
raise ValueError(
648+
f"{type(self).__name__}() cannot be scripted when crop=True, "
649+
"as this feature is not supported by the v1 transform."
650+
)
651+
return params
652+
637653
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
638654
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
639-
return dict(angle=angle)
655+
params: dict[str, Any] = dict(angle=angle)
656+
if self.crop:
657+
height, width = query_size(flat_inputs)
658+
params["crop_hw"] = F._geometry._largest_inscribed_crop_size(width, height, angle)
659+
return params
640660

641661
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
642662
fill = _get_fill(self._fill, type(inpt))
643-
return self._call_kernel(
663+
output = self._call_kernel(
644664
F.rotate,
645665
inpt,
646-
**params,
666+
angle=params["angle"],
647667
interpolation=self.interpolation,
648668
expand=self.expand,
649669
center=self.center,
650670
fill=fill,
651671
)
672+
if self.crop:
673+
output = self._call_kernel(F.center_crop, output, output_size=list(params["crop_hw"]))
674+
return output
652675

653676

654677
class RandomAffine(Transform):

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,46 @@ def affine_video(
13351335
)
13361336

13371337

1338+
def _largest_inscribed_crop_size(width: int, height: int, angle: float) -> tuple[int, int]:
1339+
"""Compute the largest axis-aligned rectangle inscribed in a rotated width x height rectangle.
1340+
1341+
Returns ``(crop_height, crop_width)`` as integers.
1342+
"""
1343+
import math
1344+
1345+
angle_rad = math.radians(angle)
1346+
sin_a = abs(math.sin(angle_rad))
1347+
cos_a = abs(math.cos(angle_rad))
1348+
1349+
# Clamp near-zero values to avoid numerical noise from sin(180°) ≈ 1.2e-16 etc.
1350+
if sin_a < 1e-10:
1351+
return height, width
1352+
if cos_a < 1e-10:
1353+
return width, height
1354+
1355+
width_is_longer = width >= height
1356+
side_long = width if width_is_longer else height
1357+
side_short = height if width_is_longer else width
1358+
1359+
if side_short <= 2.0 * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10:
1360+
# Half-constrained: two crop corners touch the longer side.
1361+
# Also handles the 45° degenerate case via abs(sin_a - cos_a) < 1e-10.
1362+
x = 0.5 * side_short
1363+
if width_is_longer:
1364+
crop_w, crop_h = x / sin_a, x / cos_a
1365+
else:
1366+
crop_w, crop_h = x / cos_a, x / sin_a
1367+
else:
1368+
# Fully constrained: crop touches all four sides
1369+
cos_2a = cos_a * cos_a - sin_a * sin_a
1370+
crop_w = (width * cos_a - height * sin_a) / cos_2a
1371+
crop_h = (height * cos_a - width * sin_a) / cos_2a
1372+
1373+
# Use floor (int()) to guarantee the crop region contains no fill pixels.
1374+
# Clamp to image dimensions for edge cases like wide images rotated near 90°.
1375+
return min(int(crop_h), height), min(int(crop_w), width)
1376+
1377+
13381378
def rotate(
13391379
inpt: torch.Tensor,
13401380
angle: float,

0 commit comments

Comments
 (0)