Skip to content

Commit e9b378f

Browse files
committed
Privatize the Python-imported symbols in utils.pxi.in
1 parent 8636340 commit e9b378f

File tree

1 file changed

+42
-43
lines changed

1 file changed

+42
-43
lines changed

cuda_bindings/cuda/bindings/_lib/utils.pxi.in

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@ from libc.stdlib cimport calloc, free
66
from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t
77
from libc.stddef cimport wchar_t
88
from libc.string cimport memcpy
9-
from enum import Enum
10-
from typing import List, Tuple
11-
import ctypes
9+
from enum import Enum as _Enum
10+
import ctypes as _ctypes
1211
cimport cuda.bindings.cydriver as cydriver
13-
import cuda.bindings.driver as driver
12+
import cuda.bindings.driver as _driver
1413
cimport cuda.bindings._lib.param_packer as param_packer
1514

1615
cdef void* _callocWrapper(length, size):
@@ -21,25 +20,25 @@ cdef void* _callocWrapper(length, size):
2120

2221
cdef class _HelperKernelParams:
2322
supported_types = { # excluding void_p and None, which are handled specially
24-
ctypes.c_bool,
25-
ctypes.c_char,
26-
ctypes.c_wchar,
27-
ctypes.c_byte,
28-
ctypes.c_ubyte,
29-
ctypes.c_short,
30-
ctypes.c_ushort,
31-
ctypes.c_int,
32-
ctypes.c_uint,
33-
ctypes.c_long,
34-
ctypes.c_ulong,
35-
ctypes.c_longlong,
36-
ctypes.c_ulonglong,
37-
ctypes.c_size_t,
38-
ctypes.c_float,
39-
ctypes.c_double
23+
_ctypes.c_bool,
24+
_ctypes.c_char,
25+
_ctypes.c_wchar,
26+
_ctypes.c_byte,
27+
_ctypes.c_ubyte,
28+
_ctypes.c_short,
29+
_ctypes.c_ushort,
30+
_ctypes.c_int,
31+
_ctypes.c_uint,
32+
_ctypes.c_long,
33+
_ctypes.c_ulong,
34+
_ctypes.c_longlong,
35+
_ctypes.c_ulonglong,
36+
_ctypes.c_size_t,
37+
_ctypes.c_float,
38+
_ctypes.c_double
4039
}
4140

42-
max_param_size = max(ctypes.sizeof(max(_HelperKernelParams.supported_types, key=lambda t:ctypes.sizeof(t))), sizeof(void_ptr))
41+
max_param_size = max(_ctypes.sizeof(max(_HelperKernelParams.supported_types, key=lambda t:_ctypes.sizeof(t))), sizeof(void_ptr))
4342

4443
def __cinit__(self, kernelParams):
4544
self._pyobj_acquired = False
@@ -56,7 +55,7 @@ cdef class _HelperKernelParams:
5655
raise RuntimeError("Argument 'kernelParams' failed to retrieve buffer through Buffer Protocol")
5756
self._pyobj_acquired = True
5857
self._ckernelParams = <void**><void_ptr>self._pybuffer.buf
59-
elif isinstance(kernelParams, (Tuple)) and len(kernelParams) == 2 and isinstance(kernelParams[0], (Tuple)) and isinstance(kernelParams[1], (Tuple)):
58+
elif isinstance(kernelParams, (tuple)) and len(kernelParams) == 2 and isinstance(kernelParams[0], (tuple)) and isinstance(kernelParams[1], (tuple)):
6059
# Hard run, construct and fill out contigues memory using provided kernel values and types based
6160
if len(kernelParams[0]) != len(kernelParams[1]):
6261
raise TypeError("Argument 'kernelParams' has tuples with different length")
@@ -73,44 +72,44 @@ cdef class _HelperKernelParams:
7372
# special cases for None
7473
if callable(getattr(value, 'getPtr', None)):
7574
self._ckernelParams[idx] = <void*><void_ptr>value.getPtr()
76-
elif isinstance(value, (ctypes.Structure)):
77-
self._ckernelParams[idx] = <void*><void_ptr>ctypes.addressof(value)
78-
elif isinstance(value, (Enum)):
75+
elif isinstance(value, (_ctypes.Structure)):
76+
self._ckernelParams[idx] = <void*><void_ptr>_ctypes.addressof(value)
77+
elif isinstance(value, (_Enum)):
7978
self._ckernelParams[idx] = &(self._ckernelParamsData[data_idx])
8079
(<int*>self._ckernelParams[idx])[0] = value.value
8180
data_idx += sizeof(int)
8281
else:
83-
raise TypeError("Provided argument is of type {} but expected Type {}, {} or CUDA Binding structure with getPtr() attribute".format(type(value), type(ctypes.Structure), type(ctypes.c_void_p)))
82+
raise TypeError("Provided argument is of type {} but expected Type {}, {} or CUDA Binding structure with getPtr() attribute".format(type(value), type(_ctypes.Structure), type(_ctypes.c_void_p)))
8483
elif ctype in _HelperKernelParams.supported_types:
8584
self._ckernelParams[idx] = &(self._ckernelParamsData[data_idx])
8685

