Skip to content

Commit 026185d

Browse files
committed
rgb to gray and gray to rgb done
1 parent fbea584 commit 026185d

File tree

2 files changed

+158
-8
lines changed

2 files changed

+158
-8
lines changed

test/test_transforms_v2.py

Lines changed: 101 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6410,7 +6410,17 @@ class TestRgbToGrayscale:
64106410
def test_kernel_image(self, dtype, device):
64116411
check_kernel(F.rgb_to_grayscale_image, make_image(dtype=dtype, device=device))
64126412

6413-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
6413+
@pytest.mark.parametrize(
6414+
"make_input",
6415+
[
6416+
make_image_tensor,
6417+
make_image_pil,
6418+
make_image,
6419+
pytest.param(
6420+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6421+
),
6422+
],
6423+
)
64146424
def test_functional(self, make_input):
64156425
check_functional(F.rgb_to_grayscale, make_input())
64166426

@@ -6420,23 +6430,62 @@ def test_functional(self, make_input):
64206430
(F.rgb_to_grayscale_image, torch.Tensor),
64216431
(F._color._rgb_to_grayscale_image_pil, PIL.Image.Image),
64226432
(F.rgb_to_grayscale_image, tv_tensors.Image),
6433+
pytest.param(
6434+
F._color._rgb_to_grayscale_cvcuda,
6435+
"cvcuda.Tensor",
6436+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
6437+
),
64236438
],
64246439
)
64256440
def test_functional_signature(self, kernel, input_type):
6441+
if input_type == "cvcuda.Tensor":
6442+
input_type = _import_cvcuda().Tensor
64266443
check_functional_kernel_signature_match(F.rgb_to_grayscale, kernel=kernel, input_type=input_type)
64276444

64286445
@pytest.mark.parametrize("transform", [transforms.Grayscale(), transforms.RandomGrayscale(p=1)])
6429-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
6446+
@pytest.mark.parametrize(
6447+
"make_input",
6448+
[
6449+
make_image_tensor,
6450+
make_image_pil,
6451+
make_image,
6452+
pytest.param(
6453+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6454+
),
6455+
],
6456+
)
64306457
def test_transform(self, transform, make_input):
6458+
if make_input is make_image_cvcuda and isinstance(transform, transforms.RandomGrayscale):
6459+
pytest.skip("CV-CUDA does not support RandomGrayscale, will have num_output_channels == 3")
64316460
check_transform(transform, make_input())
64326461

64336462
@pytest.mark.parametrize("num_output_channels", [1, 3])
64346463
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
6464+
@pytest.mark.parametrize(
6465+
"make_input",
6466+
[
6467+
make_image,
6468+
pytest.param(
6469+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6470+
),
6471+
],
6472+
)
64356473
@pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)])
6436-
def test_image_correctness(self, num_output_channels, color_space, fn):
6437-
image = make_image(dtype=torch.uint8, device="cpu", color_space=color_space)
6474+
def test_image_correctness(self, num_output_channels, color_space, make_input, fn):
6475+
if make_input is make_image_cvcuda and num_output_channels == 3:
6476+
pytest.skip("CV-CUDA does not support num_output_channels == 3")
6477+
6478+
image = make_input(dtype=torch.uint8, device="cpu", color_space=color_space)
64386479

64396480
actual = fn(image, num_output_channels=num_output_channels)
6481+
6482+
if make_input is make_image_cvcuda:
6483+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
6484+
actual = actual.squeeze(0)
6485+
# drop the batch dimension
6486+
image = F.cvcuda_to_tensor(image).to(device="cpu")
6487+
image = image.squeeze(0)
6488+
64406489
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))
64416490

64426491
assert_equal(actual, expected, rtol=0, atol=1)
@@ -6474,7 +6523,17 @@ class TestGrayscaleToRgb:
64746523
def test_kernel_image(self, dtype, device):
64756524
check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device))
64766525

6477-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
6526+
@pytest.mark.parametrize(
6527+
"make_input",
6528+
[
6529+
make_image_tensor,
6530+
make_image_pil,
6531+
make_image,
6532+
pytest.param(
6533+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6534+
),
6535+
],
6536+
)
64786537
def test_functional(self, make_input):
64796538
check_functional(F.grayscale_to_rgb, make_input())
64806539

