Skip to content

Commit e0e868a

Browse files
committed
Address comments from PR
1 parent 8c7ea2e commit e0e868a

File tree

4 files changed

+71
-52
lines changed

4 files changed

+71
-52
lines changed

cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import cython
1818
# Extern
1919
###############################################################################
2020

21+
from .utils import NotSupportedError
22+
2123
cdef extern from "<dlfcn.h>" nogil:
2224
void* dlopen(const char*, int)
2325
char* dlerror()
@@ -32,14 +34,31 @@ cdef extern from "<dlfcn.h>" nogil:
3234

3335
const void* RTLD_DEFAULT 'RTLD_DEFAULT'
3436

37+
cdef int get_cuda_version():
38+
cdef void* handle = NULL
39+
cdef int err, driver_ver = 0
40+
41+
# Load driver to check version
42+
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
43+
if handle == NULL:
44+
err_msg = dlerror()
45+
raise NotSupportedError(f'CUDA driver is not found ({err_msg.decode()})')
46+
cuDriverGetVersion = dlsym(handle, "cuDriverGetVersion")
47+
if cuDriverGetVersion == NULL:
48+
raise RuntimeError('something went wrong')
49+
err = (<int (*)(int*) noexcept nogil>cuDriverGetVersion)(&driver_ver)
50+
if err != 0:
51+
raise RuntimeError('something went wrong')
52+
53+
return driver_ver
54+
3555

3656
###############################################################################
3757
# Wrapper init
3858
###############################################################################
3959

4060
cdef object __symbol_lock = threading.Lock()
4161
cdef bint __py_cufile_init = False
42-
cdef void* __cuDriverGetVersion = NULL
4362

4463
cdef void* __cuFileHandleRegister = NULL
4564
cdef void* __cuFileHandleDeregister = NULL
@@ -97,24 +116,9 @@ cdef int _check_or_init_cufile() except -1 nogil:
97116
return 0
98117

99118
cdef void* handle = NULL
100-
cdef int err, driver_ver = 0
101119

102120
with gil, __symbol_lock:
103-
# Load driver to check version
104-
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
105-
if handle == NULL:
106-
err_msg = dlerror()
107-
raise NotSupportedError(f'CUDA driver is not found ({err_msg.decode()})')
108-
global __cuDriverGetVersion
109-
if __cuDriverGetVersion == NULL:
110-
__cuDriverGetVersion = dlsym(handle, "cuDriverGetVersion")
111-
if __cuDriverGetVersion == NULL:
112-
raise RuntimeError('something went wrong')
113-
err = (<int (*)(int*) noexcept nogil>__cuDriverGetVersion)(&driver_ver)
114-
if err != 0:
115-
raise RuntimeError('something went wrong')
116-
#dlclose(handle)
117-
handle = NULL
121+
driver_ver = get_cuda_version()
118122

119123
# Load function
120124
global __cuFileHandleRegister

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ from cuda.pathfinder import load_nvidia_dynamic_lib
1313

1414
from libc.stddef cimport wchar_t
1515
from libc.stdint cimport uintptr_t
16-
from cpython cimport PyUnicode_AsWideCharString
16+
from cpython cimport PyUnicode_AsWideCharString, PyMem_Free
1717

1818
from .utils import NotSupportedError
1919

@@ -23,6 +23,7 @@ cdef extern from "windows.h" nogil:
2323
ctypedef void* FARPROC
2424
ctypedef unsigned long DWORD
2525
ctypedef const wchar_t *LPCWSTR
26+
ctypedef const char *LPCSTR
2627

2728
cdef DWORD LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
2829
cdef DWORD LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
@@ -36,19 +37,23 @@ cdef extern from "windows.h" nogil:
3637

3738
HMODULE _LoadLibraryW "LoadLibraryW"(LPCWSTR lpLibFileName)
3839

39-
FARPROC _GetProcAddress "GetProcAddress"(HMODULE hModule, const char* lpProcName)
40+
FARPROC _GetProcAddress "GetProcAddress"(HMODULE hModule, LPCSTR lpProcName)
4041

4142
HMODULE _GetModuleHandleW "GetModuleHandleW"(LPCWSTR lpModuleName)
4243

43-
cdef inline uintptr_t LoadLibraryExW(str path, HANDLE hFile, DWORD dwFlags) nogil:
44-
cdef wchar_t* wpath
45-
with gil:
46-
wpath = PyUnicode_AsWideCharString(path, NULL)
47-
return <uintptr_t>_LoadLibraryExW(
48-
wpath,
49-
hFile,
50-
dwFlags
51-
)
44+
cdef inline uintptr_t LoadLibraryExW(str path, HANDLE hFile, DWORD dwFlags):
45+
cdef uintptr_t result
46+
cdef wchar_t* wpath = PyUnicode_AsWideCharString(path, NULL)
47+
if wpath == NULL:
48+
raise
49+
with nogil:
50+
result = <uintptr_t>_LoadLibraryExW(
51+
wpath,
52+
hFile,
53+
dwFlags
54+
)
55+
PyMem_Free(wpath)
56+
return result
5257

