Skip to content

Commit a3d8797

Browse files
committed
update invert with new PR comments revisions
1 parent 7ccc301 commit a3d8797

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

test/test_transforms_v2.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5874,23 +5874,26 @@ def test_functional_signature(self, kernel, input_type):
58745874
def test_transform(self, make_input):
58755875
check_transform(transforms.RandomInvert(p=1), make_input())
58765876

5877+
@pytest.mark.parametrize(
5878+
"make_input",
5879+
[
5880+
make_image,
5881+
pytest.param(
5882+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5883+
),
5884+
],
5885+
)
58775886
@pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)])
5878-
def test_correctness_image(self, fn):
5879-
image = make_image(dtype=torch.uint8, device="cpu")
5887+
def test_correctness_image(self, make_input, fn):
5888+
image = make_input(dtype=torch.uint8, device="cpu")
58805889

58815890
actual = fn(image)
5882-
expected = F.to_image(F.invert(F.to_pil_image(image)))
58835891

5884-
assert_equal(actual, expected)
5892+
if make_input is make_image_cvcuda:
5893+
image = cvcuda_to_pil_compatible_tensor(image)
5894+
5895+
expected = F.to_image(F.invert(F.to_pil_image(image)))
58855896

5886-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5887-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
5888-
@pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)])
5889-
def test_correctness_cvcuda(self, dtype, fn):
5890-
image = make_image(batch_dims=(1,), dtype=dtype, device="cuda")
5891-
cv_image = F.to_cvcuda_tensor(image)
5892-
actual = F.cvcuda_to_tensor(fn(cv_image))
5893-
expected = F.invert_image(image)
58945897
assert_equal(actual, expected)
58955898

58965899

torchvision/transforms/v2/functional/_color.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,13 +690,15 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
690690
return invert_image(video)
691691

692692

693-
if _CVCUDA_AVAILABLE:
693+
if CVCUDA_AVAILABLE:
694694
_invert_cvcuda_tensors = {}
695695

696696

697697
def _invert_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
698698
cvcuda = _import_cvcuda()
699699

700+
# save the tensors into a dictionary only if CV-CUDA is actually used
701+
# we save these here, since they are static and small in size
700702
if "base" not in _invert_cvcuda_tensors:
701703
_invert_cvcuda_tensors["base"] = cvcuda.as_tensor(
702704
torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device="cuda").reshape(1, 1, 1, 3).contiguous(), "NHWC"
@@ -722,7 +724,7 @@ def _invert_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
722724
return cvcuda.normalize(image, base=base, scale=scale, globalscale=1.0, globalshift=shift)
723725

724726

725-
if _CVCUDA_AVAILABLE:
727+
if CVCUDA_AVAILABLE:
726728
_register_kernel_internal(invert, _import_cvcuda().Tensor)(_invert_cvcuda)
727729

728730

0 commit comments

Comments
 (0)