@@ -6484,20 +6543,54 @@ def test_functional(self, make_input):
64846543
(F.rgb_to_grayscale_image, torch.Tensor),
64856544
(F._color._rgb_to_grayscale_image_pil, PIL.Image.Image),
64866545
(F.rgb_to_grayscale_image, tv_tensors.Image),
6546+
pytest.param(
6547+
F._color._rgb_to_grayscale_cvcuda,
6548+
"cvcuda.Tensor",
6549+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
6550+
),
64876551
],
64886552
)
64896553
def test_functional_signature(self, kernel, input_type):
6554+
if input_type == "cvcuda.Tensor":
6555+
input_type = _import_cvcuda().Tensor
64906556
check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type)
64916557

6492-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
6558+
@pytest.mark.parametrize(
6559+
"make_input",
6560+
[
6561+
make_image_tensor,
6562+
make_image_pil,
6563+
make_image,
6564+
pytest.param(
6565+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6566+
),
6567+
],
6568+
)
64936569
def test_transform(self, make_input):
64946570
check_transform(transforms.RGB(), make_input(color_space="GRAY"))
64956571

6572+
@pytest.mark.parametrize(
6573+
"make_input",
6574+
[
6575+
make_image,
6576+
pytest.param(
6577+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6578+
),
6579+
],
6580+
)
64966581
@pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)])
6497-
def test_image_correctness(self, fn):
6498-
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
6582+
def test_image_correctness(self, make_input, fn):
6583+
image = make_input(dtype=torch.uint8, device="cpu", color_space="GRAY")
64996584

65006585
actual = fn(image)
6586+
6587+
if make_input is make_image_cvcuda:
6588+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
6589+
actual = actual.squeeze(0)
6590+
# drop the batch dimension
6591+
image = F.cvcuda_to_tensor(image).to(device="cpu")
6592+
image = image.squeeze(0)
6593+
65016594
expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image)))
65026595

65036596
assert_equal(actual, expected, rtol=0, atol=1)

torchvision/transforms/v2/functional/_color.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,38 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int
7373
return _FP.to_grayscale(image, num_output_channels=num_output_channels)
7474

7575

76+
def _rgb_to_grayscale_cvcuda(
77+
image: "cvcuda.Tensor",
78+
num_output_channels: int = 1,
79+
) -> "cvcuda.Tensor":
80+
cvcuda = _import_cvcuda()
81+
82+
if num_output_channels not in (1, 3):
83+
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
84+
85+
if num_output_channels == 3:
86+
raise ValueError("num_output_channels must be 1 for CV-CUDA, got 3.")
87+
88+
if image.shape[3] == 1:
89+
# if we already have a single channel, just clone the tensor
90+
# we will use copymakeborder since CV-CUDA has no native clone
91+
return cvcuda.copymakeborder(
92+
image,
93+
border_mode=cvcuda.Border.CONSTANT,
94+
border_value=[0],
95+
top=0,
96+
left=0,
97+
bottom=0,
98+
right=0,
99+
)
100+
101+
return cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2GRAY)
102+
103+
104+
if CVCUDA_AVAILABLE:
105+
_register_kernel_internal(rgb_to_grayscale, _import_cvcuda().Tensor)(_rgb_to_grayscale_cvcuda)
106+
107+
76108
def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor:
77109
"""See :class:`~torchvision.transforms.v2.RGB` for details."""
78110
if torch.jit.is_scripting():
@@ -99,6 +131,31 @@ def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
99131
return image.convert(mode="RGB")
100132

101133

134+
def _grayscale_to_rgb_cvcuda(
135+
image: "cvcuda.Tensor",
136+
) -> "cvcuda.Tensor":
137+
cvcuda = _import_cvcuda()
138+
139+
if image.shape[3] == 3:
140+
# if we already have RGB channels, just clone the tensor
141+
# we will use copymakeborder since CV-CUDA has no native clone
142+
return cvcuda.copymakeborder(
143+
image,
144+
border_mode=cvcuda.Border.CONSTANT,
145+
border_value=[0],
146+
top=0,
147+
left=0,
148+
bottom=0,
149+
right=0,
150+
)
151+
152+
return cvcuda.cvtcolor(image, cvcuda.ColorConversion.GRAY2RGB)
153+
154+
155+
if CVCUDA_AVAILABLE:
156+
_register_kernel_internal(grayscale_to_rgb, _import_cvcuda().Tensor)(_grayscale_to_rgb_cvcuda)
157+
158+
102159
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
103160
ratio = float(ratio)
104161
fp = image1.is_floating_point()

0 commit comments

Comments
 (0)