Skip to content

Commit 310b994

Browse files
committed
Add explicit CUDA graph construction API (GraphDef, GraphNode)
Introduces GraphDef and GraphNode types for explicit CUDA graph construction, with a full node hierarchy, shared instantiation helper with GraphCompleteOptions support, and comprehensive tests. Made-with: Cursor
1 parent 3ed5217 commit 310b994

16 files changed

+5318
-141
lines changed

cuda_core/cuda/core/_cpp/resource_handles.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ decltype(&cuLibraryLoadData) p_cuLibraryLoadData = nullptr;
5656
decltype(&cuLibraryUnload) p_cuLibraryUnload = nullptr;
5757
decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr;
5858

59+
// Graph
60+
decltype(&cuGraphDestroy) p_cuGraphDestroy = nullptr;
61+
5962
// Linker
6063
decltype(&cuLinkDestroy) p_cuLinkDestroy = nullptr;
6164

@@ -901,6 +904,57 @@ LibraryHandle get_kernel_library(const KernelHandle& h) noexcept {
901904
return get_box(h)->h_library;
902905
}
903906

907+
// ============================================================================
908+
// Graph Handles
909+
// ============================================================================
910+
911+
namespace {
912+
struct GraphBox {
913+
CUgraph resource;
914+
GraphHandle h_parent; // Keeps parent alive for child/branch graphs
915+
};
916+
} // namespace
917+
918+
GraphHandle create_graph_handle(CUgraph graph) {
919+
auto box = std::shared_ptr<const GraphBox>(
920+
new GraphBox{graph, {}},
921+
[](const GraphBox* b) {
922+
GILReleaseGuard gil;
923+
p_cuGraphDestroy(b->resource);
924+
delete b;
925+
}
926+
);
927+
return GraphHandle(box, &box->resource);
928+
}
929+
930+
GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent) {
931+
auto box = std::make_shared<const GraphBox>(GraphBox{graph, h_parent});
932+
return GraphHandle(box, &box->resource);
933+
}
934+
935+
namespace {
936+
struct GraphNodeBox {
937+
CUgraphNode resource;
938+
GraphHandle h_graph;
939+
};
940+
} // namespace
941+
942+
static const GraphNodeBox* get_box(const GraphNodeHandle& h) {
943+
const CUgraphNode* p = h.get();
944+
return reinterpret_cast<const GraphNodeBox*>(
945+
reinterpret_cast<const char*>(p) - offsetof(GraphNodeBox, resource)
946+
);
947+
}
948+
949+
GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) {
950+
auto box = std::make_shared<const GraphNodeBox>(GraphNodeBox{node, h_graph});
951+
return GraphNodeHandle(box, &box->resource);
952+
}
953+
954+
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept {
955+
return h ? get_box(h)->h_graph : GraphHandle{};
956+
}
957+
904958
// ============================================================================
905959
// Graphics Resource Handles
906960
// ============================================================================

cuda_core/cuda/core/_cpp/resource_handles.hpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ extern decltype(&cuLibraryLoadData) p_cuLibraryLoadData;
9292
extern decltype(&cuLibraryUnload) p_cuLibraryUnload;
9393
extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel;
9494

95+
// Graph
96+
extern decltype(&cuGraphDestroy) p_cuGraphDestroy;
97+
9598
// Linker
9699
extern decltype(&cuLinkDestroy) p_cuLinkDestroy;
97100

@@ -143,6 +146,8 @@ using EventHandle = std::shared_ptr<const CUevent>;
143146
using MemoryPoolHandle = std::shared_ptr<const CUmemoryPool>;
144147
using LibraryHandle = std::shared_ptr<const CUlibrary>;
145148
using KernelHandle = std::shared_ptr<const CUkernel>;
149+
using GraphHandle = std::shared_ptr<const CUgraph>;
150+
using GraphNodeHandle = std::shared_ptr<const CUgraphNode>;
146151
using GraphicsResourceHandle = std::shared_ptr<const CUgraphicsResource>;
147152
using NvrtcProgramHandle = std::shared_ptr<const nvrtcProgram>;
148153
using NvvmProgramHandle = std::shared_ptr<const NvvmProgramValue>;
@@ -371,6 +376,33 @@ KernelHandle create_kernel_handle_ref(CUkernel kernel);
371376
// Returns empty handle if the kernel has no library dependency.
372377
LibraryHandle get_kernel_library(const KernelHandle& h) noexcept;
373378

