Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,22 @@ def test_random_affine():
with pytest.raises(ValueError):
transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10])

# degrees=None should work when another transform param is given
t = transforms.RandomAffine(degrees=None, translate=[0.2, 0.3])
assert t.degrees == [0.0, 0.0]

# omitting degrees positionally should also work
t = transforms.RandomAffine(translate=[0.2, 0.3])
assert t.degrees == [0.0, 0.0]

# degrees=None with no other params must raise
with pytest.raises(ValueError, match="at least one of translate, scale, or shear"):
transforms.RandomAffine(degrees=None)

# all-defaults (no args) must raise
with pytest.raises(ValueError, match="at least one of translate, scale, or shear"):
transforms.RandomAffine()

# assert fill being either a Sequence or a Number
with pytest.raises(TypeError):
transforms.RandomAffine(0, fill={})
Expand Down
19 changes: 19 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,25 @@ def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill")

def test_transform_degrees_none(self):
# degrees=None disables rotation when another param is given
t = transforms.RandomAffine(degrees=None, translate=(0.2, 0.3))
assert t.degrees == [0.0, 0.0]

# positional omission works the same way
t = transforms.RandomAffine(translate=(0.2, 0.3))
assert t.degrees == [0.0, 0.0]

def test_transform_degrees_none_no_other_param_error(self):
# degrees=None with nothing else enabled must raise
with pytest.raises(ValueError, match="at least one of translate, scale, or shear"):
transforms.RandomAffine(degrees=None)

def test_transform_no_args_error(self):
# zero arguments must raise
with pytest.raises(ValueError, match="at least one of translate, scale, or shear"):
transforms.RandomAffine()


class TestVerticalFlip:
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
Expand Down
76 changes: 35 additions & 41 deletions torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,30 @@ class AutoAugmentPolicy(Enum):
SVHN = "svhn"


# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
class AutoAugment(torch.nn.Module):
class _AutoAugmentBase(torch.nn.Module):
"""Base class for AutoAugment, RandAugment, TrivialAugmentWide, and AugMix."""

def __init__(
self,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[list[float]] = None,
) -> None:
super().__init__()
self.interpolation = interpolation
self.fill = fill

def _get_fill(self, img: Tensor) -> Optional[list[float]]:
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
return fill


class AutoAugment(_AutoAugmentBase):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
Expand All @@ -124,10 +146,8 @@ def __init__(
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[list[float]] = None,
) -> None:
super().__init__()
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
self.interpolation = interpolation
self.fill = fill
self.policies = self._get_policies(policy)

def _get_policies(
Expand Down Expand Up @@ -259,13 +279,8 @@ def forward(self, img: Tensor) -> Tensor:
Returns:
PIL Image or Tensor: AutoAugmented image.
"""
fill = self.fill
fill = self._get_fill(img)
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]

transform_id, probs, signs = self.get_params(len(self.policies))

Expand All @@ -284,7 +299,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(policy={self.policy}, fill={self.fill})"


class RandAugment(torch.nn.Module):
class RandAugment(_AutoAugmentBase):
r"""RandAugment data augmentation method based on
`"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`_.
Expand All @@ -311,12 +326,10 @@ def __init__(
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[list[float]] = None,
) -> None:
super().__init__()
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill

def _augmentation_space(self, num_bins: int, image_size: tuple[int, int]) -> dict[str, tuple[Tensor, bool]]:
return {
Expand Down Expand Up @@ -344,13 +357,8 @@ def forward(self, img: Tensor) -> Tensor:
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
fill = self._get_fill(img)
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]

op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
for _ in range(self.num_ops):
Expand All @@ -377,7 +385,7 @@ def __repr__(self) -> str:
return s


class TrivialAugmentWide(torch.nn.Module):
class TrivialAugmentWide(_AutoAugmentBase):
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
Expand All @@ -399,10 +407,8 @@ def __init__(
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[list[float]] = None,
) -> None:
super().__init__()
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill

def _augmentation_space(self, num_bins: int) -> dict[str, tuple[Tensor, bool]]:
return {
Expand Down Expand Up @@ -430,13 +436,7 @@ def forward(self, img: Tensor) -> Tensor:
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
fill = self._get_fill(img)

op_meta = self._augmentation_space(self.num_magnitude_bins)
op_index = int(torch.randint(len(op_meta), (1,)).item())
Expand All @@ -463,7 +463,7 @@ def __repr__(self) -> str:
return s


class AugMix(torch.nn.Module):
class AugMix(_AutoAugmentBase):
r"""AugMix data augmentation method based on
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
Expand Down Expand Up @@ -494,7 +494,7 @@ def __init__(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[list[float]] = None,
) -> None:
super().__init__()
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
Expand All @@ -503,8 +503,6 @@ def __init__(
self.chain_depth = chain_depth
self.alpha = alpha
self.all_ops = all_ops
self.interpolation = interpolation
self.fill = fill

def _augmentation_space(self, num_bins: int, image_size: tuple[int, int]) -> dict[str, tuple[Tensor, bool]]:
s = {
Expand Down Expand Up @@ -549,14 +547,10 @@ def forward(self, orig_img: Tensor) -> Tensor:
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
fill = self._get_fill(orig_img)
channels, height, width = F.get_dimensions(orig_img)
if isinstance(orig_img, Tensor):
img = orig_img
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
else:
img = self._pil_to_tensor(orig_img)

Expand Down
18 changes: 14 additions & 4 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,9 +1398,10 @@ class RandomAffine(torch.nn.Module):
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.

Args:
degrees (sequence or number): Range of degrees to select from.
degrees (sequence or number, optional): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
will be (-degrees, +degrees). Set to 0 or ``None`` to deactivate rotations.
If ``None``, at least one of ``translate``, ``scale``, or ``shear`` must be provided.
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
and vertical translations. For example translate=(a, b), then horizontal shift
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
Expand Down Expand Up @@ -1428,7 +1429,7 @@ class RandomAffine(torch.nn.Module):

def __init__(
self,
degrees,
degrees=None,
translate=None,
scale=None,
shear=None,
Expand All @@ -1442,7 +1443,16 @@ def __init__(
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)

self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
# Allow degrees=None to disable rotation, but require at least one
# other geometric transform to be specified.
if degrees is None:
if translate is None and scale is None and shear is None:
raise ValueError(
"If degrees is None, at least one of translate, scale, or shear must be provided."
)
self.degrees = [0.0, 0.0]
else:
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))

if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2,))
Expand Down
18 changes: 14 additions & 4 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,9 +660,10 @@ class RandomAffine(Transform):
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.

Args:
degrees (sequence or number): Range of degrees to select from.
degrees (sequence or number, optional): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
will be (-degrees, +degrees). Set to 0 or ``None`` to deactivate rotations.
If ``None``, at least one of ``translate``, ``scale``, or ``shear`` must be provided.
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
and vertical translations. For example translate=(a, b), then horizontal shift
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
Expand Down Expand Up @@ -695,7 +696,7 @@ class RandomAffine(Transform):

def __init__(
self,
degrees: Union[numbers.Number, Sequence],
degrees: Optional[Union[numbers.Number, Sequence]] = None,
translate: Optional[Sequence[float]] = None,
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
Expand All @@ -704,7 +705,16 @@ def __init__(
center: Optional[list[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
# Allow degrees=None to disable rotation, but require at least one
# other geometric transform to be specified.
if degrees is None:
if translate is None and scale is None and shear is None:
raise ValueError(
"If degrees is None, at least one of translate, scale, or shear must be provided."
)
self.degrees = [0.0, 0.0]
else:
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate:
Expand Down
Loading