@@ -155,6 +155,44 @@ class GraphCompleteOptions:
155155 use_node_priority : bool = False
156156
157157
158+ def _instantiate_graph (h_graph , options : GraphCompleteOptions | None = None ) -> Graph :
159+ params = driver .CUDA_GRAPH_INSTANTIATE_PARAMS ()
160+ if options :
161+ flags = 0
162+ if options .auto_free_on_launch :
163+ flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
164+ if options .upload_stream :
165+ flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
166+ params .hUploadStream = options .upload_stream .handle
167+ if options .device_launch :
168+ flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
169+ if options .use_node_priority :
170+ flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
171+ params .flags = flags
172+
173+ graph = Graph ._init (handle_return (driver .cuGraphInstantiateWithParams (h_graph , params )))
174+ if params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_ERROR :
175+ raise RuntimeError (
176+ "Instantiation failed for an unexpected reason which is described in the return value of the function."
177+ )
178+ elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE :
179+ raise RuntimeError ("Instantiation failed due to invalid structure, such as cycles." )
180+ elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED :
181+ raise RuntimeError (
182+ "Instantiation for device launch failed because the graph contained an unsupported operation."
183+ )
184+ elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED :
185+ raise RuntimeError ("Instantiation for device launch failed due to the nodes belonging to different contexts." )
186+ elif (
187+ _py_major_minor >= (12 , 8 )
188+ and params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
189+ ):
190+ raise RuntimeError ("One or more conditional handles are not associated with conditional builders." )
191+ elif params .result_out != driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_SUCCESS :
192+ raise RuntimeError (f"Graph instantiation failed with unexpected error code: { params .result_out } " )
193+ return graph
194+
195+
158196class GraphBuilder :
159197 """Represents a graph under construction.
160198
@@ -317,53 +355,7 @@ def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
317355 if not self ._building_ended :
318356 raise RuntimeError ("Graph has not finished building." )
319357
320- if (_driver_ver < 12000 ) or (_py_major_minor < (12 , 0 )):
321- flags = 0
322- if options :
323- if options .auto_free_on_launch :
324- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
325- if options .use_node_priority :
326- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
327- return Graph ._init (handle_return (driver .cuGraphInstantiateWithFlags (self ._mnff .graph , flags )))
328-
329- params = driver .CUDA_GRAPH_INSTANTIATE_PARAMS ()
330- if options :
331- flags = 0
332- if options .auto_free_on_launch :
333- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
334- if options .upload_stream :
335- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
336- params .hUploadStream = options .upload_stream .handle
337- if options .device_launch :
338- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
339- if options .use_node_priority :
340- flags |= driver .CUgraphInstantiate_flags .CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
341- params .flags = flags
342-
343- graph = Graph ._init (handle_return (driver .cuGraphInstantiateWithParams (self ._mnff .graph , params )))
344- if params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_ERROR :
345- # NOTE: Should never get here since the handle_return should have caught this case
346- raise RuntimeError (
347- "Instantiation failed for an unexpected reason which is described in the return value of the function."
348- )
349- elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE :
350- raise RuntimeError ("Instantiation failed due to invalid structure, such as cycles." )
351- elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED :
352- raise RuntimeError (
353- "Instantiation for device launch failed because the graph contained an unsupported operation."
354- )
355- elif params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED :
356- raise RuntimeError (
357- "Instantiation for device launch failed due to the nodes belonging to different contexts."
358- )
359- elif (
360- _py_major_minor >= (12 , 8 )
361- and params .result_out == driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
362- ):
363- raise RuntimeError ("One or more conditional handles are not associated with conditional builders." )
364- elif params .result_out != driver .CUgraphInstantiateResult .CUDA_GRAPH_INSTANTIATE_SUCCESS :
365- raise RuntimeError (f"Graph instantiation failed with unexpected error code: { params .result_out } " )
366- return graph
358+ return _instantiate_graph (self ._mnff .graph , options )
367359
368360 def debug_dot_print (self , path , options : GraphDebugPrintOptions | None = None ):
369361 """Generates a DOT debug file for the graph builder.
0 commit comments