Skip to content

Commit f07dba3

Browse files
committed
Extract shared graph instantiation helper
Move instantiation logic from GraphBuilder.complete() into a shared _instantiate_graph() helper so GraphDef.instantiate() can reuse it with full GraphCompleteOptions support and error handling. Made-with: Cursor
1 parent a60a56a commit f07dba3

File tree

2 files changed

+48
-53
lines changed

2 files changed

+48
-53
lines changed

cuda_core/cuda/core/_graph/__init__.py

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,44 @@ class GraphCompleteOptions:
155155
use_node_priority: bool = False
156156

157157

158+
def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph:
159+
params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS()
160+
if options:
161+
flags = 0
162+
if options.auto_free_on_launch:
163+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
164+
if options.upload_stream:
165+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
166+
params.hUploadStream = options.upload_stream.handle
167+
if options.device_launch:
168+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
169+
if options.use_node_priority:
170+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
171+
params.flags = flags
172+
173+
graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params)))
174+
if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR:
175+
raise RuntimeError(
176+
"Instantiation failed for an unexpected reason which is described in the return value of the function."
177+
)
178+
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE:
179+
raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
180+
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED:
181+
raise RuntimeError(
182+
"Instantiation for device launch failed because the graph contained an unsupported operation."
183+
)
184+
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
185+
raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
186+
elif (
187+
_py_major_minor >= (12, 8)
188+
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
189+
):
190+
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
191+
elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS:
192+
raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
193+
return graph
194+
195+
158196
class GraphBuilder:
159197
"""Represents a graph under construction.
160198
@@ -317,53 +355,7 @@ def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
317355
if not self._building_ended:
318356
raise RuntimeError("Graph has not finished building.")
319357

320-
if (_driver_ver < 12000) or (_py_major_minor < (12, 0)):
321-
flags = 0
322-
if options:
323-
if options.auto_free_on_launch:
324-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
325-
if options.use_node_priority:
326-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
327-
return Graph._init(handle_return(driver.cuGraphInstantiateWithFlags(self._mnff.graph, flags)))
328-
329-
params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS()
330-
if options:
331-
flags = 0
332-
if options.auto_free_on_launch:
333-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
334-
if options.upload_stream:
335-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
336-
params.hUploadStream = options.upload_stream.handle
337-
if options.device_launch:
338-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
339-
if options.use_node_priority:
340-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
341-
params.flags = flags
342-
343-
graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(self._mnff.graph, params)))
344-
if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR:
345-
# NOTE: Should never get here since the handle_return should have caught this case
346-
raise RuntimeError(
347-
"Instantiation failed for an unexpected reason which is described in the return value of the function."
348-
)
349-
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE:
350-
raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
351-
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED:
352-
raise RuntimeError(
353-
"Instantiation for device launch failed because the graph contained an unsupported operation."
354-
)
355-
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
356-
raise RuntimeError(
357-
"Instantiation for device launch failed due to the nodes belonging to different contexts."
358-
)
359-
elif (
360-
_py_major_minor >= (12, 8)
361-
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
362-
):
363-
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
364-
elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS:
365-
raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
366-
return graph
358+
return _instantiate_graph(self._mnff.graph, options)
367359

368360
def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
369361
"""Generates a DOT debug file for the graph builder.

cuda_core/cuda/core/_graph/_graphdef.pyx

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,20 +457,23 @@ cdef class GraphDef:
457457
"""
458458
return self._entry.switch(condition, count)
459459

460-
def instantiate(self):
460+
def instantiate(self, options=None):
461461
"""Instantiate the graph definition into an executable Graph.
462462
463+
Parameters
464+
----------
465+
options : :obj:`~_graph.GraphCompleteOptions`, optional
466+
Customizable dataclass for graph instantiation options.
467+
463468
Returns
464469
-------
465470
Graph
466471
An executable graph that can be launched on a stream.
467472
"""
468-
from cuda.core._graph import Graph
469-
from cuda.core._utils.cuda_utils import handle_return
473+
from cuda.core._graph import _instantiate_graph
470474

471-
graph_exec = handle_return(driver.cuGraphInstantiate(
472-
driver.CUgraph(as_intptr(self._h_graph)), 0))
473-
return Graph._init(graph_exec)
475+
return _instantiate_graph(
476+
driver.CUgraph(as_intptr(self._h_graph)), options)
474477

475478
def debug_dot_print(self, path: str, options=None) -> None:
476479
"""Write a GraphViz DOT representation of the graph to a file.

0 commit comments

Comments
 (0)