@@ -77,8 +77,6 @@ NvJitLinkDestroyFn p_nvJitLinkDestroy = nullptr;
7777
7878namespace {
7979
80- using cuda_core::detail::py_is_finalizing;
81-
8280// Helper to release the GIL while calling into the CUDA driver.
8381// This guard is *conditional*: if the caller already dropped the GIL,
8482// we avoid calling PyEval_SaveThread (which requires holding the GIL).
@@ -148,6 +146,51 @@ class GILAcquireGuard {
148146
149147} // namespace
150148
149+ // ============================================================================
150+ // Handle reverse-lookup registry
151+ //
152+ // Maps raw CUDA handles (CUevent, CUkernel, etc.) back to their owning
153+ // shared_ptr so that _ref constructors can recover full metadata.
154+ // Uses weak_ptr to avoid preventing destruction.
155+ // ============================================================================
156+
157+ template <typename Key, typename Handle, typename Hash = std::hash<Key>>
158+ class HandleRegistry {
159+ public:
160+ using MapType = std::unordered_map<Key, std::weak_ptr<typename Handle::element_type>, Hash>;
161+
162+ void register_handle (const Key& key, const Handle& h) {
163+ std::lock_guard<std::mutex> lock (mutex_);
164+ map_[key] = h;
165+ }
166+
167+ void unregister_handle (const Key& key) noexcept {
168+ try {
169+ std::lock_guard<std::mutex> lock (mutex_);
170+ auto it = map_.find (key);
171+ if (it != map_.end () && it->second .expired ()) {
172+ map_.erase (it);
173+ }
174+ } catch (...) {}
175+ }
176+
177+ Handle lookup (const Key& key) {
178+ std::lock_guard<std::mutex> lock (mutex_);
179+ auto it = map_.find (key);
180+ if (it != map_.end ()) {
181+ if (auto h = it->second .lock ()) {
182+ return h;
183+ }
184+ map_.erase (it);
185+ }
186+ return {};
187+ }
188+
189+ private:
190+ std::mutex mutex_;
191+ MapType map_;
192+ };
193+
151194// ============================================================================
152195// Thread-local error handling
153196// ============================================================================
@@ -306,47 +349,98 @@ StreamHandle get_per_thread_stream() {
306349namespace {
307350struct EventBox {
308351 CUevent resource;
352+ bool timing_disabled;
353+ bool busy_waited;
354+ bool ipc_enabled;
355+ int device_id;
356+ ContextHandle h_context;
309357};
310358} // namespace
311359
312- EventHandle create_event_handle (const ContextHandle& h_ctx, unsigned int flags) {
360+ static const EventBox* get_box (const EventHandle& h) {
361+ const CUevent* p = h.get ();
362+ return reinterpret_cast <const EventBox*>(
363+ reinterpret_cast <const char *>(p) - offsetof (EventBox, resource)
364+ );
365+ }
366+
367+ bool get_event_timing_disabled (const EventHandle& h) noexcept {
368+ return h ? get_box (h)->timing_disabled : true ;
369+ }
370+
371+ bool get_event_busy_waited (const EventHandle& h) noexcept {
372+ return h ? get_box (h)->busy_waited : false ;
373+ }
374+
375+ bool get_event_ipc_enabled (const EventHandle& h) noexcept {
376+ return h ? get_box (h)->ipc_enabled : false ;
377+ }
378+
379+ int get_event_device_id (const EventHandle& h) noexcept {
380+ return h ? get_box (h)->device_id : -1 ;
381+ }
382+
383+ ContextHandle get_event_context (const EventHandle& h) noexcept {
384+ return h ? get_box (h)->h_context : ContextHandle{};
385+ }
386+
387+ static HandleRegistry<CUevent, EventHandle> event_registry;
388+
389+ EventHandle create_event_handle (const ContextHandle& h_ctx, unsigned int flags,
390+ bool timing_disabled, bool busy_waited,
391+ bool ipc_enabled, int device_id) {
313392 GILReleaseGuard gil;
314393 CUevent event;
315394 if (CUDA_SUCCESS != (err = p_cuEventCreate (&event, flags))) {
316395 return {};
317396 }
318397
319398 auto box = std::shared_ptr<const EventBox>(
320- new EventBox{event},
399+ new EventBox{event, timing_disabled, busy_waited, ipc_enabled, device_id, h_ctx },
321400 [h_ctx](const EventBox* b) {
401+ event_registry.unregister_handle (b->resource );
322402 GILReleaseGuard gil;
323403 p_cuEventDestroy (b->resource );
324404 delete b;
325405 }
326406 );
327- return EventHandle (box, &box->resource );
407+ EventHandle h (box, &box->resource );
408+ event_registry.register_handle (event, h);
409+ return h;
328410}
329411
330412EventHandle create_event_handle_noctx (unsigned int flags) {
331- return create_event_handle (ContextHandle{}, flags);
413+ return create_event_handle (ContextHandle{}, flags, true , false , false , - 1 );
332414}
333415
334- EventHandle create_event_handle_ipc (const CUipcEventHandle& ipc_handle) {
416+ EventHandle create_event_handle_ref (CUevent event) {
417+ if (auto h = event_registry.lookup (event)) {
418+ return h;
419+ }
420+ auto box = std::make_shared<const EventBox>(EventBox{event, true , false , false , -1 , {}});
421+ return EventHandle (box, &box->resource );
422+ }
423+
424+ EventHandle create_event_handle_ipc (const CUipcEventHandle& ipc_handle,
425+ bool busy_waited) {
335426 GILReleaseGuard gil;
336427 CUevent event;
337428 if (CUDA_SUCCESS != (err = p_cuIpcOpenEventHandle (&event, ipc_handle))) {
338429 return {};
339430 }
340431
341432 auto box = std::shared_ptr<const EventBox>(
342- new EventBox{event},
433+ new EventBox{event, true , busy_waited, true , - 1 , {} },
343434 [](const EventBox* b) {
435+ event_registry.unregister_handle (b->resource );
344436 GILReleaseGuard gil;
345437 p_cuEventDestroy (b->resource );
346438 delete b;
347439 }
348440 );
349- return EventHandle (box, &box->resource );
441+ EventHandle h (box, &box->resource );
442+ event_registry.register_handle (event, h);
443+ return h;
350444}
351445
352446// ============================================================================
@@ -653,61 +747,43 @@ struct ExportDataKeyHash {
653747
654748}
655749
656- static std::mutex ipc_ptr_cache_mutex ;
657- static std::unordered_map<ExportDataKey, std::weak_ptr<DevicePtrBox>, ExportDataKeyHash> ipc_ptr_cache ;
750+ static HandleRegistry<ExportDataKey, DevicePtrHandle, ExportDataKeyHash> ipc_ptr_cache ;
751+ static std::mutex ipc_import_mutex ;
658752
659753DevicePtrHandle deviceptr_import_ipc (const MemoryPoolHandle& h_pool, const void * export_data, const StreamHandle& h_stream) {
660754 auto data = const_cast <CUmemPoolPtrExportData*>(
661755 reinterpret_cast <const CUmemPoolPtrExportData*>(export_data));
662756
663757 if (use_ipc_ptr_cache ()) {
664- // Check cache before calling cuMemPoolImportPointer
665758 ExportDataKey key;
666759 std::memcpy (&key.data , data, sizeof (key.data ));
667760
668- std::lock_guard<std::mutex> lock (ipc_ptr_cache_mutex );
761+ std::lock_guard<std::mutex> lock (ipc_import_mutex );
669762
670- auto it = ipc_ptr_cache.find (key);
671- if (it != ipc_ptr_cache.end ()) {
672- if (auto box = it->second .lock ()) {
673- // Cache hit - return existing handle
674- return DevicePtrHandle (box, &box->resource );
675- }
676- ipc_ptr_cache.erase (it); // Expired entry
763+ if (auto h = ipc_ptr_cache.lookup (key)) {
764+ return h;
677765 }
678766
679- // Cache miss - import the pointer
680767 GILReleaseGuard gil;
681768 CUdeviceptr ptr;
682769 if (CUDA_SUCCESS != (err = p_cuMemPoolImportPointer (&ptr, *h_pool, data))) {
683770 return {};
684771 }
685772
686- // Create new handle with cache-clearing deleter
687773 auto box = std::shared_ptr<DevicePtrBox>(
688774 new DevicePtrBox{ptr, h_stream},
689775 [h_pool, key](DevicePtrBox* b) {
776+ ipc_ptr_cache.unregister_handle (key);
690777 GILReleaseGuard gil;
691- try {
692- std::lock_guard<std::mutex> lock (ipc_ptr_cache_mutex);
693- // Only erase if expired - avoids race where another thread
694- // replaced the entry with a new import before we acquired the lock.
695- auto it = ipc_ptr_cache.find (key);
696- if (it != ipc_ptr_cache.end () && it->second .expired ()) {
697- ipc_ptr_cache.erase (it);
698- }
699- } catch (...) {
700- // Cache cleanup is best-effort - swallow exceptions in destructor context
701- }
702778 p_cuMemFreeAsync (b->resource , as_cu (b->h_stream ));
703779 delete b;
704780 }
705781 );
706- ipc_ptr_cache[key] = box;
707- return DevicePtrHandle (box, &box->resource );
782+ DevicePtrHandle h (box, &box->resource );
783+ ipc_ptr_cache.register_handle (key, h);
784+ return h;
708785
709786 } else {
710- // No caching - simple handle creation
711787 GILReleaseGuard gil;
712788 CUdeviceptr ptr;
713789 if (CUDA_SUCCESS != (err = p_cuMemPoolImportPointer (&ptr, *h_pool, data))) {
@@ -786,25 +862,45 @@ LibraryHandle create_library_handle_ref(CUlibrary library) {
786862namespace {
787863struct KernelBox {
788864 CUkernel resource;
789- LibraryHandle h_library; // Keeps library alive
865+ LibraryHandle h_library;
790866};
791867} // namespace
792868
869+ static const KernelBox* get_box (const KernelHandle& h) {
870+ const CUkernel* p = h.get ();
871+ return reinterpret_cast <const KernelBox*>(
872+ reinterpret_cast <const char *>(p) - offsetof (KernelBox, resource)
873+ );
874+ }
875+
876+ static HandleRegistry<CUkernel, KernelHandle> kernel_registry;
877+
793878KernelHandle create_kernel_handle (const LibraryHandle& h_library, const char * name) {
794879 GILReleaseGuard gil;
795880 CUkernel kernel;
796881 if (CUDA_SUCCESS != (err = p_cuLibraryGetKernel (&kernel, *h_library, name))) {
797882 return {};
798883 }
799884
800- return create_kernel_handle_ref (kernel, h_library);
885+ auto box = std::make_shared<const KernelBox>(KernelBox{kernel, h_library});
886+ KernelHandle h (box, &box->resource );
887+ kernel_registry.register_handle (kernel, h);
888+ return h;
801889}
802890
803- KernelHandle create_kernel_handle_ref (CUkernel kernel, const LibraryHandle& h_library) {
804- auto box = std::make_shared<const KernelBox>(KernelBox{kernel, h_library});
891+ KernelHandle create_kernel_handle_ref (CUkernel kernel) {
892+ if (auto h = kernel_registry.lookup (kernel)) {
893+ return h;
894+ }
895+ auto box = std::make_shared<const KernelBox>(KernelBox{kernel, {}});
805896 return KernelHandle (box, &box->resource );
806897}
807898
899+ LibraryHandle get_kernel_library (const KernelHandle& h) noexcept {
900+ if (!h) return {};
901+ return get_box (h)->h_library ;
902+ }
903+
808904// ============================================================================
809905// Graphics Resource Handles
810906// ============================================================================
0 commit comments