2121import torchvision .transforms .v2 as transforms
2222
2323from common_utils import (
24+ assert_close ,
2425 assert_equal ,
2526 cache ,
2627 cpu_and_cuda ,
4243)
4344
4445from torch import nn
45- from torch .testing import assert_close
4646from torch .utils ._pytree import tree_flatten , tree_map
4747from torch .utils .data import DataLoader , default_collate
4848from torchvision import tv_tensors
@@ -2619,7 +2619,32 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca
26192619 scale = scale ,
26202620 )
26212621
2622- @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image , make_video ])
2622+ @pytest .mark .parametrize (
2623+ ("kernel" , "input_type" ),
2624+ [
2625+ (F .to_dtype_image , torch .Tensor ),
2626+ (F .to_dtype_video , tv_tensors .Video ),
2627+ pytest .param (
2628+ F ._misc ._to_dtype_image_cvcuda ,
2629+ None ,
2630+ marks = pytest .mark .needs_cvcuda ,
2631+ ),
2632+ ],
2633+ )
2634+ def test_functional_signature (self , kernel , input_type ):
2635+ if kernel is F ._misc ._to_dtype_image_cvcuda :
2636+ input_type = _import_cvcuda ().Tensor
2637+ check_functional_kernel_signature_match (F .to_dtype , kernel = kernel , input_type = input_type )
2638+
2639+ @pytest .mark .parametrize (
2640+ "make_input" ,
2641+ [
2642+ make_image_tensor ,
2643+ make_image ,
2644+ make_video ,
2645+ pytest .param (make_image_cvcuda , marks = pytest .mark .needs_cvcuda ),
2646+ ],
2647+ )
26232648 @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float64 , torch .uint8 ])
26242649 @pytest .mark .parametrize ("output_dtype" , [torch .float32 , torch .float64 , torch .uint8 ])
26252650 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
@@ -2634,7 +2659,14 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale):
26342659
26352660 @pytest .mark .parametrize (
26362661 "make_input" ,
2637- [make_image_tensor , make_image , make_bounding_boxes , make_segmentation_mask , make_video ],
2662+ [
2663+ make_image_tensor ,
2664+ make_image ,
2665+ make_bounding_boxes ,
2666+ make_segmentation_mask ,
2667+ make_video ,
2668+ pytest .param (make_image_cvcuda , marks = pytest .mark .needs_cvcuda ),
2669+ ],
26382670 )
26392671 @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float64 , torch .uint8 ])
26402672 @pytest .mark .parametrize ("output_dtype" , [torch .float32 , torch .float64 , torch .uint8 ])
@@ -2680,25 +2712,69 @@ def fn(value):
26802712
26812713 return torch .tensor (tree_map (fn , image .tolist ())).to (dtype = output_dtype , device = image .device )
26822714
2715+ def _get_dtype_conversion_atol_cvcuda (self , input_dtype , output_dtype ):
2716+ in_bits = torch .iinfo (input_dtype ).bits if not input_dtype .is_floating_point else None
2717+ out_bits = torch .iinfo (output_dtype ).bits if not output_dtype .is_floating_point else None
2718+ narrows_bits = in_bits is not None and out_bits is not None and out_bits < in_bits
2719+
2720+ # int->int with narrowing bits, allow atol=1 for rounding diffs
2721+ if narrows_bits :
2722+ atol = 1
2723+ # float->int check for same diff, rounding error on float
2724+ elif input_dtype .is_floating_point and not output_dtype .is_floating_point :
2725+ atol = 1
2726+ # if generating a float value from an int, allow small rounding error
2727+ elif not input_dtype .is_floating_point and output_dtype .is_floating_point :
2728+ atol = 1e-7
2729+ # all other cases, should be exact
2730+ # uint8 -> uint16 promotion would be here
2731+ else :
2732+ atol = 0
2733+
2734+ return atol
2735+
26832736 @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float64 , torch .uint8 , torch .uint16 ])
26842737 @pytest .mark .parametrize ("output_dtype" , [torch .float32 , torch .float64 , torch .uint8 , torch .uint16 ])
26852738 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
26862739 @pytest .mark .parametrize ("scale" , (True , False ))
2687- def test_image_correctness (self , input_dtype , output_dtype , device , scale ):
2740+ @pytest .mark .parametrize (
2741+ "make_input" ,
2742+ [
2743+ make_image ,
2744+ pytest .param (make_image_cvcuda , marks = pytest .mark .needs_cvcuda ),
2745+ ],
2746+ )
2747+ @pytest .mark .parametrize ("fn" , [F .to_dtype , transform_cls_to_functional (transforms .ToDtype )])
2748+ def test_image_correctness (self , input_dtype , output_dtype , device , scale , make_input , fn ):
26882749 if input_dtype .is_floating_point and output_dtype == torch .int64 :
26892750 pytest .xfail ("float to int64 conversion is not supported" )
26902751 if input_dtype == torch .uint8 and output_dtype == torch .uint16 and device == "cuda" :
26912752 pytest .xfail ("uint8 to uint16 conversion is not supported on cuda" )
2753+ if (
2754+ input_dtype == torch .uint16
2755+ and output_dtype == torch .uint8
2756+ and not scale
2757+ and make_input is make_image_cvcuda
2758+ ):
2759+ pytest .xfail ("uint16 to uint8 conversion without scale is not supported for CV-CUDA." )
26922760
2693- input = make_image (dtype = input_dtype , device = device )
2761+ input = make_input (dtype = input_dtype , device = device )
2762+ out = fn (input , dtype = output_dtype , scale = scale )
2763+
2764+ if make_input is make_image_cvcuda :
2765+ input = F .cvcuda_to_tensor (input )
2766+ out = F .cvcuda_to_tensor (out )
26942767
2695- out = F .to_dtype (input , dtype = output_dtype , scale = scale )
26962768 expected = self .reference_convert_dtype_image_tensor (input , dtype = output_dtype , scale = scale )
26972769
2698- if input_dtype .is_floating_point and not output_dtype .is_floating_point and scale :
2699- torch .testing .assert_close (out , expected , atol = 1 , rtol = 0 )
2700- else :
2701- torch .testing .assert_close (out , expected )
2770+ atol , rtol = None , None
2771+ if make_input is make_image_cvcuda :
2772+ atol = self ._get_dtype_conversion_atol_cvcuda (input_dtype , output_dtype )
2773+ rtol = 0
2774+ elif input_dtype .is_floating_point and not output_dtype .is_floating_point and scale :
2775+ atol , rtol = 1 , 0
2776+
2777+ torch .testing .assert_close (out , expected , atol = atol , rtol = rtol )
27022778
27032779 def was_scaled (self , inpt ):
27042780 # this assumes the target dtype is float
0 commit comments