@@ -56,10 +56,19 @@ from cuda.core._graph._utils cimport (
5656 _attach_user_object,
5757)
5858
59+ import weakref
60+
5961from cuda.core import Device
6062from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy
6163from cuda.core._utils.cuda_utils import driver, handle_return
6264
65+ _node_cache = weakref.WeakValueDictionary()
66+
67+
68+ cdef inline GraphNode _cached(GraphNode n):
69+ _node_cache[< uintptr_t> n._h_node.get()] = n
70+ return n
71+
6372
6473cdef class GraphNode:
6574 """ Base class for all graph nodes.
@@ -510,18 +519,30 @@ cdef inline ConditionalNode _make_conditional_node(
510519 n._cond_type = cond_type
511520 n._branches = branches
512521
513- return n
522+ return _cached(n)
523+
514524
515525cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node):
526+ cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph)
527+
528+ # Sentinel: virtual node to represent the graph entry point.
516529 if node == NULL :
517530 n = GraphNode.__new__ (GraphNode)
518- (< GraphNode> n)._h_node = create_graph_node_handle(node, h_graph)
531+ (< GraphNode> n)._h_node = h_node
519532 return n
520533
521- cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph)
534+ # Return a cached object or create and cache a new one.
535+ cached = _node_cache.get(< uintptr_t> h_node.get())
536+ if cached is not None :
537+ return < GraphNode> cached
538+ else :
539+ return _cached(GN_create_impl(h_node))
540+
541+
542+ cdef inline GraphNode GN_create_impl(GraphNodeHandle h_node):
522543 cdef cydriver.CUgraphNodeType node_type
523544 with nogil:
524- HANDLE_RETURN(cydriver.cuGraphNodeGetType(node , & node_type))
545+ HANDLE_RETURN(cydriver.cuGraphNodeGetType(as_cu(h_node) , & node_type))
525546
526547 if node_type == cydriver.CU_GRAPH_NODE_TYPE_EMPTY:
527548 return EmptyNode._create_impl(h_node)
@@ -583,10 +604,10 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker,
583604 _attach_user_object(as_cu(h_graph), < void * > new KernelHandle(ker._h_kernel),
584605 < cydriver.CUhostFn> _destroy_kernel_handle_copy)
585606
586- return KernelNode._create_with_params(
607+ return _cached( KernelNode._create_with_params(
587608 create_graph_node_handle(new_node, h_graph),
588609 conf.grid, conf.block, conf.shmem_size,
589- ker._h_kernel)
610+ ker._h_kernel))
590611
591612
592613cdef inline EmptyNode GN_join(GraphNode self , tuple nodes):
@@ -612,7 +633,7 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes):
612633 HANDLE_RETURN(cydriver.cuGraphAddEmptyNode(
613634 & new_node, as_cu(h_graph), deps_ptr, num_deps))
614635
615- return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph))
636+ return _cached( EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph) ))
616637
617638
618639cdef inline AllocNode GN_alloc(GraphNode self , size_t size, object options):
@@ -688,9 +709,9 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options):
688709 HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode(
689710 & new_node, as_cu(h_graph), deps, num_deps, & alloc_params))
690711
691- return AllocNode._create_with_params(
712+ return _cached( AllocNode._create_with_params(
692713 create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size,
693- device_id, memory_type, tuple (peer_ids))
714+ device_id, memory_type, tuple (peer_ids)))
694715
695716
696717cdef inline FreeNode GN_free(GraphNode self , cydriver.CUdeviceptr c_dptr):
@@ -708,7 +729,7 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr):
708729 HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode(
709730 & new_node, as_cu(h_graph), deps, num_deps, c_dptr))
710731
711- return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr)
732+ return _cached( FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr) )
712733
713734
714735cdef inline MemsetNode GN_memset(
@@ -743,9 +764,9 @@ cdef inline MemsetNode GN_memset(
743764 & new_node, as_cu(h_graph), deps, num_deps,
744765 & memset_params, ctx))
745766
746- return MemsetNode._create_with_params(
767+ return _cached( MemsetNode._create_with_params(
747768 create_graph_node_handle(new_node, h_graph), c_dst,
748- val, elem_size, width, height, pitch)
769+ val, elem_size, width, height, pitch))
749770
750771
751772cdef inline MemcpyNode GN_memcpy(
@@ -804,9 +825,9 @@ cdef inline MemcpyNode GN_memcpy(
804825 HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode(
805826 & new_node, as_cu(h_graph), deps, num_deps, & params, ctx))
806827
807- return MemcpyNode._create_with_params(
828+ return _cached( MemcpyNode._create_with_params(
808829 create_graph_node_handle(new_node, h_graph), c_dst, c_src, size,
809- c_dst_type, c_src_type)
830+ c_dst_type, c_src_type))
810831
811832
812833cdef inline ChildGraphNode GN_embed(GraphNode self , GraphDef child_def):
@@ -831,8 +852,8 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def):
831852
832853 cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph)
833854
834- return ChildGraphNode._create_with_params(
835- create_graph_node_handle(new_node, h_graph), h_embedded)
855+ return _cached( ChildGraphNode._create_with_params(
856+ create_graph_node_handle(new_node, h_graph), h_embedded))
836857
837858
838859cdef inline EventRecordNode GN_record_event(GraphNode self , Event ev):
@@ -853,8 +874,8 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev):
853874 _attach_user_object(as_cu(h_graph), < void * > new EventHandle(ev._h_event),
854875 < cydriver.CUhostFn> _destroy_event_handle_copy)
855876
856- return EventRecordNode._create_with_params(
857- create_graph_node_handle(new_node, h_graph), ev._h_event)
877+ return _cached( EventRecordNode._create_with_params(
878+ create_graph_node_handle(new_node, h_graph), ev._h_event))
858879
859880
860881cdef inline EventWaitNode GN_wait_event(GraphNode self , Event ev):
@@ -875,8 +896,8 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev):
875896 _attach_user_object(as_cu(h_graph), < void * > new EventHandle(ev._h_event),
876897 < cydriver.CUhostFn> _destroy_event_handle_copy)
877898
878- return EventWaitNode._create_with_params(
879- create_graph_node_handle(new_node, h_graph), ev._h_event)
899+ return _cached( EventWaitNode._create_with_params(
900+ create_graph_node_handle(new_node, h_graph), ev._h_event))
880901
881902
882903cdef inline HostCallbackNode GN_callback(GraphNode self , object fn, object user_data):
@@ -902,6 +923,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_
902923 & new_node, as_cu(h_graph), deps, num_deps, & node_params))
903924
904925 cdef object callable_obj = fn if not isinstance (fn, ct._CFuncPtr) else None
905- return HostCallbackNode._create_with_params(
926+ return _cached( HostCallbackNode._create_with_params(
906927 create_graph_node_handle(new_node, h_graph), callable_obj,
907- node_params.fn, node_params.userData)
928+ node_params.fn, node_params.userData))
0 commit comments