Skip to content

Commit f8aab31

Browse files
committed
wip elastic
1 parent fbea584 commit f8aab31

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

test/test_transforms_v2.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3355,6 +3355,9 @@ def test_kernel_video(self):
33553355
make_segmentation_mask,
33563356
make_video,
33573357
make_keypoints,
3358+
pytest.param(
3359+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
3360+
),
33583361
],
33593362
)
33603363
def test_functional(self, make_input):
@@ -3370,9 +3373,16 @@ def test_functional(self, make_input):
33703373
(F.elastic_mask, tv_tensors.Mask),
33713374
(F.elastic_video, tv_tensors.Video),
33723375
(F.elastic_keypoints, tv_tensors.KeyPoints),
3376+
pytest.param(
3377+
F._geometry._elastic_cvcuda,
3378+
"cvcuda.Tensor",
3379+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available"),
3380+
),
33733381
],
33743382
)
33753383
def test_functional_signature(self, kernel, input_type):
3384+
if input_type == "cvcuda.Tensor":
3385+
input_type = _import_cvcuda().Tensor
33763386
check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type)
33773387

33783388
@pytest.mark.parametrize(
@@ -3385,6 +3395,9 @@ def test_functional_signature(self, kernel, input_type):
33853395
make_segmentation_mask,
33863396
make_video,
33873397
make_keypoints,
3398+
pytest.param(
3399+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
3400+
),
33883401
],
33893402
)
33903403
def test_displacement_error(self, make_input):
@@ -3406,6 +3419,9 @@ def test_displacement_error(self, make_input):
34063419
make_segmentation_mask,
34073420
make_video,
34083421
make_keypoints,
3422+
pytest.param(
3423+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
3424+
),
34093425
],
34103426
)
34113427
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Sequence
55
from typing import Any, Optional, TYPE_CHECKING, Union
66

7+
import numpy as np
78
import PIL.Image
89
import torch
910
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
@@ -2529,6 +2530,111 @@ def elastic_video(
25292530
return elastic_image(video, displacement, interpolation=interpolation, fill=fill)
25302531

25312532

2533+
if CVCUDA_AVAILABLE:
2534+
_cvcuda_interp = {
2535+
InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR,
2536+
"bilinear": cvcuda.Interp.LINEAR,
2537+
"linear": cvcuda.Interp.LINEAR,
2538+
2: cvcuda.Interp.LINEAR,
2539+
InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC,
2540+
"bicubic": cvcuda.Interp.CUBIC,
2541+
3: cvcuda.Interp.CUBIC,
2542+
InterpolationMode.NEAREST: cvcuda.Interp.NEAREST,
2543+
"nearest": cvcuda.Interp.NEAREST,
2544+
0: cvcuda.Interp.NEAREST,
2545+
InterpolationMode.BOX: cvcuda.Interp.BOX,
2546+
"box": cvcuda.Interp.BOX,
2547+
4: cvcuda.Interp.BOX,
2548+
InterpolationMode.HAMMING: cvcuda.Interp.HAMMING,
2549+
"hamming": cvcuda.Interp.HAMMING,
2550+
5: cvcuda.Interp.HAMMING,
2551+
InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS,
2552+
"lanczos": cvcuda.Interp.LANCZOS,
2553+
1: cvcuda.Interp.LANCZOS,
2554+
}
2555+
2556+
2557+
def _elastic_cvcuda(
2558+
image: "cvcuda.Tensor",
2559+
displacement: torch.Tensor,
2560+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2561+
fill: _FillTypeJIT = None,
2562+
) -> "cvcuda.Tensor":
2563+
if not isinstance(displacement, torch.Tensor):
2564+
raise TypeError("Argument displacement should be a Tensor")
2565+
2566+
# Input image is NHWC format: (N, H, W, C)
2567+
batch_size, height, width, num_channels = image.shape
2568+
device = torch.device("cuda")
2569+
dtype = torch.float32
2570+
2571+
expected_shape = (1, height, width, 2)
2572+
if expected_shape != displacement.shape:
2573+
raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
2574+
2575+
# cvcuda.remap only supports uint8 for 3-channel images, float32 for 1-channel
2576+
input_dtype = image.dtype
2577+
if num_channels == 3 and input_dtype != cvcuda.Type.U8:
2578+
raise ValueError(f"cvcuda.remap requires uint8 dtype for 3-channel images, but got {input_dtype}")
2579+
elif num_channels == 1 and input_dtype != cvcuda.Type.F32:
2580+
raise ValueError(f"cvcuda.remap requires float32 dtype for 1-channel images, but got {input_dtype}")
2581+
2582+
# Build normalized grid: identity + displacement
2583+
# _create_identity_grid returns (1, H, W, 2) with values in [-1, 1]
2584+
identity_grid = _create_identity_grid((height, width), device=device, dtype=dtype)
2585+
grid = identity_grid.add_(displacement.to(dtype=dtype, device=device))
2586+
2587+
# Convert normalized grid [-1, 1] to absolute pixel coordinates [0, width-1], [0, height-1]
2588+
# grid[..., 0] is x (horizontal), grid[..., 1] is y (vertical)
2589+
map_x = (grid[..., 0] + 1) * (width - 1) / 2.0
2590+
map_y = (grid[..., 1] + 1) * (height - 1) / 2.0
2591+
2592+
# Stack into (1, H, W, 2) map tensor
2593+
pixel_map = torch.stack([map_x, map_y], dim=-1)
2594+
2595+
# Expand map for batch if needed
2596+
if batch_size > 1:
2597+
pixel_map = pixel_map.expand(batch_size, -1, -1, -1)
2598+
2599+
# Create cvcuda map tensor (NHWC layout with 2 channels for x,y)
2600+
cv_map = cvcuda.as_tensor(pixel_map.contiguous(), "NHWC")
2601+
2602+
# Resolve interpolation
2603+
src_interp = _cvcuda_interp.get(interpolation, cvcuda.Interp.LINEAR)
2604+
2605+
# Resolve border mode and value
2606+
if fill is None:
2607+
border_mode = cvcuda.Border.CONSTANT
2608+
border_value = np.array([], dtype=np.float32)
2609+
elif isinstance(fill, (int, float)):
2610+
border_mode = cvcuda.Border.CONSTANT
2611+
border_value = np.array([fill], dtype=np.float32)
2612+
elif isinstance(fill, (list, tuple)):
2613+
border_mode = cvcuda.Border.CONSTANT
2614+
border_value = np.array(fill, dtype=np.float32)
2615+
else:
2616+
border_mode = cvcuda.Border.CONSTANT
2617+
border_value = np.array([], dtype=np.float32)
2618+
2619+
# Call cvcuda.remap
2620+
output = cvcuda.remap(
2621+
image,
2622+
cv_map,
2623+
src_interp=src_interp,
2624+
map_interp=cvcuda.Interp.LINEAR,
2625+
map_type=cvcuda.Remap.ABSOLUTE,
2626+
align_corners=False,
2627+
border=border_mode,
2628+
border_value=border_value,
2629+
)
2630+
2631+
return output
2632+
2633+
2634+
if CVCUDA_AVAILABLE:
2635+
_elastic_cvcuda = _register_kernel_internal(elastic, cvcuda.Tensor)(_elastic_cvcuda)
2636+
2637+
25322638
def center_crop(inpt: torch.Tensor, output_size: list[int]) -> torch.Tensor:
25332639
"""See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
25342640
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)