@@ -44,6 +44,8 @@ ctypedef unsigned long long float_ptr
4444ctypedef unsigned long long double_ptr
4545ctypedef unsigned long long void_ptr
4646
47+ cdef dict _cu_mem_alloc_managed_concurrent_access_by_device = {}
48+
4749#: CUDA API version number
4850CUDA_VERSION = cydriver.CUDA_VERSION
4951
@@ -31212,7 +31214,7 @@ def cuMemHostGetFlags(p):
3121231214 return (_CUresult_SUCCESS, pFlags)
3121331215{{endif}}
3121431216
31215- {{if 'MANUALLYDISABLEDcuMemAllocManaged ' in found_functions}}
31217+ {{if 'cuMemAllocManaged ' in found_functions}}
3121631218
3121731219@cython.embedsignature(True)
3121831220def cuMemAllocManaged(size_t bytesize, unsigned int flags):
@@ -31341,6 +31343,39 @@ def cuMemAllocManaged(size_t bytesize, unsigned int flags):
3134131343 --------
3134231344 :py:obj:`~.cuArray3DCreate`, :py:obj:`~.cuArray3DGetDescriptor`, :py:obj:`~.cuArrayCreate`, :py:obj:`~.cuArrayDestroy`, :py:obj:`~.cuArrayGetDescriptor`, :py:obj:`~.cuMemAllocHost`, :py:obj:`~.cuMemAllocPitch`, :py:obj:`~.cuMemcpy2D`, :py:obj:`~.cuMemcpy2DAsync`, :py:obj:`~.cuMemcpy2DUnaligned`, :py:obj:`~.cuMemcpy3D`, :py:obj:`~.cuMemcpy3DAsync`, :py:obj:`~.cuMemcpyAtoA`, :py:obj:`~.cuMemcpyAtoD`, :py:obj:`~.cuMemcpyAtoH`, :py:obj:`~.cuMemcpyAtoHAsync`, :py:obj:`~.cuMemcpyDtoA`, :py:obj:`~.cuMemcpyDtoD`, :py:obj:`~.cuMemcpyDtoDAsync`, :py:obj:`~.cuMemcpyDtoH`, :py:obj:`~.cuMemcpyDtoHAsync`, :py:obj:`~.cuMemcpyHtoA`, :py:obj:`~.cuMemcpyHtoAAsync`, :py:obj:`~.cuMemcpyHtoD`, :py:obj:`~.cuMemcpyHtoDAsync`, :py:obj:`~.cuMemFree`, :py:obj:`~.cuMemFreeHost`, :py:obj:`~.cuMemGetAddressRange`, :py:obj:`~.cuMemGetInfo`, :py:obj:`~.cuMemHostAlloc`, :py:obj:`~.cuMemHostGetDevicePointer`, :py:obj:`~.cuMemsetD2D8`, :py:obj:`~.cuMemsetD2D16`, :py:obj:`~.cuMemsetD2D32`, :py:obj:`~.cuMemsetD8`, :py:obj:`~.cuMemsetD16`, :py:obj:`~.cuMemsetD32`, :py:obj:`~.cuDeviceGetAttribute`, :py:obj:`~.cuStreamAttachMemAsync`, :py:obj:`~.cudaMallocManaged`
3134331345 """
31346+ # WIP-WIP-WIP THIS CODE NEEDS TO BE PORTED TO THE CODE GENERATOR
31347+ cdef int concurrent_access = 0
31348+ cdef int device_id = 0
31349+ cdef cydriver.CUdevice device
31350+ err = cydriver.cuCtxGetDevice(&device)
31351+ if err != cydriver.CUDA_SUCCESS:
31352+ # cuMemAllocManaged would fail with the same error anyway.
31353+ return (_CUresult(err), None)
31354+ device_id = <int>device
31355+ if device_id in _cu_mem_alloc_managed_concurrent_access_by_device:
31356+ if _cu_mem_alloc_managed_concurrent_access_by_device[device_id] == 0:
31357+ raise RuntimeError(
31358+ "cuMemAllocManaged is not supported when "
31359+ "CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS == 0"
31360+ )
31361+ else:
31362+ err = cydriver.cuDeviceGetAttribute(
31363+ &concurrent_access,
31364+ cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS,
31365+ device,
31366+ )
31367+ if err != cydriver.CUDA_SUCCESS:
31368+ raise RuntimeError(
31369+ "cuDeviceGetAttribute(CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS) failed "
31370+ f"while validating cuMemAllocManaged: {_CUresult(err)}"
31371+ )
31372+ _cu_mem_alloc_managed_concurrent_access_by_device[device_id] = concurrent_access
31373+ if concurrent_access == 0:
31374+ raise RuntimeError(
31375+ "cuMemAllocManaged is not supported when "
31376+ "CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS == 0"
31377+ )
31378+
3134431379 cdef CUdeviceptr dptr = CUdeviceptr()
3134531380 with nogil:
3134631381 err = cydriver.cuMemAllocManaged(<cydriver.CUdeviceptr*>dptr._pvt_ptr, bytesize, flags)
0 commit comments