8786
# handle case where a float is passed as a double
88-
if ctype == ctypes.c_double and isinstance(value, ctypes.c_float):
87+
if ctype == _ctypes.c_double and isinstance(value, _ctypes.c_float):
8988
value = ctype(value.value)
9089
if not isinstance(value, ctype): # make it a ctype
9190
size = param_packer.feed(self._ckernelParams[idx], value, ctype)
9291
if size == 0: # feed failed
9392
value = ctype(value)
94-
size = ctypes.sizeof(ctype)
95-
addr = <void*>(<void_ptr>ctypes.addressof(value))
93+
size = _ctypes.sizeof(ctype)
94+
addr = <void*>(<void_ptr>_ctypes.addressof(value))
9695
memcpy(self._ckernelParams[idx], addr, size)
9796
else:
98-
size = ctypes.sizeof(ctype)
99-
addr = <void*>(<void_ptr>ctypes.addressof(value))
97+
size = _ctypes.sizeof(ctype)
98+
addr = <void*>(<void_ptr>_ctypes.addressof(value))
10099
memcpy(self._ckernelParams[idx], addr, size)
101100
data_idx += size
102-
elif ctype == ctypes.c_void_p:
101+
elif ctype == _ctypes.c_void_p:
103102
# special cases for void_p
104-
if isinstance(value, (int, ctypes.c_void_p)):
103+
if isinstance(value, (int, _ctypes.c_void_p)):
105104
self._ckernelParams[idx] = &(self._ckernelParamsData[data_idx])
106-
(<void_ptr*>self._ckernelParams[idx])[0] = value.value if isinstance(value, (ctypes.c_void_p)) else value
105+
(<void_ptr*>self._ckernelParams[idx])[0] = value.value if isinstance(value, (_ctypes.c_void_p)) else value
107106
data_idx += sizeof(void_ptr)
108107
elif callable(getattr(value, 'getPtr', None)):
109108
self._ckernelParams[idx] = &(self._ckernelParamsData[data_idx])
110109
(<void_ptr*>self._ckernelParams[idx])[0] = value.getPtr()
111110
data_idx += sizeof(void_ptr)
112111
else:
113-
raise TypeError("Provided argument is of type {} but expected Type {}, {} or CUDA Binding structure with getPtr() attribute".format(type(value), type(int), type(ctypes.c_void_p)))
112+
raise TypeError("Provided argument is of type {} but expected Type {}, {} or CUDA Binding structure with getPtr() attribute".format(type(value), type(int), type(_ctypes.c_void_p)))
114113
else:
115114
raise TypeError("Unsupported type: " + str(type(ctype)))
116115
idx += 1
@@ -136,7 +135,7 @@ cdef class _HelperInputVoidPtr:
136135
elif isinstance(ptr, (int)):
137136
# Easy run, user gave us an already configured void** address
138137
self._cptr = <void*><void_ptr>ptr
139-
elif isinstance(ptr, (driver.CUdeviceptr)):
138+
elif isinstance(ptr, (_driver.CUdeviceptr)):
140139
self._cptr = <void*><void_ptr>int(ptr)
141140
elif PyObject_CheckBuffer(ptr):
142141
# Easy run, get address from Python Buffer Protocol
@@ -173,7 +172,7 @@ cdef class _HelperCUmemPool_attribute:
173172
{{if 'CU_MEMPOOL_ATTR_USED_MEM_CURRENT'}}cydriver.CUmemPool_attribute_enum.CU_MEMPOOL_ATTR_USED_MEM_CURRENT,{{endif}}
174173
{{if 'CU_MEMPOOL_ATTR_USED_MEM_HIGH'}}cydriver.CUmemPool_attribute_enum.CU_MEMPOOL_ATTR_USED_MEM_HIGH,{{endif}}):
175174
if self._is_getter:
176-
self._cuuint64_t_val = driver.cuuint64_t()
175+
self._cuuint64_t_val = _driver.cuuint64_t()
177176
self._cptr = <void*><void_ptr>self._cuuint64_t_val.getPtr()
178177
else:
179178
self._cptr = <void*><void_ptr>init_value.getPtr()
@@ -244,7 +243,7 @@ cdef class _HelperCUpointer_attribute:
244243
self._attr = attr.value
245244
if self._attr in ({{if 'CU_POINTER_ATTRIBUTE_CONTEXT'}}cydriver.CUpointer_attribute_enum.CU_POINTER_ATTRIBUTE_CONTEXT,{{endif}}):
246245
if self._is_getter:
247-
self._ctx = driver.CUcontext()
246+
self._ctx = _driver.CUcontext()
248247
self._cptr = <void*><void_ptr>self._ctx.getPtr()
249248
else:
250249
self._cptr = <void*><void_ptr>init_value.getPtr()
@@ -258,7 +257,7 @@ cdef class _HelperCUpointer_attribute:
258257
elif self._attr in ({{if 'CU_POINTER_ATTRIBUTE_DEVICE_POINTER'}}cydriver.CUpointer_attribute_enum.CU_POINTER_ATTRIBUTE_DEVICE_POINTER,{{endif}}
259258
{{if 'CU_POINTER_ATTRIBUTE_RANGE_START_ADDR'}}cydriver.CUpointer_attribute_enum.CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,{{endif}}):
260259
if self._is_getter:
261-
self._devptr = driver.CUdeviceptr()
260+
self._devptr = _driver.CUdeviceptr()
262261
self._cptr = <void*><void_ptr>self._devptr.getPtr()
263262
else:
264263
self._cptr = <void*><void_ptr>init_value.getPtr()
@@ -267,7 +266,7 @@ cdef class _HelperCUpointer_attribute:
267266
self._cptr = <void*>&self._void
268267
elif self._attr in ({{if 'CU_POINTER_ATTRIBUTE_P2P_TOKENS'}}cydriver.CUpointer_attribute_enum.CU_POINTER_ATTRIBUTE_P2P_TOKENS,{{endif}}):
269268
if self._is_getter:
270-
self._token = driver.CUDA_POINTER_ATTRIBUTE_P2P_TOKENS()
269+
self._token = _driver.CUDA_POINTER_ATTRIBUTE_P2P_TOKENS()
271270
self._cptr = <void*><void_ptr>self._token.getPtr()
272271
else:
273272
self._cptr = <void*><void_ptr>init_value.getPtr()
@@ -285,7 +284,7 @@ cdef class _HelperCUpointer_attribute:
285284
self._cptr = <void*>&self._size
286285
elif self._attr in ({{if 'CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE'}}cydriver.CUpointer_attribute_enum.CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE,{{endif}}):
287286
if self._is_getter:
288-
self._mempool = driver.CUmemoryPool()
287+
self._mempool = _driver.CUmemoryPool()
289288
self._cptr = <void*><void_ptr>self._mempool.getPtr()
290289
else:
291290
self._cptr = <void*><void_ptr>init_value.getPtr()
@@ -341,7 +340,7 @@ cdef class _HelperCUgraphMem_attribute:
341340
{{if 'CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT' in found_values}}cydriver.CUgraphMem_attribute_enum.CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT,{{endif}}
342341
{{if 'CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH' in found_values}}cydriver.CUgraphMem_attribute_enum.CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH,{{endif}}):
343342
if self._is_getter:
344-
self._cuuint64_t_val = driver.cuuint64_t()
343+
self._cuuint64_t_val = _driver.cuuint64_t()
345344
self._cptr = <void*><void_ptr>self._cuuint64_t_val.getPtr()
346345
else:
347346
self._cptr = <void*><void_ptr>init_value.getPtr()
@@ -554,7 +553,7 @@ cdef class _HelperCUmemAllocationHandleType:
554553
{{endif}}
555554
{{if 'CU_MEM_HANDLE_TYPE_FABRIC' in found_values}}
556555
elif self._type in (cydriver.CUmemAllocationHandleType_enum.CU_MEM_HANDLE_TYPE_FABRIC,):
557-
self._mem_fabric_handle = driver.CUmemFabricHandle()
556+
self._mem_fabric_handle = _driver.CUmemFabricHandle()
558557
self._cptr = <void*><void_ptr>self._mem_fabric_handle.getPtr()
559558
{{endif}}
560559
else:

0 commit comments

Comments
 (0)