379+
// ============================================================================
380+
// Graph handle functions
381+
// ============================================================================
382+
383+
// Wrap an externally-created CUgraph with RAII cleanup.
384+
// When the last reference is released, cuGraphDestroy is called automatically.
385+
// The caller must have already created the graph via cuGraphCreate.
386+
GraphHandle create_graph_handle(CUgraph graph);
387+
388+
// Create a non-owning graph handle that keeps h_parent alive.
389+
// Use for graphs owned by a child/conditional node in a parent graph.
390+
// The child graph will NOT be destroyed when this handle is released,
391+
// but h_parent will be prevented from destruction while this handle exists.
392+
GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent);
393+
394+
// ============================================================================
395+
// Graph node handle functions
396+
// ============================================================================
397+
398+
// Create a node handle. Nodes are owned by their parent graph (not
399+
// independently destroyable). The GraphHandle dependency ensures the
400+
// graph outlives any node reference.
401+
GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph);
402+
403+
// Extract the owning graph handle from a node handle.
404+
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept;
405+
374406
// ============================================================================
375407
// Graphics resource handle functions
376408
// ============================================================================
@@ -467,6 +499,14 @@ inline CUkernel as_cu(const KernelHandle& h) noexcept {
467499
return h ? *h : nullptr;
468500
}
469501

502+
inline CUgraph as_cu(const GraphHandle& h) noexcept {
503+
return h ? *h : nullptr;
504+
}
505+
506+
inline CUgraphNode as_cu(const GraphNodeHandle& h) noexcept {
507+
return h ? *h : nullptr;
508+
}
509+
470510
inline CUgraphicsResource as_cu(const GraphicsResourceHandle& h) noexcept {
471511
return h ? *h : nullptr;
472512
}
@@ -517,6 +557,14 @@ inline std::intptr_t as_intptr(const KernelHandle& h) noexcept {
517557
return reinterpret_cast<std::intptr_t>(as_cu(h));
518558
}
519559

560+
inline std::intptr_t as_intptr(const GraphHandle& h) noexcept {
561+
return reinterpret_cast<std::intptr_t>(as_cu(h));
562+
}
563+
564+
inline std::intptr_t as_intptr(const GraphNodeHandle& h) noexcept {
565+
return reinterpret_cast<std::intptr_t>(as_cu(h));
566+
}
567+
520568
inline std::intptr_t as_intptr(const GraphicsResourceHandle& h) noexcept {
521569
return reinterpret_cast<std::intptr_t>(as_cu(h));
522570
}
@@ -595,6 +643,17 @@ inline PyObject* as_py(const KernelHandle& h) noexcept {
595643
return detail::make_py("cuda.bindings.driver", "CUkernel", as_intptr(h));
596644
}
597645

646+
inline PyObject* as_py(const GraphHandle& h) noexcept {
647+
return detail::make_py("cuda.bindings.driver", "CUgraph", as_intptr(h));
648+
}
649+
650+
inline PyObject* as_py(const GraphNodeHandle& h) noexcept {
651+
if (!as_intptr(h)) {
652+
Py_RETURN_NONE;
653+
}
654+
return detail::make_py("cuda.bindings.driver", "CUgraphNode", as_intptr(h));
655+
}
656+
598657
inline PyObject* as_py(const NvrtcProgramHandle& h) noexcept {
599658
return detail::make_py("cuda.bindings.nvrtc", "nvrtcProgram", as_intptr(h));
600659
}

cuda_core/cuda/core/_graph/__init__.py

Lines changed: 77 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,43 @@ class GraphDebugPrintOptions:
9191
extra_topo_info: bool = False
9292
conditional_node_params: bool = False
9393

94+
def _to_flags(self) -> int:
95+
"""Convert options to CUDA driver API flags (internal use)."""
96+
flags = 0
97+
if self.verbose:
98+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE
99+
if self.runtime_types:
100+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES
101+
if self.kernel_node_params:
102+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS
103+
if self.memcpy_node_params:
104+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS
105+
if self.memset_node_params:
106+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS
107+
if self.host_node_params:
108+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS
109+
if self.event_node_params:
110+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS
111+
if self.ext_semas_signal_node_params:
112+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS
113+
if self.ext_semas_wait_node_params:
114+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS
115+
if self.kernel_node_attributes:
116+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES
117+
if self.handles:
118+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES
119+
if self.mem_alloc_node_params:
120+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS
121+
if self.mem_free_node_params:
122+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS
123+
if self.batch_mem_op_node_params:
124+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS
125+
if self.extra_topo_info:
126+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO
127+
if self.conditional_node_params:
128+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS
129+
return flags
130+
94131

