@@ -5825,7 +5825,18 @@ def test_kernel_image(self, dtype, device):
58255825 def test_kernel_video (self ):
58265826 check_kernel (F .invert_video , make_video ())
58275827
5828- @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image , make_image_pil , make_video ])
5828+ @pytest .mark .parametrize (
5829+ "make_input" ,
5830+ [
5831+ make_image_tensor ,
5832+ make_image ,
5833+ make_image_pil ,
5834+ make_video ,
5835+ pytest .param (
5836+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
5837+ ),
5838+ ],
5839+ )
58295840 def test_functional (self , make_input ):
58305841 check_functional (F .invert , make_input ())
58315842
@@ -5836,12 +5847,30 @@ def test_functional(self, make_input):
58365847 (F ._color ._invert_image_pil , PIL .Image .Image ),
58375848 (F .invert_image , tv_tensors .Image ),
58385849 (F .invert_video , tv_tensors .Video ),
5850+ pytest .param (
5851+ F ._color ._invert_cvcuda ,
5852+ "cvcuda.Tensor" ,
5853+ marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" ),
5854+ ),
58395855 ],
58405856 )
58415857 def test_functional_signature (self , kernel , input_type ):
5858+ if input_type == "cvcuda.Tensor" :
5859+ input_type = _import_cvcuda ().Tensor
58425860 check_functional_kernel_signature_match (F .invert , kernel = kernel , input_type = input_type )
58435861
5844- @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image_pil , make_image , make_video ])
5862+ @pytest .mark .parametrize (
5863+ "make_input" ,
5864+ [
5865+ make_image_tensor ,
5866+ make_image_pil ,
5867+ make_image ,
5868+ make_video ,
5869+ pytest .param (
5870+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
5871+ ),
5872+ ],
5873+ )
58455874 def test_transform (self , make_input ):
58465875 check_transform (transforms .RandomInvert (p = 1 ), make_input ())
58475876
@@ -5854,6 +5883,16 @@ def test_correctness_image(self, fn):
58545883
58555884 assert_equal (actual , expected )
58565885
5886+ @pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
5887+ @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .float32 ])
5888+ @pytest .mark .parametrize ("fn" , [F .invert , transform_cls_to_functional (transforms .RandomInvert , p = 1 )])
5889+ def test_correctness_cvcuda (self , dtype , fn ):
5890+ image = make_image (batch_dims = (1 ,), dtype = dtype , device = "cuda" )
5891+ cv_image = F .to_cvcuda_tensor (image )
5892+ actual = F .cvcuda_to_tensor (fn (cv_image ))
5893+ expected = F .invert_image (image )
5894+ assert_equal (actual , expected )
5895+
58575896
58585897class TestPosterize :
58595898 @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .float32 ])
0 commit comments