Skip to content

Commit dbf4a5c

Browse files
committed
fix: normalize_cvcuda move to correct patterns for tests/exporting
1 parent 324cefc commit dbf4a5c

File tree

4 files changed

+52
-82
lines changed

4 files changed

+52
-82
lines changed

test/common_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,9 @@ def make_image_pil(*args, **kwargs):
400400
return to_pil_image(make_image(*args, **kwargs))
401401

402402

403-
def make_image_cvcuda(*args, **kwargs):
404-
return to_cvcuda_tensor(make_image(*args, **kwargs))
403+
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
404+
# explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4)
405+
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))
405406

406407

407408
def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):

test/test_transforms_v2.py

Lines changed: 37 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5517,7 +5517,17 @@ def test_kernel_image_inplace(self, device):
55175517
def test_kernel_video(self):
55185518
check_kernel(F.normalize_video, make_video(dtype=torch.float32), mean=self.MEAN, std=self.STD)
55195519

5520-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
5520+
@pytest.mark.parametrize(
5521+
"make_input",
5522+
[
5523+
make_image_tensor,
5524+
make_image,
5525+
make_video,
5526+
pytest.param(
5527+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5528+
),
5529+
],
5530+
)
55215531
def test_functional(self, make_input):
55225532
check_functional(F.normalize, make_input(dtype=torch.float32), mean=self.MEAN, std=self.STD)
55235533

@@ -5527,6 +5537,11 @@ def test_functional(self, make_input):
55275537
(F.normalize_image, torch.Tensor),
55285538
(F.normalize_image, tv_tensors.Image),
55295539
(F.normalize_video, tv_tensors.Video),
5540+
pytest.param(
5541+
F._misc._normalize_cvcuda,
5542+
_import_cvcuda().Tensor,
5543+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
5544+
),
55305545
],
55315546
)
55325547
def test_functional_signature(self, kernel, input_type):
@@ -5555,7 +5570,17 @@ def _sample_input_adapter(self, transform, input, device):
55555570
adapted_input[key] = value
55565571
return adapted_input
55575572

5558-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
5573+
@pytest.mark.parametrize(
5574+
"make_input",
5575+
[
5576+
make_image_tensor,
5577+
make_image,
5578+
make_video,
5579+
pytest.param(
5580+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5581+
),
5582+
],
5583+
)
55595584
def test_transform(self, make_input):
55605585
check_transform(
55615586
transforms.Normalize(mean=self.MEAN, std=self.STD),
@@ -5579,78 +5604,16 @@ def test_correctness_image(self, mean, std, dtype, fn):
55795604

55805605
assert_equal(actual, expected)
55815606

5582-
5583-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5584-
@needs_cuda
5585-
class TestNormalizeCVCUDA:
5586-
MEANS_STDS = {
5587-
"RGB": TestNormalize.MEANS_STDS,
5588-
"GRAY": [([0.5], [2.0])],
5589-
}
5590-
MEAN_STD = {
5591-
"RGB": MEANS_STDS["RGB"][0],
5592-
"GRAY": MEANS_STDS["GRAY"][0],
5593-
}
5594-
5595-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5596-
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5597-
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5598-
def test_functional(self, color_space, batch_dims, dtype):
5599-
means_stds = self.MEANS_STDS[color_space]
5600-
for mean, std in means_stds:
5601-
image = make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims)
5602-
check_functional(F.normalize, image, mean=mean, std=std)
5603-
5604-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5605-
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5606-
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5607-
def test_functional_scalar(self, color_space, batch_dims, dtype):
5608-
image = make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims)
5609-
check_functional(F.normalize, image, mean=0.5, std=2.0)
5610-
5611-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5612-
@pytest.mark.parametrize("batch_dims", [(1,)])
5613-
def test_functional_error(self, dtype, batch_dims):
5614-
rgb_mean, rgb_std = self.MEAN_STD["RGB"]
5615-
gray_mean, gray_std = self.MEAN_STD["GRAY"]
5616-
5617-
with pytest.raises(ValueError, match="Inplace normalization is not supported for CVCUDA."):
5618-
F.normalize(make_image_cvcuda(batch_dims=batch_dims, dtype=dtype), mean=rgb_mean, std=rgb_std, inplace=True)
5619-
5620-
with pytest.raises(ValueError, match="Mean should have 3 elements. Got 1."):
5621-
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="RGB", dtype=dtype), mean=gray_mean, std=rgb_std)
5622-
5623-
with pytest.raises(ValueError, match="Std should have 3 elements. Got 1."):
5624-
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="RGB", dtype=dtype), mean=rgb_mean, std=gray_std)
5625-
5626-
with pytest.raises(ValueError, match="Mean should have 1 elements. Got 3."):
5627-
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="GRAY", dtype=dtype), mean=rgb_mean, std=gray_std)
5628-
5629-
with pytest.raises(ValueError, match="Std should have 1 elements. Got 3."):
5630-
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="GRAY", dtype=dtype), mean=gray_mean, std=rgb_std)
5631-
5632-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5633-
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5634-
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5635-
def test_transform(self, dtype, color_space, batch_dims):
5636-
means_stds = self.MEANS_STDS[color_space]
5637-
for mean, std in means_stds:
5638-
check_transform(
5639-
transforms.Normalize(mean=mean, std=std),
5640-
make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims),
5641-
)
5642-
5643-
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5644-
def test_correctness_image(self, batch_dims):
5645-
mean, std = self.MEAN_STD["RGB"]
5646-
torch_image = make_image(batch_dims=batch_dims, dtype=torch.float32, device="cuda")
5647-
cvc_image = F.to_cvcuda_tensor(torch_image)
5648-
5649-
gold = F.normalize(torch_image, mean=mean, std=std)
5650-
image = F.normalize(cvc_image, mean=mean, std=std)
5651-
image = F.cvcuda_to_tensor(image)
5652-
5653-
assert_close(image, gold, rtol=1e-7, atol=1e-7)
5607+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5608+
@pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
5609+
@pytest.mark.parametrize("dtype", [torch.float32])
5610+
@pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
5611+
def test_correctness_cvcuda(self, mean, std, dtype, fn):
5612+
image = make_image(batch_dims=(1,), dtype=dtype, device="cuda")
5613+
cvc_image = F.to_cvcuda_tensor(image)
5614+
actual = F._misc._normalize_cvcuda(cvc_image, mean=mean, std=std)
5615+
expected = fn(image, mean=mean, std=std)
5616+
torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=1e-7, atol=1e-7)
56545617

