1- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22#
33# SPDX-License-Identifier: Apache-2.0
44
55from __future__ import annotations
66
77from cuda.bindings cimport cydriver
8+
89from cuda.core._memory._memory_pool cimport _MemPool, _MemPoolOptions
910from cuda.core._utils.cuda_utils cimport (
11+ HANDLE_RETURN,
1012 check_or_create_options,
1113)
1214
1315from dataclasses import dataclass
16+ import threading
17+ import warnings
1418
1519__all__ = [' ManagedMemoryResource' , ' ManagedMemoryResourceOptions' ]
1620
@@ -91,6 +95,7 @@ cdef class ManagedMemoryResource(_MemPool):
9195 opts_base._type = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED
9296
9397 super ().__init__(device_id, opts_base)
98+ _check_concurrent_managed_access()
9499 ELSE :
95100 raise RuntimeError (" ManagedMemoryResource requires CUDA 13.0 or later" )
96101
@@ -103,3 +108,47 @@ cdef class ManagedMemoryResource(_MemPool):
103108 def is_host_accessible(self ) -> bool:
104109 """Return True. This memory resource provides host-accessible buffers."""
105110 return True
111+
112+
113+ cdef bint _concurrent_access_warned = False
114+ cdef object _concurrent_access_lock = threading.Lock()
115+
116+
117+ cdef inline _check_concurrent_managed_access():
118+ """ Warn once if the platform lacks concurrent managed memory access."""
119+ global _concurrent_access_warned
120+ if _concurrent_access_warned:
121+ return
122+
123+ cdef int c_concurrent = 0
124+ with _concurrent_access_lock:
125+ if _concurrent_access_warned:
126+ return
127+
128+ # concurrent_managed_access is a system-level attribute for sm_60 and
129+ # later, so any device will do.
130+ with nogil:
131+ HANDLE_RETURN(cydriver.cuDeviceGetAttribute(
132+ & c_concurrent,
133+ cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS,
134+ 0 ))
135+ if not c_concurrent:
136+ warnings.warn(
137+ " This platform does not support concurrent managed memory access "
138+ " (Device.properties.concurrent_managed_access is False). Host access to any managed "
139+ " allocation is forbidden while any GPU kernel is in flight, even "
140+ " if the kernel does not touch that allocation. Failing to "
141+ " synchronize before host access will cause a segfault. "
142+ " See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/"
143+ " index.html#gpu-exclusive-access-to-managed-memory" ,
144+ UserWarning ,
145+ stacklevel = 3
146+ )
147+
148+ _concurrent_access_warned = True
149+
150+
151+ def reset_concurrent_access_warning ():
152+ """ Reset the concurrent access warning flag for testing purposes."""
153+ global _concurrent_access_warned
154+ _concurrent_access_warned = False
0 commit comments