Skip to content

Commit 73ba7fe

Browse files
committed
Merge remote-tracking branch 'origin/main' into explicit-graph-construction
2 parents b55782a + f720e48 commit 73ba7fe

File tree

14 files changed

+357
-54
lines changed

14 files changed

+357
-54
lines changed

cuda_core/cuda/core/_linker.pyx

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ from dataclasses import dataclass
2929
from typing import Union
3030
from warnings import warn
3131

32+
from cuda.pathfinder import optional_cuda_import
3233
from cuda.core._device import Device
3334
from cuda.core._module import ObjectCode
3435
from cuda.core._utils.clear_error_support import assert_type
@@ -649,23 +650,20 @@ def _decide_nvjitlink_or_driver() -> bool:
649650
" For best results, consider upgrading to a recent version of"
650651
)
651652

652-
try:
653-
__import__("cuda.bindings.nvjitlink") # availability check
654-
except ModuleNotFoundError:
653+
nvjitlink_module = optional_cuda_import(
654+
"cuda.bindings.nvjitlink",
655+
probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load
656+
)
657+
if nvjitlink_module is None:
655658
warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
656659
else:
657660
from cuda.bindings._internal import nvjitlink
658661

659-
try:
660-
if _nvjitlink_has_version_symbol(nvjitlink):
661-
_use_nvjitlink_backend = True
662-
return False # Use nvjitlink
663-
except RuntimeError:
664-
warn_detail = "not available"
665-
else:
666-
warn_detail = "too old (<12.3)"
662+
if _nvjitlink_has_version_symbol(nvjitlink):
663+
_use_nvjitlink_backend = True
664+
return False # Use nvjitlink
667665
warn_txt = (
668-
f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is {warn_detail}."
666+
f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
669667
f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
670668
)
671669

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,8 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
10961096
buf.exporting_obj = obj
10971097
buf.metadata = cai_data
10981098
buf.dl_tensor = NULL
1099+
# Validate shape/strides/typestr eagerly so constructor paths fail fast.
1100+
buf.get_layout()
10991101
buf.ptr, buf.readonly = cai_data["data"]
11001102
buf.is_device_accessible = True
11011103
if buf.ptr != 0:
@@ -1138,6 +1140,8 @@ cpdef StridedMemoryView view_as_array_interface(obj, view=None):
11381140
buf.exporting_obj = obj
11391141
buf.metadata = data
11401142
buf.dl_tensor = NULL
1143+
# Validate shape/strides/typestr eagerly so constructor paths fail fast.
1144+
buf.get_layout()
11411145
buf.ptr, buf.readonly = data["data"]
11421146
buf.is_device_accessible = False
11431147
buf.device_id = handle_return(driver.cuCtxGetDevice())

cuda_core/cuda/core/_program.pyx

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import threading
1414
from warnings import warn
1515

1616
from cuda.bindings import driver, nvrtc
17+
from cuda.pathfinder import optional_cuda_import
1718

1819
from libcpp.vector cimport vector
1920

@@ -461,8 +462,8 @@ class ProgramOptions:
461462
# =============================================================================
462463

463464
# Module-level state for NVVM lazy loading
464-
cdef object_nvvm_module = None
465-
cdef bint _nvvm_import_attempted = False
465+
_nvvm_module = None
466+
_nvvm_import_attempted = False
466467

467468

468469
def _get_nvvm_module():
@@ -484,18 +485,21 @@ def _get_nvvm_module():
484485
"Please update cuda-bindings to use NVVM features."
485486
)
486487

487-
from cuda.bindings import nvvm
488-
from cuda.bindings._internal.nvvm import _inspect_function_pointer
489-
490-
if _inspect_function_pointer("__nvvmCreateProgram") == 0:
491-
raise RuntimeError("NVVM library (libnvvm) is not available in this Python environment. ")
488+
nvvm = optional_cuda_import(
489+
"cuda.bindings.nvvm",
490+
probe_function=lambda module: module.version(), # probe triggers libnvvm load
491+
)
492+
if nvvm is None:
493+
raise RuntimeError(
494+
"NVVM support is unavailable: cuda.bindings.nvvm is missing or libnvvm cannot be loaded."
495+
)
492496

493497
_nvvm_module = nvvm
494498
return _nvvm_module
495499

496-
except RuntimeError as e:
500+
except RuntimeError:
497501
_nvvm_module = None
498-
raise e
502+
raise
499503

500504
def _find_libdevice_path():
501505
"""Find libdevice*.bc for NVVM compilation using cuda.pathfinder."""

cuda_core/docs/source/release/0.7.x-notes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ Fixes and enhancements
4040
linking operations to the C level and releasing the GIL during backend calls. This benefits
4141
workloads that create many programs or linkers, and enables concurrent compilation in
4242
multithreaded applications.
43+
- Improved optional dependency handling for NVVM and nvJitLink imports so that only genuinely
44+
missing optional modules are treated as unavailable; unrelated import failures now surface
45+
normally, and ``cuda.core`` now depends directly on ``cuda-pathfinder``.

cuda_core/pixi.lock

Lines changed: 29 additions & 30 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)