Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 23 additions & 29 deletions cuda_core/cuda/core/_graph/_graph_builder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ from cuda.core._graph._utils cimport _attach_host_callback_to_graph
from cuda.core._resource_handles cimport as_cu
from cuda.core._stream cimport Stream
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
from cuda.core._utils.version cimport cy_binding_version, cy_driver_version

from cuda.core._utils.cuda_utils import (
driver,
get_binding_version,
get_driver_version,
handle_return,
)

Expand Down Expand Up @@ -169,7 +169,7 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) ->
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
elif (
get_binding_version() >= (12, 8)
cy_binding_version() >= (12, 8, 0)
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
):
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
Expand Down Expand Up @@ -449,10 +449,10 @@ class GraphBuilder:
The newly created conditional handle.

"""
if get_driver_version() < 12030:
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional handles")
if get_binding_version() < (12, 3):
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional handles")
if cy_driver_version() < (12, 3, 0):
raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles")
if cy_binding_version() < (12, 3, 0):
raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles")
if default_value is not None:
flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT
else:
Expand Down Expand Up @@ -522,10 +522,10 @@ class GraphBuilder:
The newly created conditional graph builder.

"""
if get_driver_version() < 12030:
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional if")
if get_binding_version() < (12, 3):
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional if")
if cy_driver_version() < (12, 3, 0):
raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if")
if cy_binding_version() < (12, 3, 0):
raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if")
node_params = driver.CUgraphNodeParams()
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
node_params.conditional.handle = handle
Expand Down Expand Up @@ -553,10 +553,10 @@ class GraphBuilder:
A tuple of two new graph builders, one for the if branch and one for the else branch.

"""
if get_driver_version() < 12080:
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional if-else")
if get_binding_version() < (12, 8):
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional if-else")
if cy_driver_version() < (12, 8, 0):
raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else")
if cy_binding_version() < (12, 8, 0):
raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else")
node_params = driver.CUgraphNodeParams()
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
node_params.conditional.handle = handle
Expand Down Expand Up @@ -587,10 +587,10 @@ class GraphBuilder:
A tuple of new graph builders, one for each branch.

"""
if get_driver_version() < 12080:
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional switch")
if get_binding_version() < (12, 8):
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional switch")
if cy_driver_version() < (12, 8, 0):
raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch")
if cy_binding_version() < (12, 8, 0):
raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch")
node_params = driver.CUgraphNodeParams()
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
node_params.conditional.handle = handle
Expand Down Expand Up @@ -618,10 +618,10 @@ class GraphBuilder:
The newly created while loop graph builder.

"""
if get_driver_version() < 12030:
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional while loop")
if get_binding_version() < (12, 3):
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional while loop")
if cy_driver_version() < (12, 3, 0):
raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop")
if cy_binding_version() < (12, 3, 0):
raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop")
node_params = driver.CUgraphNodeParams()
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
node_params.conditional.handle = handle
Expand Down Expand Up @@ -649,12 +649,6 @@ class GraphBuilder:
child_graph : :obj:`~_graph.GraphBuilder`
The child graph builder. Must have finished building.
"""
if (get_driver_version() < 12000) or (get_binding_version() < (12, 0)):
raise NotImplementedError(
f"Launching child graphs is not implemented for versions older than CUDA 12."
f"Found driver version is {get_driver_version()} and binding version is {get_binding_version()}"
)

if not child_graph._building_ended:
raise ValueError("Child graph has not finished building.")

Expand Down
4 changes: 2 additions & 2 deletions cuda_core/cuda/core/_graph/_graphdef.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ cdef bint _version_checked = False
cdef bint _check_node_get_params():
global _has_cuGraphNodeGetParams, _version_checked
if not _version_checked:
ver = handle_return(driver.cuDriverGetVersion())
_has_cuGraphNodeGetParams = ver >= 13020
from cuda.core._utils.version import driver_version
_has_cuGraphNodeGetParams = driver_version() >= (13, 2, 0)
_version_checked = True
return _has_cuGraphNodeGetParams

Expand Down
42 changes: 0 additions & 42 deletions cuda_core/cuda/core/_launch_config.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,16 @@

from libc.string cimport memset

from cuda.core._utils.cuda_utils cimport (
HANDLE_RETURN,
)

import threading

from cuda.core._device import Device
from cuda.core._utils.cuda_utils import (
CUDAError,
cast_to_3_tuple,
driver,
get_binding_version,
)


