Skip to content

Commit b36d799

Browse files
committed
Fix #962: Don't perform unnecessary version checks
1 parent f317f21 commit b36d799

File tree

5 files changed

+0
-156
lines changed

5 files changed

+0
-156
lines changed

cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,6 @@ cdef extern from "<dlfcn.h>" nogil:
3232

3333
const void* RTLD_DEFAULT 'RTLD_DEFAULT'
3434

35-
cdef int get_cuda_version():
36-
cdef void* handle = NULL
37-
cdef int err, driver_ver = 0
38-
39-
# Load driver to check version
40-
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
41-
if handle == NULL:
42-
err_msg = dlerror()
43-
raise NotSupportedError(f'CUDA driver is not found ({err_msg.decode()})')
44-
cuDriverGetVersion = dlsym(handle, "cuDriverGetVersion")
45-
if cuDriverGetVersion == NULL:
46-
raise RuntimeError('something went wrong')
47-
err = (<int (*)(int*) noexcept nogil>cuDriverGetVersion)(&driver_ver)
48-
if err != 0:
49-
raise RuntimeError('something went wrong')
50-
51-
return driver_ver
52-
5335

5436
###############################################################################
5537
# Wrapper init
@@ -116,8 +98,6 @@ cdef int _check_or_init_cufile() except -1 nogil:
11698
cdef void* handle = NULL
11799

118100
with gil, __symbol_lock:
119-
driver_ver = get_cuda_version()
120-
121101
# Load function
122102
global __cuFileHandleRegister
123103
__cuFileHandleRegister = dlsym(RTLD_DEFAULT, 'cuFileHandleRegister')

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,6 @@ cdef extern from "<dlfcn.h>" nogil:
3030

3131
const void* RTLD_DEFAULT 'RTLD_DEFAULT'
3232

33-
cdef int get_cuda_version():
34-
cdef void* handle = NULL
35-
cdef int err, driver_ver = 0
36-
37-
# Load driver to check version
38-
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
39-
if handle == NULL:
40-
err_msg = dlerror()
41-
raise NotSupportedError(f'CUDA driver is not found ({err_msg.decode()})')
42-
cuDriverGetVersion = dlsym(handle, "cuDriverGetVersion")
43-
if cuDriverGetVersion == NULL:
44-
raise RuntimeError('something went wrong')
45-
err = (<int (*)(int*) noexcept nogil>cuDriverGetVersion)(&driver_ver)
46-
if err != 0:
47-
raise RuntimeError('something went wrong')
48-
49-
return driver_ver
50-
5133

5234
###############################################################################
5335
# Wrapper init
@@ -85,8 +67,6 @@ cdef int _check_or_init_nvjitlink() except -1 nogil:
8567
cdef void* handle = NULL
8668

8769
with gil, __symbol_lock:
88-
driver_ver = get_cuda_version()
89-
9070
# Load function
9171
global __nvJitLinkCreate
9272
__nvJitLinkCreate = dlsym(RTLD_DEFAULT, 'nvJitLinkCreate')

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,64 +11,18 @@ from .utils import FunctionNotFoundError, NotSupportedError
1111

1212
from cuda.pathfinder import load_nvidia_dynamic_lib
1313

14-
from libc.stddef cimport wchar_t
1514
from libc.stdint cimport uintptr_t
16-
from cpython cimport PyUnicode_AsWideCharString, PyMem_Free
17-
18-
from .utils import NotSupportedError
1915

2016
cdef extern from "windows.h" nogil:
2117
ctypedef void* HMODULE
22-
ctypedef void* HANDLE
2318
ctypedef void* FARPROC
24-
ctypedef unsigned long DWORD
25-
ctypedef const wchar_t *LPCWSTR
2619
ctypedef const char *LPCSTR
2720

28-
cdef DWORD LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
29-
cdef DWORD LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
30-
cdef DWORD LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
31-
32-
HMODULE _LoadLibraryExW "LoadLibraryExW"(
33-
LPCWSTR lpLibFileName,
34-
HANDLE hFile,
35-
DWORD dwFlags
36-
)
37-
3821
FARPROC _GetProcAddress "GetProcAddress"(HMODULE hModule, LPCSTR lpProcName)
3922

40-
cdef inline uintptr_t LoadLibraryExW(str path, HANDLE hFile, DWORD dwFlags):
41-
cdef uintptr_t result
42-
cdef wchar_t* wpath = PyUnicode_AsWideCharString(path, NULL)
43-
with nogil:
44-
result = <uintptr_t>_LoadLibraryExW(
45-
wpath,
46-
hFile,
47-
dwFlags
48-
)
49-
PyMem_Free(wpath)
50-
return result
51-
5223
cdef inline void *GetProcAddress(uintptr_t hModule, const char* lpProcName) nogil:
5324
return _GetProcAddress(<HMODULE>hModule, lpProcName)
5425

