Skip to content

Commit 5fb344c

Browse files
committed
Fix #702: Update cyruntime.getLocalRuntimeVersion to use pathfinder
1 parent 8174361 commit 5fb344c

File tree

2 files changed

+26
-18
lines changed

2 files changed

+26
-18
lines changed

cuda_bindings/cuda/bindings/cyruntime.pyx.in

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,18 @@
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

44
# This code was automatically generated with version 13.0.0. Do not modify it directly.
5+
from libc.stdint cimport uintptr_t
56
cimport cuda.bindings._bindings.cyruntime as cyruntime
67
cimport cython
78

9+
from cuda.pathfinder import load_nvidia_dynamic_lib
10+
from pathlib import Path
11+
{{if 'Windows' == platform.system()}}
12+
import win32api
13+
{{else}}
14+
cimport cuda.bindings._lib.dlfcn as dlfcn
15+
{{endif}}
16+
817
{{if 'cudaDeviceReset' in found_functions}}
918

1019
cdef cudaError_t cudaDeviceReset() except ?cudaErrorCallRequiresNewerDriver nogil:
@@ -1885,35 +1894,28 @@ cdef cudaError_t cudaGraphicsVDPAURegisterOutputSurface(cudaGraphicsResource** r
18851894

18861895
{{if True}}
18871896

1888-
{{if 'Windows' != platform.system()}}
1889-
cimport cuda.bindings._lib.dlfcn as dlfcn
1890-
{{endif}}
1891-
18921897
cdef cudaError_t getLocalRuntimeVersion(int* runtimeVersion) except ?cudaErrorCallRequiresNewerDriver nogil:
1893-
{{if 'Windows' == platform.system()}}
18941898
with gil:
1895-
raise NotImplementedError('"getLocalRuntimeVersion" is unsupported on Windows')
1896-
{{else}}
1897-
# Load
1898-
handle = dlfcn.dlopen('libcudart.so.13', dlfcn.RTLD_NOW)
1899-
if handle == NULL:
1900-
with gil:
1901-
raise RuntimeError(f'Failed to dlopen libcudart.so.13')
1899+
lib = load_nvidia_dynamic_lib("cudart")
1900+
filename = Path(lib.abs_path).name
1901+
handle = <void *><uintptr_t>lib._handle_uint
19021902

1903+
{{if 'Windows' == platform.system()}}
1904+
try:
1905+
__cudaRuntimeGetVersion = <void*><unsigned long long>win32api.GetProcAddress(handle, 'cudaRuntimeGetVersion')
1906+
except:
1907+
pass
1908+
{{else}}
19031909
__cudaRuntimeGetVersion = dlfcn.dlsym(handle, 'cudaRuntimeGetVersion')
1910+
{{endif}}
19041911

19051912
if __cudaRuntimeGetVersion == NULL:
19061913
with gil:
1907-
raise RuntimeError(f'Function "cudaRuntimeGetVersion" not found in libcudart.so.13')
1914+
raise RuntimeError(f'Function "cudaRuntimeGetVersion" not found in {filename}')
19081915

1909-
# Call
19101916
cdef cudaError_t err = cudaSuccess
19111917
err = (<cudaError_t (*)(int*) except ?cudaErrorCallRequiresNewerDriver nogil> __cudaRuntimeGetVersion)(runtimeVersion)
19121918

1913-
# Unload
1914-
dlfcn.dlclose(handle)
1915-
19161919
# Return
19171920
return err
1918-
{{endif}}
19191921
{{endif}}

cuda_bindings/tests/test_cudart.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,3 +1400,9 @@ def test_struct_pointer_comparison(target):
14001400
c = target(456)
14011401
assert a != c
14021402
assert hash(a) != hash(c)
1403+
1404+
1405+
def test_getLocalRuntimeVersion():
1406+
err, version = cudart.getLocalRuntimeVersion()
1407+
assertSuccess(err)
1408+
assert version >= 10000 # CUDA 10.0

0 commit comments

Comments
 (0)