Skip to content

Commit e411f2d

Browse files
committed
Guard cuMemAllocManaged on concurrent managed access
Restore the cuMemAllocManaged binding, validate concurrent managed access per active device, and drop the test-helper skip for missing symbols. Made-with: Cursor
1 parent 20d0197 commit e411f2d

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

cuda_bindings/cuda/bindings/driver.pyx.in

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ ctypedef unsigned long long float_ptr
4444
ctypedef unsigned long long double_ptr
4545
ctypedef unsigned long long void_ptr
4646

47+
cdef dict _cu_mem_alloc_managed_concurrent_access_by_device = {}
48+
4749
#: CUDA API version number
4850
CUDA_VERSION = cydriver.CUDA_VERSION
4951

@@ -31212,7 +31214,7 @@ def cuMemHostGetFlags(p):
3121231214
return (_CUresult_SUCCESS, pFlags)
3121331215
{{endif}}
3121431216

31215-
{{if 'MANUALLYDISABLEDcuMemAllocManaged' in found_functions}}
31217+
{{if 'cuMemAllocManaged' in found_functions}}
3121631218

3121731219
@cython.embedsignature(True)
3121831220
def cuMemAllocManaged(size_t bytesize, unsigned int flags):
@@ -31341,6 +31343,39 @@ def cuMemAllocManaged(size_t bytesize, unsigned int flags):
3134131343
--------
3134231344
:py:obj:`~.cuArray3DCreate`, :py:obj:`~.cuArray3DGetDescriptor`, :py:obj:`~.cuArrayCreate`, :py:obj:`~.cuArrayDestroy`, :py:obj:`~.cuArrayGetDescriptor`, :py:obj:`~.cuMemAllocHost`, :py:obj:`~.cuMemAllocPitch`, :py:obj:`~.cuMemcpy2D`, :py:obj:`~.cuMemcpy2DAsync`, :py:obj:`~.cuMemcpy2DUnaligned`, :py:obj:`~.cuMemcpy3D`, :py:obj:`~.cuMemcpy3DAsync`, :py:obj:`~.cuMemcpyAtoA`, :py:obj:`~.cuMemcpyAtoD`, :py:obj:`~.cuMemcpyAtoH`, :py:obj:`~.cuMemcpyAtoHAsync`, :py:obj:`~.cuMemcpyDtoA`, :py:obj:`~.cuMemcpyDtoD`, :py:obj:`~.cuMemcpyDtoDAsync`, :py:obj:`~.cuMemcpyDtoH`, :py:obj:`~.cuMemcpyDtoHAsync`, :py:obj:`~.cuMemcpyHtoA`, :py:obj:`~.cuMemcpyHtoAAsync`, :py:obj:`~.cuMemcpyHtoD`, :py:obj:`~.cuMemcpyHtoDAsync`, :py:obj:`~.cuMemFree`, :py:obj:`~.cuMemFreeHost`, :py:obj:`~.cuMemGetAddressRange`, :py:obj:`~.cuMemGetInfo`, :py:obj:`~.cuMemHostAlloc`, :py:obj:`~.cuMemHostGetDevicePointer`, :py:obj:`~.cuMemsetD2D8`, :py:obj:`~.cuMemsetD2D16`, :py:obj:`~.cuMemsetD2D32`, :py:obj:`~.cuMemsetD8`, :py:obj:`~.cuMemsetD16`, :py:obj:`~.cuMemsetD32`, :py:obj:`~.cuDeviceGetAttribute`, :py:obj:`~.cuStreamAttachMemAsync`, :py:obj:`~.cudaMallocManaged`
3134331345
"""
31346+
# WIP-WIP-WIP THIS CODE NEEDS TO BE PORTED TO THE CODE GENERATOR
31347+
cdef int concurrent_access = 0
31348+
cdef int device_id = 0
31349+
cdef cydriver.CUdevice device
31350+
err = cydriver.cuCtxGetDevice(&device)
31351+
if err != cydriver.CUDA_SUCCESS:
31352+
# cuMemAllocManaged would fail with the same error anyway.
31353+
return (_CUresult(err), None)
31354+
device_id = <int>device
31355+
if device_id in _cu_mem_alloc_managed_concurrent_access_by_device:
31356+
if _cu_mem_alloc_managed_concurrent_access_by_device[device_id] == 0:
31357+
raise RuntimeError(
31358+
"cuMemAllocManaged is not supported when "
31359+
"CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS == 0"
31360+
)
31361+
else:
31362+
err = cydriver.cuDeviceGetAttribute(
31363+
&concurrent_access,
31364+
cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS,
31365+
device,
31366+
)
31367+
if err != cydriver.CUDA_SUCCESS:
31368+
raise RuntimeError(
31369+
"cuDeviceGetAttribute(CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS) failed "
31370+
f"while validating cuMemAllocManaged: {_CUresult(err)}"
31371+
)
31372+
_cu_mem_alloc_managed_concurrent_access_by_device[device_id] = concurrent_access
31373+
if concurrent_access == 0:
31374+
raise RuntimeError(
31375+
"cuMemAllocManaged is not supported when "
31376+
"CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS == 0"
31377+
)
31378+
3134431379
cdef CUdeviceptr dptr = CUdeviceptr()
3134531380
with nogil:
3134631381
err = cydriver.cuMemAllocManaged(<cydriver.CUdeviceptr*>dptr._pvt_ptr, bytesize, flags)

cuda_python_test_helpers/cuda_python_test_helpers/managed_memory.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ def _get_concurrent_managed_access(device_id: int) -> int | None:
5151

5252
def managed_memory_skip_reason(device=None) -> str | None:
5353
"""Return a skip reason when managed memory should be avoided."""
54-
if not hasattr(driver, "cuMemAllocManaged"):
55-
return "cuMemAllocManaged is unavailable; treating concurrent managed access as disabled"
5654
device_id = _resolve_device_id(device)
5755
value = _get_concurrent_managed_access(device_id)
5856
if value is None:

0 commit comments

Comments
 (0)