5358
cdef inline uintptr_t LoadLibraryW(str path) nogil:
5459
cdef wchar_t* wpath

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ from cuda.pathfinder import load_nvidia_dynamic_lib
1313

1414
from libc.stddef cimport wchar_t
1515
from libc.stdint cimport uintptr_t
16-
from cpython cimport PyUnicode_AsWideCharString
16+
from cpython cimport PyUnicode_AsWideCharString, PyMem_Free
1717

1818
from .utils import NotSupportedError
1919

@@ -23,6 +23,7 @@ cdef extern from "windows.h" nogil:
2323
ctypedef void* FARPROC
2424
ctypedef unsigned long DWORD
2525
ctypedef const wchar_t *LPCWSTR
26+
ctypedef const char *LPCSTR
2627

2728
cdef DWORD LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
2829
cdef DWORD LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
@@ -36,19 +37,23 @@ cdef extern from "windows.h" nogil:
3637

3738
HMODULE _LoadLibraryW "LoadLibraryW"(LPCWSTR lpLibFileName)
3839

39-
FARPROC _GetProcAddress "GetProcAddress"(HMODULE hModule, const char* lpProcName)
40+
FARPROC _GetProcAddress "GetProcAddress"(HMODULE hModule, LPCSTR lpProcName)
4041

4142
HMODULE _GetModuleHandleW "GetModuleHandleW"(LPCWSTR lpModuleName)
4243

43-
cdef inline uintptr_t LoadLibraryExW(str path, HANDLE hFile, DWORD dwFlags) nogil:
44-
cdef wchar_t* wpath
45-
with gil:
46-
wpath = PyUnicode_AsWideCharString(path, NULL)
47-
return <uintptr_t>_LoadLibraryExW(
48-
wpath,
49-
hFile,
50-
dwFlags
51-
)
44+
cdef inline uintptr_t LoadLibraryExW(str path, HANDLE hFile, DWORD dwFlags):
45+
cdef uintptr_t result
46+
cdef wchar_t* wpath = PyUnicode_AsWideCharString(path, NULL)
47+
if wpath == NULL:
48+
raise
49+
with nogil:
50+
result = <uintptr_t>_LoadLibraryExW(
51+
wpath,
52+
hFile,
53+
dwFlags
54+
)
55+
PyMem_Free(wpath)
56+
return result
5257

5358
cdef inline uintptr_t LoadLibraryW(str path) nogil:
5459
cdef wchar_t* wpath

cuda_bindings/cuda/bindings/_lib/windll.pxd

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,39 @@
33

44
from libc.stddef cimport wchar_t
55
from libc.stdint cimport uintptr_t
6-
from cpython cimport PyUnicode_AsWideCharString
6+
from cpython cimport PyUnicode_AsWideCharString, PyMem_Free
77

8-
cdef extern from "windows.h":
8+
cdef extern from "windows.h" nogil:
99
ctypedef void* HMODULE
1010
ctypedef void* HANDLE
1111
ctypedef void* FARPROC
1212
ctypedef unsigned long DWORD
1313
ctypedef const wchar_t *LPCWSTR
14+
ctypedef const char *LPCSTR
1415

1516
cdef DWORD LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
1617

1718
HMODULE _LoadLibraryExW "LoadLibraryExW"(
1819
LPCWSTR lpLibFileName,
1920
HANDLE hFile,
2021
DWORD dwFlags
21-
) nogil
22+
)
2223

23-
FARPROC _GetProcAddress "GetProcAddress"(HMODULE hModule, const char* lpProcName) nogil
24+
FARPROC _GetProcAddress "GetProcAddress"(HMODULE hModule, LPCSTR lpProcName)
2425

25-
cdef inline uintptr_t LoadLibraryExW(str path, HANDLE hFile, DWORD dwFlags) nogil:
26-
cdef wchar_t* wpath
27-
with gil:
28-
wpath = PyUnicode_AsWideCharString(path, NULL)
29-
return <uintptr_t>_LoadLibraryExW(
30-
wpath,
31-
hFile,
32-
dwFlags
33-
)
26+
cdef inline uintptr_t LoadLibraryExW(str path, HANDLE hFile, DWORD dwFlags):
27+
cdef uintptr_t result
28+
cdef wchar_t* wpath = PyUnicode_AsWideCharString(path, NULL)
29+
if wpath is NULL:
30+
raise
31+
with nogil:
32+
result = <uintptr_t>_LoadLibraryExW(
33+
wpath,
34+
hFile,
35+
dwFlags
36+
)
37+
PyMem_Free(wpath)
38+
return result
3439

3540
cdef inline FARPROC GetProcAddress(uintptr_t hModule, const char* lpProcName) nogil:
3641
return _GetProcAddress(<HMODULE>hModule, lpProcName)

0 commit comments

Comments
 (0)