@@ -98,17 +98,19 @@ def test_from_tiled_1d(self, dev, skip_if_no_tma):
9898
9999 def test_from_tiled_2d (self , dev , skip_if_no_tma ):
100100 buf = dev .allocate (64 * 64 * 4 ) # 64x64 float32
101+ tensor = _DeviceArray (buf , (64 , 64 ))
101102 desc = TensorMapDescriptor .from_tiled (
102- buf ,
103+ tensor ,
103104 box_dim = (32 , 32 ),
104105 data_type = TensorMapDataType .FLOAT32 ,
105106 )
106107 assert desc is not None
107108
108109 def test_from_tiled_3d (self , dev , skip_if_no_tma ):
109110 buf = dev .allocate (16 * 16 * 16 * 4 ) # 16x16x16 float32
111+ tensor = _DeviceArray (buf , (16 , 16 , 16 ))
110112 desc = TensorMapDescriptor .from_tiled (
111- buf ,
113+ tensor ,
112114 box_dim = (8 , 8 , 8 ),
113115 data_type = TensorMapDataType .FLOAT32 ,
114116 )
@@ -128,8 +130,9 @@ def test_from_tiled_5d(self, dev, skip_if_no_tma):
128130
129131 def test_from_tiled_with_swizzle (self , dev , skip_if_no_tma ):
130132 buf = dev .allocate (64 * 64 * 4 )
133+ tensor = _DeviceArray (buf , (64 , 64 ))
131134 desc = TensorMapDescriptor .from_tiled (
132- buf ,
135+ tensor ,
133136 box_dim = (32 , 32 ),
134137 data_type = TensorMapDataType .FLOAT32 ,
135138 swizzle = TensorMapSwizzle .SWIZZLE_128B ,
@@ -138,8 +141,9 @@ def test_from_tiled_with_swizzle(self, dev, skip_if_no_tma):
138141
139142 def test_from_tiled_with_l2_promotion (self , dev , skip_if_no_tma ):
140143 buf = dev .allocate (64 * 64 * 4 )
144+ tensor = _DeviceArray (buf , (64 , 64 ))
141145 desc = TensorMapDescriptor .from_tiled (
142- buf ,
146+ tensor ,
143147 box_dim = (32 , 32 ),
144148 data_type = TensorMapDataType .FLOAT32 ,
145149 l2_promotion = TensorMapL2Promotion .L2_128B ,
@@ -148,8 +152,9 @@ def test_from_tiled_with_l2_promotion(self, dev, skip_if_no_tma):
148152
149153 def test_from_tiled_with_oob_fill (self , dev , skip_if_no_tma ):
150154 buf = dev .allocate (64 * 64 * 4 )
155+ tensor = _DeviceArray (buf , (64 , 64 ))
151156 desc = TensorMapDescriptor .from_tiled (
152- buf ,
157+ tensor ,
153158 box_dim = (32 , 32 ),
154159 data_type = TensorMapDataType .FLOAT32 ,
155160 oob_fill = TensorMapOOBFill .NAN_REQUEST_ZERO_FMA ,
@@ -162,9 +167,10 @@ class TestTensorMapDescriptorValidation:
162167
163168 def test_invalid_rank_zero (self , dev , skip_if_no_tma ):
164169 buf = dev .allocate (64 )
170+ tensor = _DeviceArray (buf , ()) # 0-dim tensor
165171 with pytest .raises (ValueError , match = "rank must be between 1 and 5" ):
166172 TensorMapDescriptor .from_tiled (
167- buf ,
173+ tensor ,
168174 box_dim = (),
169175 data_type = TensorMapDataType .FLOAT32 ,
170176 )
@@ -286,8 +292,9 @@ class TestTensorMapIm2col:
286292 def test_from_im2col_3d (self , dev , skip_if_no_tma ):
287293 # 3D tensor: batch=1, height=32, channels=64
288294 buf = dev .allocate (1 * 32 * 64 * 4 )
295+ tensor = _DeviceArray (buf , (1 , 32 , 64 ))
289296 desc = TensorMapDescriptor .from_im2col (
290- buf ,
297+ tensor ,
291298 pixel_box_lower_corner = (0 ,),
292299 pixel_box_upper_corner = (4 ,),
293300 channels_per_pixel = 64 ,
@@ -310,9 +317,10 @@ def test_from_im2col_rank_validation(self, dev, skip_if_no_tma):
310317
311318 def test_from_im2col_corner_rank_mismatch (self , dev , skip_if_no_tma ):
312319 buf = dev .allocate (1 * 32 * 64 * 4 )
320+ tensor = _DeviceArray (buf , (1 , 32 , 64 )) # 3D: n_spatial = 1
313321 with pytest .raises (ValueError , match = "pixel_box_lower_corner must have 1 elements" ):
314322 TensorMapDescriptor .from_im2col (
315- buf ,
323+ tensor ,
316324 pixel_box_lower_corner = (0 , 0 ),
317325 pixel_box_upper_corner = (4 ,),
318326 channels_per_pixel = 64 ,
0 commit comments