95132
@dataclass
96133
class GraphCompleteOptions:
@@ -118,6 +155,44 @@ class GraphCompleteOptions:
118155
use_node_priority: bool = False
119156

120157

158+
def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph:
159+
params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS()
160+
if options:
161+
flags = 0
162+
if options.auto_free_on_launch:
163+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
164+
if options.upload_stream:
165+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
166+
params.hUploadStream = options.upload_stream.handle
167+
if options.device_launch:
168+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
169+
if options.use_node_priority:
170+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
171+
params.flags = flags
172+
173+
graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params)))
174+
if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR:
175+
raise RuntimeError(
176+
"Instantiation failed for an unexpected reason which is described in the return value of the function."
177+
)
178+
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE:
179+
raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
180+
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED:
181+
raise RuntimeError(
182+
"Instantiation for device launch failed because the graph contained an unsupported operation."
183+
)
184+
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
185+
raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
186+
elif (
187+
_py_major_minor >= (12, 8)
188+
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
189+
):
190+
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
191+
elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS:
192+
raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
193+
return graph
194+
195+
121196
class GraphBuilder:
122197
"""Represents a graph under construction.
123198
@@ -280,53 +355,7 @@ def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
280355
if not self._building_ended:
281356
raise RuntimeError("Graph has not finished building.")
282357

283-
if (_driver_ver < 12000) or (_py_major_minor < (12, 0)):
284-
flags = 0
285-
if options:
286-
if options.auto_free_on_launch:
287-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
288-
if options.use_node_priority:
289-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
290-
return Graph._init(handle_return(driver.cuGraphInstantiateWithFlags(self._mnff.graph, flags)))
291-
292-
params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS()
293-
if options:
294-
flags = 0
295-
if options.auto_free_on_launch:
296-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
297-
if options.upload_stream:
298-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
299-
params.hUploadStream = options.upload_stream.handle
300-
if options.device_launch:
301-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
302-
if options.use_node_priority:
303-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
304-
params.flags = flags
305-
306-
graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(self._mnff.graph, params)))
307-
if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR:
308-
# NOTE: Should never get here since the handle_return should have caught this case
309-
raise RuntimeError(
310-
"Instantiation failed for an unexpected reason which is described in the return value of the function."
311-
)
312-
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE:
313-
raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
314-
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED:
315-
raise RuntimeError(
316-
"Instantiation for device launch failed because the graph contained an unsupported operation."
317-
)
318-
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
319-
raise RuntimeError(
320-
"Instantiation for device launch failed due to the nodes belonging to different contexts."
321-
)
322-
elif (
323-
_py_major_minor >= (12, 8)
324-
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
325-
):
326-
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
327-
elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS:
328-
raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
329-
return graph
358+
return _instantiate_graph(self._mnff.graph, options)
330359

331360
def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
332361
"""Generates a DOT debug file for the graph builder.
@@ -341,41 +370,7 @@ def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
341370
"""
342371
if not self._building_ended:
343372
raise RuntimeError("Graph has not finished building.")
344-
flags = 0
345-
if options:
346-
if options.verbose:
347-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE
348-
if options.runtime_types:
349-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES
350-
if options.kernel_node_params:
351-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS
352-
if options.memcpy_node_params:
353-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS
354-
if options.memset_node_params:
355-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS
356-
if options.host_node_params:
357-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS
358-
if options.event_node_params:
359-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS
360-
if options.ext_semas_signal_node_params:
361-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS
362-
if options.ext_semas_wait_node_params:
363-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS
364-
if options.kernel_node_attributes:
365-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES
366-
if options.handles:
367-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES
368-
if options.mem_alloc_node_params:
369-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS
370-
if options.mem_free_node_params:
371-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS
372-
if options.batch_mem_op_node_params:
373-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS
374-
if options.extra_topo_info:
375-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO
376-
if options.conditional_node_params:
377-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS
378-
373+
flags = options._to_flags() if options else 0
379374
handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags))
380375

381376
def split(self, count: int) -> tuple[GraphBuilder, ...]:

0 commit comments

Comments
 (0)