56555618

56565619
class TestClampBoundingBoxes:

torchvision/transforms/v2/functional/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@
153153
gaussian_noise_image,
154154
gaussian_noise_video,
155155
normalize,
156-
normalize_cvcuda,
157156
normalize_image,
158157
normalize_video,
159158
sanitize_bounding_boxes,

torchvision/transforms/v2/functional/_misc.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Optional, Sequence, TYPE_CHECKING
2+
from typing import Optional, TYPE_CHECKING
33

44
import PIL.Image
55
import torch
@@ -79,15 +79,22 @@ def normalize_video(video: torch.Tensor, mean: list[float], std: list[float], in
7979
return normalize_image(video, mean, std, inplace=inplace)
8080

8181

82-
def normalize_cvcuda(
82+
def _normalize_cvcuda(
8383
image: "cvcuda.Tensor",
84-
mean: Sequence[float | int] | float | int,
85-
std: Sequence[float | int] | float | int,
84+
mean: list[float],
85+
std: list[float],
8686
inplace: bool = False,
8787
) -> "cvcuda.Tensor":
88+
cvcuda = _import_cvcuda()
8889
if inplace:
8990
raise ValueError("Inplace normalization is not supported for CVCUDA.")
9091

92+
# CV-CUDA supports signed int and float tensors
93+
# torchvision only supports uint and float, right now CV-CUDA doesnt expose float16, so only check 32
94+
# in the future add float16 once exposed in CV-CUDA
95+
if not (image.dtype == cvcuda.Type.F32):
96+
raise ValueError(f"Input tensor should be a float tensor. Got {image.dtype}.")
97+
9198
channels = image.shape[3]
9299
if isinstance(mean, float | int):
93100
mean = [mean] * channels
@@ -115,7 +122,7 @@ def normalize_cvcuda(
115122

116123

117124
if CVCUDA_AVAILABLE:
118-
_normalize_cvcuda = _register_kernel_internal(normalize, cvcuda.Tensor)(normalize_cvcuda)
125+
_normalize_cvcuda_registered = _register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_cvcuda)
119126

120127

121128
def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor:

0 commit comments

Comments
 (0)