cdef bint _inited = False
cdef bint _use_ex = False
cdef object _lock = threading.Lock()

# Attribute names for identity comparison and representation
_LAUNCH_CONFIG_ATTRS = ('grid', 'cluster', 'block', 'shmem_size', 'cooperative_launch')


cdef int _lazy_init() except?-1:
global _inited, _use_ex
if _inited:
return 0

cdef tuple _py_major_minor
cdef int _driver_ver
with _lock:
if _inited:
return 0

# binding availability depends on cuda-python version
_py_major_minor = get_binding_version()
HANDLE_RETURN(cydriver.cuDriverGetVersion(&_driver_ver))
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
_inited = True

return 0


cdef class LaunchConfig:
"""Customizable launch options.

Expand Down Expand Up @@ -99,8 +66,6 @@ cdef class LaunchConfig:
cooperative_launch : bool, optional
Whether to launch as cooperative kernel (default: False)
"""
_lazy_init()

# Convert and validate grid and block dimensions
self.grid = cast_to_3_tuple("LaunchConfig.grid", grid)
self.block = cast_to_3_tuple("LaunchConfig.block", block)
Expand All @@ -110,10 +75,6 @@ cdef class LaunchConfig:
# device compute capability or attributes.
# thread block clusters are supported starting H100
if cluster is not None:
if not _use_ex:
err, drvers = driver.cuDriverGetVersion()
drvers_fmt = f" (got driver version {drvers})" if err == driver.CUresult.CUDA_SUCCESS else ""
raise CUDAError(f"thread block clusters require cuda.bindings & driver 11.8+{drvers_fmt}")
cc = Device().compute_capability
if cc < (9, 0):
raise CUDAError(
Expand Down Expand Up @@ -153,7 +114,6 @@ cdef class LaunchConfig:
return hash(self._identity())

cdef cydriver.CUlaunchConfig _to_native_launch_config(self):
_lazy_init()
cdef cydriver.CUlaunchConfig drv_cfg
cdef cydriver.CUlaunchAttribute attr
memset(&drv_cfg, 0, sizeof(drv_cfg))
Expand Down Expand Up @@ -201,8 +161,6 @@ cpdef object _to_native_launch_config(LaunchConfig config):
driver.CUlaunchConfig
Native CUDA driver launch configuration
"""
_lazy_init()

cdef object drv_cfg = driver.CUlaunchConfig()
cdef list attrs
cdef object attr
Expand Down
66 changes: 9 additions & 57 deletions cuda_core/cuda/core/_launcher.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,9 @@ from cuda.core._utils.cuda_utils cimport (
check_or_create_options,
HANDLE_RETURN,
)

import threading

from cuda.core._module import Kernel
from cuda.core._stream import Stream
from cuda.core._utils.cuda_utils import (
_reduce_3_tuple,
get_binding_version,
)


cdef bint _inited = False
cdef bint _use_ex = False
cdef object _lock = threading.Lock()


cdef int _lazy_init() except?-1:
global _inited, _use_ex
if _inited:
return 0

cdef int _driver_ver
with _lock:
if _inited:
return 0

# binding availability depends on cuda-python version
_py_major_minor = get_binding_version()
HANDLE_RETURN(cydriver.cuDriverGetVersion(&_driver_ver))
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
_inited = True

return 0
from math import prod


def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kernel: Kernel, *kernel_args):
Expand All @@ -70,49 +40,31 @@ def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kern

"""
cdef Stream s = Stream_accept(stream, allow_stream_protocol=True)
_lazy_init()
cdef LaunchConfig conf = check_or_create_options(LaunchConfig, config, "launch config")

# TODO: can we ensure kernel_args is valid/safe to use here?
# TODO: merge with HelperKernelParams?
cdef ParamHolder ker_args = ParamHolder(kernel_args)
cdef void** args_ptr = <void**><uintptr_t>(ker_args.ptr)

# Note: We now use CUkernel handles exclusively (CUDA 12+), but they can be cast to
# CUfunction for use with cuLaunchKernel, as both handle types are interchangeable
# for kernel launch purposes.
cdef Kernel ker = <Kernel>kernel
cdef cydriver.CUfunction func_handle = <cydriver.CUfunction>as_cu(ker._h_kernel)

# Note: CUkernel can still be launched via cuLaunchKernel (not just cuLaunchKernelEx).
# We check both binding & driver versions here mainly to see if the "Ex" API is
# available and if so we use it, as it's more feature rich.
if _use_ex:
drv_cfg = conf._to_native_launch_config()
drv_cfg.hStream = as_cu(s._h_stream)
if conf.cooperative_launch:
_check_cooperative_launch(kernel, conf, s)
with nogil:
HANDLE_RETURN(cydriver.cuLaunchKernelEx(&drv_cfg, func_handle, args_ptr, NULL))
else:
# TODO: check if config has any unsupported attrs
HANDLE_RETURN(
cydriver.cuLaunchKernel(
func_handle,
conf.grid[0], conf.grid[1], conf.grid[2],
conf.block[0], conf.block[1], conf.block[2],
conf.shmem_size, as_cu(s._h_stream), args_ptr, NULL
)
)
drv_cfg = conf._to_native_launch_config()
drv_cfg.hStream = as_cu(s._h_stream)
if conf.cooperative_launch:
_check_cooperative_launch(kernel, conf, s)
with nogil:
HANDLE_RETURN(cydriver.cuLaunchKernelEx(&drv_cfg, func_handle, args_ptr, NULL))


cdef _check_cooperative_launch(kernel: Kernel, config: LaunchConfig, stream: Stream):
dev = stream.device
num_sm = dev.properties.multiprocessor_count
max_grid_size = (
kernel.occupancy.max_active_blocks_per_multiprocessor(_reduce_3_tuple(config.block), config.shmem_size) * num_sm
kernel.occupancy.max_active_blocks_per_multiprocessor(prod(config.block), config.shmem_size) * num_sm
)
if _reduce_3_tuple(config.grid) > max_grid_size:
if prod(config.grid) > max_grid_size:
# For now let's try not to be smart and adjust the grid size behind users' back.
# We explicitly ask users to adjust.
x, y, z = config.grid
Expand Down
12 changes: 4 additions & 8 deletions cuda_core/cuda/core/_linker.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ from cuda.core._utils.cuda_utils import (
CUDAError,
check_or_create_options,
driver,
handle_return,
is_sequence,
)

Expand Down Expand Up @@ -620,9 +619,8 @@ cdef inline void Linker_annotate_error_log(Linker self, object e):

# TODO: revisit this treatment for py313t builds
_driver = None # populated if nvJitLink cannot be used
_driver_ver = None
_inited = False
_use_nvjitlink_backend = False # set by _decide_nvjitlink_or_driver()
_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver()

# Input type mappings populated by _lazy_init() with C-level enum ints.
_nvjitlink_input_types = None
Expand All @@ -637,13 +635,10 @@ def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
# Note: this function is reused in the tests
def _decide_nvjitlink_or_driver() -> bool:
"""Return True if falling back to the cuLink* driver APIs."""
global _driver_ver, _driver, _use_nvjitlink_backend
if _driver_ver is not None:
global _driver, _use_nvjitlink_backend
if _use_nvjitlink_backend is not None:
return not _use_nvjitlink_backend

_driver_ver = handle_return(driver.cuDriverGetVersion())
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)

