@@ -91,6 +91,43 @@ class GraphDebugPrintOptions:
9191 extra_topo_info : bool = False
9292 conditional_node_params : bool = False
9393
94+ def _to_flags (self ) -> int :
95+ """Convert options to CUDA driver API flags (internal use)."""
96+ flags = 0
97+ if self .verbose :
98+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE
99+ if self .runtime_types :
100+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES
101+ if self .kernel_node_params :
102+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS
103+ if self .memcpy_node_params :
104+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS
105+ if self .memset_node_params :
106+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS
107+ if self .host_node_params :
108+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS
109+ if self .event_node_params :
110+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS
111+ if self .ext_semas_signal_node_params :
112+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS
113+ if self .ext_semas_wait_node_params :
114+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS
115+ if self .kernel_node_attributes :
116+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES
117+ if self .handles :
118+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES
119+ if self .mem_alloc_node_params :
120+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS
121+ if self .mem_free_node_params :
122+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS
123+ if self .batch_mem_op_node_params :
124+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS
125+ if self .extra_topo_info :
126+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO
127+ if self .conditional_node_params :
128+ flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS
129+ return flags
130+
94131
95132@dataclass
96133class GraphCompleteOptions :
@@ -118,6 +155,44 @@ class GraphCompleteOptions:
118155 use_node_priority : bool = False
119156
120157
158+ def _instantiate_graph (h_graph , options : GraphCompleteOptions | None = None ) -> Graph :
159+ params = driver .CUDA_GRAPH_INSTANTIATE_PARAMS ()
160+ if options :
161+ flags = 0
162+ if options .auto_free_on_launch :
163+ flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
164+ if options .upload_stream :
165+ flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
166+ params .hUploadStream = options .upload_stream .handle
167+ if options .device_launch :
168+ flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
169+ if options .use_node_priority :
170+ flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
171+ params .flags = flags
172+
173+ graph = Graph ._init (handle_return (driver .cuGraphInstantiateWithParams (h_graph , params )))
174+ if params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_ERROR :
175+ raise RuntimeError (
176+ "Instantiation failed for an unexpected reason which is described in the return value of the function."
177+ )
178+ elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE :
179+ raise RuntimeError ("Instantiation failed due to invalid structure, such as cycles." )
180+ elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED :
181+ raise RuntimeError (
182+ "Instantiation for device launch failed because the graph contained an unsupported operation."
183+ )
184+ elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED :
185+ raise RuntimeError ("Instantiation for device launch failed due to the nodes belonging to different contexts." )
186+ elif (
187+ _py_major_minor >= (12 , 8 )
188+ and params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
189+ ):
190+ raise RuntimeError ("One or more conditional handles are not associated with conditional builders." )
191+ elif params .result_out != driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_SUCCESS :
192+ raise RuntimeError (f"Graph instantiation failed with unexpected error code: { params .result_out } " )
193+ return graph
194+
195+
121196class GraphBuilder :
122197 """Represents a graph under construction.
123198
@@ -280,53 +355,7 @@ def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
280355 if not self ._building_ended :
281356 raise RuntimeError ("Graph has not finished building." )
282357
283- if (_driver_ver < 12000 ) or (_py_major_minor < (12 , 0 )):
284- flags = 0
285- if options :
286- if options .auto_free_on_launch :
287- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
288- if options .use_node_priority :
289- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
290- return Graph ._init (handle_return (driver .cuGraphInstantiateWithFlags (self ._mnff .graph , flags )))
291-
292- params = driver .CUDA_GRAPH_INSTANTIATE_PARAMS ()
293- if options :
294- flags = 0
295- if options .auto_free_on_launch :
296- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
297- if options .upload_stream :
298- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
299- params .hUploadStream = options .upload_stream .handle
300- if options .device_launch :
301- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
302- if options .use_node_priority :
303- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
304- params .flags = flags
305-
306- graph = Graph ._init (handle_return (driver .cuGraphInstantiateWithParams (self ._mnff .graph , params )))
307- if params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_ERROR :
308- # NOTE: Should never get here since the handle_return should have caught this case
309- raise RuntimeError (
310- "Instantiation failed for an unexpected reason which is described in the return value of the function."
311- )
312- elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE :
313- raise RuntimeError ("Instantiation failed due to invalid structure, such as cycles." )
314- elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED :
315- raise RuntimeError (
316- "Instantiation for device launch failed because the graph contained an unsupported operation."
317- )
318- elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED :
319- raise RuntimeError (
320- "Instantiation for device launch failed due to the nodes belonging to different contexts."
321- )
322- elif (
323- _py_major_minor >= (12 , 8 )
324- and params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
325- ):
326- raise RuntimeError ("One or more conditional handles are not associated with conditional builders." )
327- elif params .result_out != driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_SUCCESS :
328- raise RuntimeError (f"Graph instantiation failed with unexpected error code: { params .result_out } " )
329- return graph
358+ return _instantiate_graph (self ._mnff .graph , options )
330359
331360 def debug_dot_print (self , path , options : GraphDebugPrintOptions | None = None ):
332361 """Generates a DOT debug file for the graph builder.
@@ -341,41 +370,7 @@ def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
341370 """
342371 if not self ._building_ended :
343372 raise RuntimeError ("Graph has not finished building." )
344- flags = 0
345- if options :
346- if options .verbose :
347- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE
348- if options .runtime_types :
349- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES
350- if options .kernel_node_params :
351- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS
352- if options .memcpy_node_params :
353- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS
354- if options .memset_node_params :
355- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS
356- if options .host_node_params :
357- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS
358- if options .event_node_params :
359- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS
360- if options .ext_semas_signal_node_params :
361- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS
362- if options .ext_semas_wait_node_params :
363- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS
364- if options .kernel_node_attributes :
365- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES
366- if options .handles :
367- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES
368- if options .mem_alloc_node_params :
369- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS
370- if options .mem_free_node_params :
371- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS
372- if options .batch_mem_op_node_params :
373- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS
374- if options .extra_topo_info :
375- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO
376- if options .conditional_node_params :
377- flags |= driver .CUgraphDebugDot_flags .CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS
378-
373+ flags = options ._to_flags () if options else 0
379374 handle_return (driver .cuGraphDebugDotPrint (self ._mnff .graph , path , flags ))
380375
381376 def split (self , count : int ) -> tuple [GraphBuilder , ...]:
0 commit comments