|
4 | 4 | from collections.abc import Sequence |
5 | 5 | from typing import Any, Optional, TYPE_CHECKING, Union |
6 | 6 |
|
| 7 | +import numpy as np |
7 | 8 | import PIL.Image |
8 | 9 | import torch |
9 | 10 | from torch.nn.functional import grid_sample, interpolate, pad as torch_pad |
@@ -2529,6 +2530,111 @@ def elastic_video( |
2529 | 2530 | return elastic_image(video, displacement, interpolation=interpolation, fill=fill) |
2530 | 2531 |
|
2531 | 2532 |
|
| 2533 | +if CVCUDA_AVAILABLE: |
| 2534 | + _cvcuda_interp = { |
| 2535 | + InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR, |
| 2536 | + "bilinear": cvcuda.Interp.LINEAR, |
| 2537 | + "linear": cvcuda.Interp.LINEAR, |
| 2538 | + 2: cvcuda.Interp.LINEAR, |
| 2539 | + InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC, |
| 2540 | + "bicubic": cvcuda.Interp.CUBIC, |
| 2541 | + 3: cvcuda.Interp.CUBIC, |
| 2542 | + InterpolationMode.NEAREST: cvcuda.Interp.NEAREST, |
| 2543 | + "nearest": cvcuda.Interp.NEAREST, |
| 2544 | + 0: cvcuda.Interp.NEAREST, |
| 2545 | + InterpolationMode.BOX: cvcuda.Interp.BOX, |
| 2546 | + "box": cvcuda.Interp.BOX, |
| 2547 | + 4: cvcuda.Interp.BOX, |
| 2548 | + InterpolationMode.HAMMING: cvcuda.Interp.HAMMING, |
| 2549 | + "hamming": cvcuda.Interp.HAMMING, |
| 2550 | + 5: cvcuda.Interp.HAMMING, |
| 2551 | + InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS, |
| 2552 | + "lanczos": cvcuda.Interp.LANCZOS, |
| 2553 | + 1: cvcuda.Interp.LANCZOS, |
| 2554 | + } |
| 2555 | + |
| 2556 | + |
| 2557 | +def _elastic_cvcuda( |
| 2558 | + image: "cvcuda.Tensor", |
| 2559 | + displacement: torch.Tensor, |
| 2560 | + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, |
| 2561 | + fill: _FillTypeJIT = None, |
| 2562 | +) -> "cvcuda.Tensor": |
| 2563 | + if not isinstance(displacement, torch.Tensor): |
| 2564 | + raise TypeError("Argument displacement should be a Tensor") |
| 2565 | + |
| 2566 | + # Input image is NHWC format: (N, H, W, C) |
| 2567 | + batch_size, height, width, num_channels = image.shape |
| 2568 | + device = torch.device("cuda") |
| 2569 | + dtype = torch.float32 |
| 2570 | + |
| 2571 | + expected_shape = (1, height, width, 2) |
| 2572 | + if expected_shape != displacement.shape: |
| 2573 | + raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") |
| 2574 | + |
| 2575 | + # cvcuda.remap only supports uint8 for 3-channel images, float32 for 1-channel |
| 2576 | + input_dtype = image.dtype |
| 2577 | + if num_channels == 3 and input_dtype != cvcuda.Type.U8: |
| 2578 | + raise ValueError(f"cvcuda.remap requires uint8 dtype for 3-channel images, but got {input_dtype}") |
| 2579 | + elif num_channels == 1 and input_dtype != cvcuda.Type.F32: |
| 2580 | + raise ValueError(f"cvcuda.remap requires float32 dtype for 1-channel images, but got {input_dtype}") |
| 2581 | + |
| 2582 | + # Build normalized grid: identity + displacement |
| 2583 | + # _create_identity_grid returns (1, H, W, 2) with values in [-1, 1] |
| 2584 | + identity_grid = _create_identity_grid((height, width), device=device, dtype=dtype) |
| 2585 | + grid = identity_grid.add_(displacement.to(dtype=dtype, device=device)) |
| 2586 | + |
| 2587 | + # Convert normalized grid [-1, 1] to absolute pixel coordinates [0, width-1], [0, height-1] |
| 2588 | + # grid[..., 0] is x (horizontal), grid[..., 1] is y (vertical) |
| 2589 | + map_x = (grid[..., 0] + 1) * (width - 1) / 2.0 |
| 2590 | + map_y = (grid[..., 1] + 1) * (height - 1) / 2.0 |
| 2591 | + |
| 2592 | + # Stack into (1, H, W, 2) map tensor |
| 2593 | + pixel_map = torch.stack([map_x, map_y], dim=-1) |
| 2594 | + |
| 2595 | + # Expand map for batch if needed |
| 2596 | + if batch_size > 1: |
| 2597 | + pixel_map = pixel_map.expand(batch_size, -1, -1, -1) |
| 2598 | + |
| 2599 | + # Create cvcuda map tensor (NHWC layout with 2 channels for x,y) |
| 2600 | + cv_map = cvcuda.as_tensor(pixel_map.contiguous(), "NHWC") |
| 2601 | + |
| 2602 | + # Resolve interpolation |
| 2603 | + src_interp = _cvcuda_interp.get(interpolation, cvcuda.Interp.LINEAR) |
| 2604 | + |
| 2605 | + # Resolve border mode and value |
| 2606 | + if fill is None: |
| 2607 | + border_mode = cvcuda.Border.CONSTANT |
| 2608 | + border_value = np.array([], dtype=np.float32) |
| 2609 | + elif isinstance(fill, (int, float)): |
| 2610 | + border_mode = cvcuda.Border.CONSTANT |
| 2611 | + border_value = np.array([fill], dtype=np.float32) |
| 2612 | + elif isinstance(fill, (list, tuple)): |
| 2613 | + border_mode = cvcuda.Border.CONSTANT |
| 2614 | + border_value = np.array(fill, dtype=np.float32) |
| 2615 | + else: |
| 2616 | + border_mode = cvcuda.Border.CONSTANT |
| 2617 | + border_value = np.array([], dtype=np.float32) |
| 2618 | + |
| 2619 | + # Call cvcuda.remap |
| 2620 | + output = cvcuda.remap( |
| 2621 | + image, |
| 2622 | + cv_map, |
| 2623 | + src_interp=src_interp, |
| 2624 | + map_interp=cvcuda.Interp.LINEAR, |
| 2625 | + map_type=cvcuda.Remap.ABSOLUTE, |
| 2626 | + align_corners=False, |
| 2627 | + border=border_mode, |
| 2628 | + border_value=border_value, |
| 2629 | + ) |
| 2630 | + |
| 2631 | + return output |
| 2632 | + |
| 2633 | + |
| 2634 | +if CVCUDA_AVAILABLE: |
| 2635 | + _elastic_cvcuda = _register_kernel_internal(elastic, cvcuda.Tensor)(_elastic_cvcuda) |
| 2636 | + |
| 2637 | + |
2532 | 2638 | def center_crop(inpt: torch.Tensor, output_size: list[int]) -> torch.Tensor: |
2533 | 2639 | """See :class:`~torchvision.transforms.v2.RandomCrop` for details.""" |
2534 | 2640 | if torch.jit.is_scripting(): |
|
0 commit comments