warn_txt_common = (
"the driver APIs will be used instead, which do not support"
" minor version compatibility or linking LTO IRs."
Expand All @@ -668,6 +663,7 @@ def _decide_nvjitlink_or_driver() -> bool:
)

warn(warn_txt, stacklevel=2, category=RuntimeWarning)
_use_nvjitlink_backend = False
_driver = driver
return True

Expand Down
5 changes: 2 additions & 3 deletions cuda_core/cuda/core/_memory/_virtual_memory_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
Transaction,
check_or_create_options,
driver,
get_binding_version,
)
from cuda.core._utils.cuda_utils import (
_check_driver_error as raise_if_driver_error,
)
from cuda.core._utils.version import binding_version

__all__ = ["VirtualMemoryResource", "VirtualMemoryResourceOptions"]

Expand Down Expand Up @@ -99,8 +99,7 @@ class VirtualMemoryResourceOptions:
_t = driver.CUmemAllocationType
# CUDA 13+ exposes MANAGED in CUmemAllocationType; older 12.x does not
_allocation_type = {"pinned": _t.CU_MEM_ALLOCATION_TYPE_PINNED} # noqa: RUF012
ver_major, ver_minor = get_binding_version()
if ver_major >= 13:
if binding_version() >= (13, 0, 0):
_allocation_type["managed"] = _t.CU_MEM_ALLOCATION_TYPE_MANAGED

@staticmethod
Expand Down
Loading
Loading