55import weakref
66from dataclasses import dataclass
77
8+ from libc.stdint cimport intptr_t
9+
810from cuda.bindings cimport cydriver
911
1012from cuda.core._graph._utils cimport _attach_host_callback_to_graph
@@ -14,6 +16,7 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
1416from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
1517
1618from cuda.core._utils.cuda_utils import (
19+ CUDAError,
1720 driver,
1821 handle_return,
1922)
@@ -783,24 +786,42 @@ class Graph:
783786 """
784787 return self._mnff.graph
785788
786- def update(self , builder: GraphBuilder ) :
787- """ Update the graph using new build configuration from the builder .
789+ def update(self , source: " GraphBuilder | GraphDef") -> None :
790+ """Update the graph using a new graph definition .
788791
789- The topology of the provided builder must be identical to this graph.
792+ The topology of the provided source must be identical to this graph.
790793
791794 Parameters
792795 ----------
793- builder : :obj:`~_graph.GraphBuilder`
794- The builder to update the graph with.
796+ source : :obj:`~_graph.GraphBuilder` or :obj:`~_graph._graph_def.GraphDef`
797+ The graph definition to update from. A GraphBuilder must have
798+ finished building.
795799
796800 """
797- if not builder._building_ended:
798- raise ValueError (" Graph has not finished building." )
801+ from cuda.core._graph._graph_def import GraphDef
802+
803+ cdef cydriver.CUgraph cu_graph
804+ cdef cydriver.CUgraphExec cu_exec = < cydriver.CUgraphExec>< intptr_t> int (self ._mnff.graph)
799805
800- # Update the graph with the new nodes from the builder
801- exec_update_result = handle_return(driver.cuGraphExecUpdate(self ._mnff.graph, builder._mnff.graph))
802- if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS:
803- raise RuntimeError (f" Failed to update graph: {exec_update_result.result()}" )
806+ if isinstance(source , GraphBuilder ):
807+ if not source._building_ended:
808+ raise ValueError (" Graph has not finished building." )
809+ cu_graph = < cydriver.CUgraph>< intptr_t> int (source._mnff.graph)
810+ elif isinstance (source, GraphDef):
811+ cu_graph = < cydriver.CUgraph>< intptr_t> int (source.handle)
812+ else :
813+ raise TypeError (
814+ f" expected GraphBuilder or GraphDef, got {type(source).__name__}" )
815+
816+ cdef cydriver.CUgraphExecUpdateResultInfo result_info
817+ cdef cydriver.CUresult err
818+ with nogil:
819+ err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, & result_info)
820+ if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE:
821+ reason = driver.CUgraphExecUpdateResult(result_info.result)
822+ msg = f" Graph update failed: {reason.__doc__.strip()} ({reason.name})"
823+ raise CUDAError(msg)
824+ HANDLE_RETURN(err)
804825
805826 def upload (self , stream: Stream ):
806827 """ Uploads the graph in a stream.
0 commit comments