11# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4- import pytest
5-
64import numpy as np
7-
5+ import pytest
86from cuda .core import (
97 Device ,
10- TensorMapDescriptor ,
118 TensorMapDataType ,
9+ TensorMapDescriptor ,
1210 TensorMapIm2ColWideMode ,
1311 TensorMapInterleave ,
1412 TensorMapL2Promotion ,
@@ -28,14 +26,14 @@ def skip_if_no_tma(dev):
2826 pytest .skip ("Device does not support TMA (requires compute capability 9.0+)" )
2927
3028
31-
3229class _DeviceArray :
3330 """Wrap a Buffer with explicit shape via __cuda_array_interface__.
3431
3532 dev.allocate() returns a 1D byte buffer. For multi-dimensional TMA tests
3633 we need the tensor to report a proper shape/dtype so the TMA encoder sees
3734 the correct rank, dimensions, and strides.
3835 """
36+
3937 def __init__ (self , buf , shape , dtype = np .float32 ):
4038 self ._buf = buf # prevent GC
4139 self .__cuda_array_interface__ = {
@@ -225,25 +223,30 @@ def test_invalid_data_type(self, dev, skip_if_no_tma):
225223class TestTensorMapDtypeMapping :
226224 """Test automatic dtype inference from numpy dtypes."""
227225
228- @pytest .mark .parametrize ("np_dtype,expected_tma_dt" , [
229- (np .uint8 , TensorMapDataType .UINT8 ),
230- (np .uint16 , TensorMapDataType .UINT16 ),
231- (np .uint32 , TensorMapDataType .UINT32 ),
232- (np .int32 , TensorMapDataType .INT32 ),
233- (np .uint64 , TensorMapDataType .UINT64 ),
234- (np .int64 , TensorMapDataType .INT64 ),
235- (np .float16 , TensorMapDataType .FLOAT16 ),
236- (np .float32 , TensorMapDataType .FLOAT32 ),
237- (np .float64 , TensorMapDataType .FLOAT64 ),
238- ])
226+ @pytest .mark .parametrize (
227+ "np_dtype,expected_tma_dt" ,
228+ [
229+ (np .uint8 , TensorMapDataType .UINT8 ),
230+ (np .uint16 , TensorMapDataType .UINT16 ),
231+ (np .uint32 , TensorMapDataType .UINT32 ),
232+ (np .int32 , TensorMapDataType .INT32 ),
233+ (np .uint64 , TensorMapDataType .UINT64 ),
234+ (np .int64 , TensorMapDataType .INT64 ),
235+ (np .float16 , TensorMapDataType .FLOAT16 ),
236+ (np .float32 , TensorMapDataType .FLOAT32 ),
237+ (np .float64 , TensorMapDataType .FLOAT64 ),
238+ ],
239+ )
239240 def test_dtype_mapping (self , np_dtype , expected_tma_dt , dev , skip_if_no_tma ):
240241 from cuda .core ._tensor_map import _NUMPY_DTYPE_TO_TMA
242+
241243 assert _NUMPY_DTYPE_TO_TMA [np .dtype (np_dtype )] == expected_tma_dt
242244
243245 def test_bfloat16_mapping (self ):
244246 try :
245- from ml_dtypes import bfloat16
246247 from cuda .core ._tensor_map import _NUMPY_DTYPE_TO_TMA
248+ from ml_dtypes import bfloat16
249+
247250 assert _NUMPY_DTYPE_TO_TMA [np .dtype (bfloat16 )] == TensorMapDataType .BFLOAT16
248251 except ImportError :
249252 pytest .skip ("ml_dtypes not installed" )
0 commit comments