Skip to content

Commit 29df2ef

Browse files
committed
simplify elastic cvcuda code more
1 parent bb6056a commit 29df2ef

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2551,7 +2551,6 @@ def _elastic_cvcuda(
25512551
if not isinstance(displacement, torch.Tensor):
25522552
raise TypeError("Argument displacement should be a Tensor")
25532553

2554-
# Input image is NHWC format: (N, H, W, C)
25552554
batch_size, height, width, num_channels = image.shape
25562555
device = torch.device("cuda")
25572556
dtype = torch.float32
@@ -2567,6 +2566,10 @@ def _elastic_cvcuda(
25672566
elif num_channels == 1 and input_dtype != cvcuda.Type.F32:
25682567
raise ValueError(f"cvcuda.remap requires float32 dtype for 1-channel images, but got {input_dtype}")
25692568

2569+
interp = _cvcuda_interp.get(interpolation, cvcuda.Interp.LINEAR)
2570+
if interp is None:
2571+
raise ValueError(f"Invalid interpolation mode: {interpolation}")
2572+
25702573
# Build normalized grid: identity + displacement
25712574
# _create_identity_grid returns (1, H, W, 2) with values in [-1, 1]
25722575
identity_grid = _create_identity_grid((height, width), device=device, dtype=dtype)
@@ -2587,28 +2590,20 @@ def _elastic_cvcuda(
25872590
# Create cvcuda map tensor (NHWC layout with 2 channels for x,y)
25882591
cv_map = cvcuda.as_tensor(pixel_map.contiguous(), "NHWC")
25892592

2590-
# Resolve interpolation
2591-
src_interp = _cvcuda_interp.get(interpolation, cvcuda.Interp.LINEAR)
2592-
2593-
# Resolve border mode and value
2593+
border_mode = cvcuda.Border.CONSTANT
25942594
if fill is None:
2595-
border_mode = cvcuda.Border.CONSTANT
25962595
border_value = np.array([], dtype=np.float32)
25972596
elif isinstance(fill, (int, float)):
2598-
border_mode = cvcuda.Border.CONSTANT
25992597
border_value = np.array([fill], dtype=np.float32)
26002598
elif isinstance(fill, (list, tuple)):
2601-
border_mode = cvcuda.Border.CONSTANT
26022599
border_value = np.array(fill, dtype=np.float32)
26032600
else:
2604-
border_mode = cvcuda.Border.CONSTANT
26052601
border_value = np.array([], dtype=np.float32)
26062602

2607-
# Call cvcuda.remap
26082603
output = cvcuda.remap(
26092604
image,
26102605
cv_map,
2611-
src_interp=src_interp,
2606+
src_interp=interp,
26122607
map_interp=cvcuda.Interp.LINEAR,
26132608
map_type=cvcuda.Remap.ABSOLUTE,
26142609
align_corners=False,

0 commit comments

Comments
 (0)