55-
cdef int get_cuda_version():
56-
cdef int err, driver_ver = 0
57-
58-
# Load driver to check version
59-
handle = LoadLibraryExW("nvcuda.dll", NULL, LOAD_LIBRARY_SEARCH_SYSTEM32)
60-
if handle == 0:
61-
raise NotSupportedError('CUDA driver is not found')
62-
cuDriverGetVersion = GetProcAddress(handle, 'cuDriverGetVersion')
63-
if cuDriverGetVersion == NULL:
64-
raise RuntimeError('something went wrong')
65-
err = (<int (*)(int*) noexcept nogil>cuDriverGetVersion)(&driver_ver)
66-
if err != 0:
67-
raise RuntimeError('something went wrong')
68-
69-
return driver_ver
70-
71-
7226

7327
###############################################################################
7428
# Wrapper init
@@ -99,8 +53,6 @@ cdef int _check_or_init_nvjitlink() except -1 nogil:
9953
return 0
10054

10155
with gil, __symbol_lock:
102-
driver_ver = get_cuda_version()
103-
10456
# Load library
10557
handle = load_nvidia_dynamic_lib("nvJitLink")._handle_uint
10658

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,6 @@ cdef extern from "<dlfcn.h>" nogil:
3030

3131
const void* RTLD_DEFAULT 'RTLD_DEFAULT'
3232

33-
cdef int get_cuda_version():
34-
cdef void* handle = NULL
35-
cdef int err, driver_ver = 0
36-
37-
# Load driver to check version
38-
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
39-
if handle == NULL:
40-
err_msg = dlerror()
41-
raise NotSupportedError(f'CUDA driver is not found ({err_msg.decode()})')
42-
cuDriverGetVersion = dlsym(handle, "cuDriverGetVersion")
43-
if cuDriverGetVersion == NULL:
44-
raise RuntimeError('something went wrong')
45-
err = (<int (*)(int*) noexcept nogil>cuDriverGetVersion)(&driver_ver)
46-
if err != 0:
47-
raise RuntimeError('something went wrong')
48-
49-
return driver_ver
50-
5133

5234
###############################################################################
5335
# Wrapper init
@@ -84,8 +66,6 @@ cdef int _check_or_init_nvvm() except -1 nogil:
8466
cdef void* handle = NULL
8567

8668
with gil, __symbol_lock:
87-
driver_ver = get_cuda_version()
88-
8969
# Load function
9070
global __nvvmGetErrorString
9171
__nvvmGetErrorString = dlsym(RTLD_DEFAULT, 'nvvmGetErrorString')

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,64 +11,18 @@ from .utils import FunctionNotFoundError, NotSupportedError
1111

1212
from cuda.pathfinder import load_nvidia_dynamic_lib
1313

14-
from libc.stddef cimport wchar_t
1514
from libc.stdint cimport uintptr_t
16-
from cpython cimport PyUnicode_AsWideCharString, PyMem_Free
17-
18-
from .utils import NotSupportedError
1915

2016
cdef extern from "windows.h" nogil:
2117
ctypedef void* HMODULE
22-
ctypedef void* HANDLE
2318
ctypedef void* FARPROC
24-
ctypedef unsigned long DWORD
25-
ctypedef const wchar_t *LPCWSTR
2619
ctypedef const char *LPCSTR
2720

28-
cdef DWORD LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
29-
cdef DWORD LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
30-
cdef DWORD LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
31-
32-
HMODULE _LoadLibraryExW "LoadLibraryExW"(
33-
LPCWSTR lpLibFileName,
34-
HANDLE hFile,
35-
DWORD dwFlags
36-
)
37-
3821
FARPROC _GetProcAddress "GetProcAddress"(HMODULE hModule, LPCSTR lpProcName)
3922

40-
cdef inline uintptr_t LoadLibraryExW(str path, HANDLE hFile, DWORD dwFlags):
41-
cdef uintptr_t result
42-
cdef wchar_t* wpath = PyUnicode_AsWideCharString(path, NULL)
43-
with nogil:
44-
result = <uintptr_t>_LoadLibraryExW(
45-
wpath,
46-
hFile,
47-
dwFlags
48-
)
49-
PyMem_Free(wpath)
50-
return result
51-
5223
cdef inline void *GetProcAddress(uintptr_t hModule, const char* lpProcName) nogil:
5324
return _GetProcAddress(<HMODULE>hModule, lpProcName)
5425

55-
cdef int get_cuda_version():
56-
cdef int err, driver_ver = 0
57-
58-
# Load driver to check version
59-
handle = LoadLibraryExW("nvcuda.dll", NULL, LOAD_LIBRARY_SEARCH_SYSTEM32)
60-
if handle == 0:
61-
raise NotSupportedError('CUDA driver is not found')
62-
cuDriverGetVersion = GetProcAddress(handle, 'cuDriverGetVersion')
63-
if cuDriverGetVersion == NULL:
64-
raise RuntimeError('something went wrong')
65-
err = (<int (*)(int*) noexcept nogil>cuDriverGetVersion)(&driver_ver)
66-
if err != 0:
67-
raise RuntimeError('something went wrong')
68-
69-
return driver_ver
70-
71-
7226

7327
###############################################################################
7428
# Wrapper init
@@ -98,8 +52,6 @@ cdef int _check_or_init_nvvm() except -1 nogil:
9852
return 0
9953

10054
with gil, __symbol_lock:
101-
driver_ver = get_cuda_version()
102-
10355
# Load library
10456
handle = load_nvidia_dynamic_lib("nvvm")._handle_uint
10557

0 commit comments

Comments
 (0)