Skip to content

Commit 0424a46

Browse files
committed
Add GraphNode identity cache for stable Python object round-trips
Nodes retrieved via GraphDef.nodes(), edges(), or pred/succ traversal now return the same Python object that was originally created, enabling identity checks with `is`. A C++ HandleRegistry deduplicates CUgraphNode handles, and a Cython WeakValueDictionary caches the Python wrapper objects. Made-with: Cursor
1 parent 8554d30 commit 0424a46

File tree

4 files changed

+54
-24
lines changed

4 files changed

+54
-24
lines changed

cuda_core/cuda/core/_cpp/resource_handles.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -969,9 +969,16 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) {
969969
);
970970
}
971971

972+
static HandleRegistry<CUgraphNode, GraphNodeHandle> graph_node_registry;
973+
972974
GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) {
975+
if (auto h = graph_node_registry.lookup(node)) {
976+
return h;
977+
}
973978
auto box = std::make_shared<const GraphNodeBox>(GraphNodeBox{node, h_graph});
974-
return GraphNodeHandle(box, &box->resource);
979+
GraphNodeHandle h(box, &box->resource);
980+
graph_node_registry.register_handle(node, h);
981+
return h;
975982
}
976983

977984
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept {

cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,19 @@ from cuda.core._graph._utils cimport (
5656
_attach_user_object,
5757
)
5858

59+
import weakref
60+
5961
from cuda.core import Device
6062
from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy
6163
from cuda.core._utils.cuda_utils import driver, handle_return
6264

65+
_node_cache = weakref.WeakValueDictionary()
66+
67+
68+
cdef inline GraphNode _cached(GraphNode n):
69+
_node_cache[<uintptr_t>n._h_node.get()] = n
70+
return n
71+
6372

6473
cdef class GraphNode:
6574
"""Base class for all graph nodes.
@@ -510,18 +519,30 @@ cdef inline ConditionalNode _make_conditional_node(
510519
n._cond_type = cond_type
511520
n._branches = branches
512521

513-
return n
522+
return _cached(n)
523+
514524

515525
cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node):
526+
cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph)
527+
528+
# Sentinel: virtual node to represent the graph entry point.
516529
if node == NULL:
517530
n = GraphNode.__new__(GraphNode)
518-
(<GraphNode>n)._h_node = create_graph_node_handle(node, h_graph)
531+
(<GraphNode>n)._h_node = h_node
519532
return n
520533

521-
cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph)
534+
# Return a cached object or create and cache a new one.
535+
cached = _node_cache.get(<uintptr_t>h_node.get())
536+
if cached is not None:
537+
return <GraphNode>cached
538+
else:
539+
return _cached(GN_create_impl(h_node))
540+
541+
542+
cdef inline GraphNode GN_create_impl(GraphNodeHandle h_node):
522543
cdef cydriver.CUgraphNodeType node_type
523544
with nogil:
524-
HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type))
545+
HANDLE_RETURN(cydriver.cuGraphNodeGetType(as_cu(h_node), &node_type))
525546

526547
if node_type == cydriver.CU_GRAPH_NODE_TYPE_EMPTY:
527548
return EmptyNode._create_impl(h_node)
@@ -583,10 +604,10 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker,
583604
_attach_user_object(as_cu(h_graph), <void*>new KernelHandle(ker._h_kernel),
584605
<cydriver.CUhostFn>_destroy_kernel_handle_copy)
585606

586-
return KernelNode._create_with_params(
607+
return _cached(KernelNode._create_with_params(
587608
create_graph_node_handle(new_node, h_graph),
588609
conf.grid, conf.block, conf.shmem_size,
589-
ker._h_kernel)
610+
ker._h_kernel))
590611

591612

592613
cdef inline EmptyNode GN_join(GraphNode self, tuple nodes):
@@ -612,7 +633,7 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes):
612633
HANDLE_RETURN(cydriver.cuGraphAddEmptyNode(
613634
&new_node, as_cu(h_graph), deps_ptr, num_deps))
614635

615-
return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph))
636+
return _cached(EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)))
616637

617638

618639
cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options):
@@ -688,9 +709,9 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options):
688709
HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode(
689710
&new_node, as_cu(h_graph), deps, num_deps, &alloc_params))
690711

691-
return AllocNode._create_with_params(
712+
return _cached(AllocNode._create_with_params(
692713
create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size,
693-
device_id, memory_type, tuple(peer_ids))
714+
device_id, memory_type, tuple(peer_ids)))
694715

695716

696717
cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr):
@@ -708,7 +729,7 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr):
708729
HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode(
709730
&new_node, as_cu(h_graph), deps, num_deps, c_dptr))
710731

711-
return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr)
732+
return _cached(FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr))
712733

713734

714735
cdef inline MemsetNode GN_memset(
@@ -743,9 +764,9 @@ cdef inline MemsetNode GN_memset(
743764
&new_node, as_cu(h_graph), deps, num_deps,
744765
&memset_params, ctx))
745766

746-
return MemsetNode._create_with_params(
767+
return _cached(MemsetNode._create_with_params(
747768
create_graph_node_handle(new_node, h_graph), c_dst,
748-
val, elem_size, width, height, pitch)
769+
val, elem_size, width, height, pitch))
749770

750771

751772
cdef inline MemcpyNode GN_memcpy(
@@ -804,9 +825,9 @@ cdef inline MemcpyNode GN_memcpy(
804825
HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode(
805826
&new_node, as_cu(h_graph), deps, num_deps, &params, ctx))
806827

807-
return MemcpyNode._create_with_params(
828+
return _cached(MemcpyNode._create_with_params(
808829
create_graph_node_handle(new_node, h_graph), c_dst, c_src, size,
809-
c_dst_type, c_src_type)
830+
c_dst_type, c_src_type))
810831

811832

812833
cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def):
@@ -831,8 +852,8 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def):
831852

832853
cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph)
833854

834-
return ChildGraphNode._create_with_params(
835-
create_graph_node_handle(new_node, h_graph), h_embedded)
855+
return _cached(ChildGraphNode._create_with_params(
856+
create_graph_node_handle(new_node, h_graph), h_embedded))
836857

837858

838859
cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev):
@@ -853,8 +874,8 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev):
853874
_attach_user_object(as_cu(h_graph), <void*>new EventHandle(ev._h_event),
854875
<cydriver.CUhostFn>_destroy_event_handle_copy)
855876

856-
return EventRecordNode._create_with_params(
857-
create_graph_node_handle(new_node, h_graph), ev._h_event)
877+
return _cached(EventRecordNode._create_with_params(
878+
create_graph_node_handle(new_node, h_graph), ev._h_event))
858879

859880

860881
cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev):
@@ -875,8 +896,8 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev):
875896
_attach_user_object(as_cu(h_graph), <void*>new EventHandle(ev._h_event),
876897
<cydriver.CUhostFn>_destroy_event_handle_copy)
877898

878-
return EventWaitNode._create_with_params(
879-
create_graph_node_handle(new_node, h_graph), ev._h_event)
899+
return _cached(EventWaitNode._create_with_params(
900+
create_graph_node_handle(new_node, h_graph), ev._h_event))
880901

881902

882903
cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_data):
@@ -902,6 +923,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_
902923
&new_node, as_cu(h_graph), deps, num_deps, &node_params))
903924

904925
cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None
905-
return HostCallbackNode._create_with_params(
926+
return _cached(HostCallbackNode._create_with_params(
906927
create_graph_node_handle(new_node, h_graph), callable_obj,
907-
node_params.fn, node_params.userData)
928+
node_params.fn, node_params.userData))

cuda_core/tests/graph/test_graphdef.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ def test_node_type_preserved_by_nodes(node_spec):
661661
matched = [n for n in all_nodes if n == node]
662662
assert len(matched) == 1
663663
assert isinstance(matched[0], spec.roundtrip_class)
664+
assert matched[0] is node
664665

665666

666667
def test_node_type_preserved_by_pred_succ(node_spec):
@@ -670,6 +671,7 @@ def test_node_type_preserved_by_pred_succ(node_spec):
670671
matched = [s for s in predecessor.succ if s == node]
671672
assert len(matched) == 1
672673
assert isinstance(matched[0], spec.roundtrip_class)
674+
assert matched[0] is node
673675

674676

675677
def test_node_attrs(node_spec):

cuda_core/tests/graph/test_graphdef_mutation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def test_convert_linear_to_fan_in(init_cuda):
372372
for node in g.nodes():
373373
if isinstance(node, MemsetNode):
374374
node.pred.clear()
375-
elif isinstance(node, KernelNode) and node != reduce_node:
375+
elif isinstance(node, KernelNode) and node is not reduce_node:
376376
node.succ.add(reduce_node)
377377

378378
assert len(g.edges()) == 8

0 commit comments

Comments
 (0)