From 6b33f6a93d0f66bf4f378b5b56239e85b6aa77b2 Mon Sep 17 00:00:00 2001 From: sunnycho100 Date: Wed, 18 Mar 2026 01:34:43 -0500 Subject: [PATCH 1/2] Allow degrees=None in RandomAffine to disable rotation --- test/test_transforms.py | 16 ++++++++++++++++ test/test_transforms_v2.py | 19 +++++++++++++++++++ torchvision/transforms/transforms.py | 18 ++++++++++++++---- torchvision/transforms/v2/_geometry.py | 18 ++++++++++++++---- 4 files changed, 63 insertions(+), 8 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index d93800d59bc..d4e31ec905e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2164,6 +2164,22 @@ def test_random_affine(): with pytest.raises(ValueError): transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10]) + # degrees=None should work when another transform param is given + t = transforms.RandomAffine(degrees=None, translate=[0.2, 0.3]) + assert t.degrees == [0.0, 0.0] + + # omitting degrees positionally should also work + t = transforms.RandomAffine(translate=[0.2, 0.3]) + assert t.degrees == [0.0, 0.0] + + # degrees=None with no other params must raise + with pytest.raises(ValueError, match="at least one of translate, scale, or shear"): + transforms.RandomAffine(degrees=None) + + # all-defaults (no args) must raise + with pytest.raises(ValueError, match="at least one of translate, scale, or shear"): + transforms.RandomAffine() + # assert fill being either a Sequence or a Number with pytest.raises(TypeError): transforms.RandomAffine(0, fill={}) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 30d7ba69bea..93af4a9e399 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1838,6 +1838,25 @@ def test_transform_unknown_fill_error(self): with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.RandomAffine(degrees=0, fill="fill") + def test_transform_degrees_none(self): + # degrees=None disables rotation when another param is given + t = transforms.RandomAffine(degrees=None, translate=(0.2, 0.3)) + assert t.degrees == [0.0, 0.0] + + # positional omission works the same way + t = transforms.RandomAffine(translate=(0.2, 0.3)) + assert t.degrees == [0.0, 0.0] + + def test_transform_degrees_none_no_other_param_error(self): + # degrees=None with nothing else enabled must raise + with pytest.raises(ValueError, match="at least one of translate, scale, or shear"): + transforms.RandomAffine(degrees=None) + + def test_transform_no_args_error(self): + # zero arguments must raise + with pytest.raises(ValueError, match="at least one of translate, scale, or shear"): + transforms.RandomAffine() + class TestVerticalFlip: @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index e33b3e28194..1d5b7755ed2 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1398,9 +1398,10 @@ class RandomAffine(torch.nn.Module): to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: - degrees (sequence or number): Range of degrees to select from. + degrees (sequence or number, optional): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees - will be (-degrees, +degrees). Set to 0 to deactivate rotations. + will be (-degrees, +degrees). Set to 0 or ``None`` to deactivate rotations. + If ``None``, at least one of ``translate``, ``scale``, or ``shear`` must be provided. translate (tuple, optional): tuple of maximum absolute fraction for horizontal and vertical translations. For example translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is @@ -1428,7 +1429,7 @@ class RandomAffine(torch.nn.Module): def __init__( self, - degrees, + degrees=None, translate=None, scale=None, shear=None, @@ -1442,7 +1443,16 @@ def __init__( if isinstance(interpolation, int): interpolation = _interpolation_modes_from_int(interpolation) - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + # Allow degrees=None to disable rotation, but require at least one + # other geometric transform to be specified. + if degrees is None: + if translate is None and scale is None and shear is None: + raise ValueError( + "If degrees is None, at least one of translate, scale, or shear must be provided." + ) + self.degrees = [0.0, 0.0] + else: + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if translate is not None: _check_sequence_input(translate, "translate", req_sizes=(2,)) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 20e4c9e5942..191ea998f5d 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -660,9 +660,10 @@ class RandomAffine(Transform): the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. Args: - degrees (sequence or number): Range of degrees to select from. + degrees (sequence or number, optional): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees - will be (-degrees, +degrees). Set to 0 to deactivate rotations. + will be (-degrees, +degrees). Set to 0 or ``None`` to deactivate rotations. + If ``None``, at least one of ``translate``, ``scale``, or ``shear`` must be provided. translate (tuple, optional): tuple of maximum absolute fraction for horizontal and vertical translations. For example translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is @@ -695,7 +696,7 @@ class RandomAffine(Transform): def __init__( self, - degrees: Union[numbers.Number, Sequence], + degrees: Optional[Union[numbers.Number, Sequence]] = None, translate: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None, @@ -704,7 +705,16 @@ def __init__( center: Optional[list[float]] = None, ) -> None: super().__init__() - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + # Allow degrees=None to disable rotation, but require at least one + # other geometric transform to be specified. + if degrees is None: + if translate is None and scale is None and shear is None: + raise ValueError( + "If degrees is None, at least one of translate, scale, or shear must be provided." + ) + self.degrees = [0.0, 0.0] + else: + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if translate is not None: _check_sequence_input(translate, "translate", req_sizes=(2,)) for t in translate: From 2dcef64504ee8b9812bf1a213bb2a43919fd376b Mon Sep 17 00:00:00 2001 From: sunnycho100 Date: Wed, 18 Mar 2026 02:20:23 -0500 Subject: [PATCH 2/2] Refactor autoaugment: extract _AutoAugmentBase to eliminate duplicated fill logic --- torchvision/transforms/autoaugment.py | 76 ++++++++++++--------------- 1 file changed, 35 insertions(+), 41 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 20291d09b94..126d8d6aa5a 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -100,8 +100,30 @@ class AutoAugmentPolicy(Enum): SVHN = "svhn" -# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class -class AutoAugment(torch.nn.Module): +class _AutoAugmentBase(torch.nn.Module): + """Base class for AutoAugment, RandAugment, TrivialAugmentWide, and AugMix.""" + + def __init__( + self, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[list[float]] = None, + ) -> None: + super().__init__() + self.interpolation = interpolation + self.fill = fill + + def _get_fill(self, img: Tensor) -> Optional[list[float]]: + fill = self.fill + channels, height, width = F.get_dimensions(img) + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * channels + elif fill is not None: + fill = [float(f) for f in fill] + return fill + + +class AutoAugment(_AutoAugmentBase): r"""AutoAugment data augmentation method based on `"AutoAugment: Learning Augmentation Strategies from Data" `_. If the image is torch Tensor, it should be of type torch.uint8, and it is expected @@ -124,10 +146,8 @@ def __init__( interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[list[float]] = None, ) -> None: - super().__init__() + super().__init__(interpolation=interpolation, fill=fill) self.policy = policy - self.interpolation = interpolation - self.fill = fill self.policies = self._get_policies(policy) def _get_policies( @@ -259,13 +279,8 @@ def forward(self, img: Tensor) -> Tensor: Returns: PIL Image or Tensor: AutoAugmented image. """ - fill = self.fill + fill = self._get_fill(img) channels, height, width = F.get_dimensions(img) - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * channels - elif fill is not None: - fill = [float(f) for f in fill] transform_id, probs, signs = self.get_params(len(self.policies)) @@ -284,7 +299,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(policy={self.policy}, fill={self.fill})" -class RandAugment(torch.nn.Module): +class RandAugment(_AutoAugmentBase): r"""RandAugment data augmentation method based on `"RandAugment: Practical automated data augmentation with a reduced search space" `_. @@ -311,12 +326,10 @@ def __init__( interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[list[float]] = None, ) -> None: - super().__init__() + super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops self.magnitude = magnitude self.num_magnitude_bins = num_magnitude_bins - self.interpolation = interpolation - self.fill = fill def _augmentation_space(self, num_bins: int, image_size: tuple[int, int]) -> dict[str, tuple[Tensor, bool]]: return { @@ -344,13 +357,8 @@ def forward(self, img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Transformed image. """ - fill = self.fill + fill = self._get_fill(img) channels, height, width = F.get_dimensions(img) - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * channels - elif fill is not None: - fill = [float(f) for f in fill] op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width)) for _ in range(self.num_ops): @@ -377,7 +385,7 @@ def __repr__(self) -> str: return s -class TrivialAugmentWide(torch.nn.Module): +class TrivialAugmentWide(_AutoAugmentBase): r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `_. If the image is torch Tensor, it should be of type torch.uint8, and it is expected @@ -399,10 +407,8 @@ def __init__( interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[list[float]] = None, ) -> None: - super().__init__() + super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins - self.interpolation = interpolation - self.fill = fill def _augmentation_space(self, num_bins: int) -> dict[str, tuple[Tensor, bool]]: return { @@ -430,13 +436,7 @@ def forward(self, img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Transformed image. """ - fill = self.fill - channels, height, width = F.get_dimensions(img) - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * channels - elif fill is not None: - fill = [float(f) for f in fill] + fill = self._get_fill(img) op_meta = self._augmentation_space(self.num_magnitude_bins) op_index = int(torch.randint(len(op_meta), (1,)).item()) @@ -463,7 +463,7 @@ def __repr__(self) -> str: return s -class AugMix(torch.nn.Module): +class AugMix(_AutoAugmentBase): r"""AugMix data augmentation method based on `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" `_. If the image is torch Tensor, it should be of type torch.uint8, and it is expected @@ -494,7 +494,7 @@ def __init__( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[list[float]] = None, ) -> None: - super().__init__() + super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 if not (1 <= severity <= self._PARAMETER_MAX): raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") @@ -503,8 +503,6 @@ def __init__( self.chain_depth = chain_depth self.alpha = alpha self.all_ops = all_ops - self.interpolation = interpolation - self.fill = fill def _augmentation_space(self, num_bins: int, image_size: tuple[int, int]) -> dict[str, tuple[Tensor, bool]]: s = { @@ -549,14 +547,10 @@ def forward(self, orig_img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Transformed image. """ - fill = self.fill + fill = self._get_fill(orig_img) channels, height, width = F.get_dimensions(orig_img) if isinstance(orig_img, Tensor): img = orig_img - if isinstance(fill, (int, float)): - fill = [float(fill)] * channels - elif fill is not None: - fill = [float(f) for f in fill] else: img = self._pil_to_tensor(orig_img)