Skip to content

Commit 3c3dfb3

Browse files
committed
Merge branch 'main' into pathfinder_cudla
2 parents 9327aea + 5064470 commit 3c3dfb3

13 files changed

+954
-145
lines changed

cuda_core/cuda/core/_cpp/resource_handles.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,8 @@ class HandleRegistry {
174174
}
175175

176176
void unregister_handle(const Key& key) noexcept {
177-
try {
178-
std::lock_guard<std::mutex> lock(mutex_);
179-
auto it = map_.find(key);
180-
if (it != map_.end() && it->second.expired()) {
181-
map_.erase(it);
182-
}
183-
} catch (...) {}
177+
std::lock_guard<std::mutex> lock(mutex_);
178+
map_.erase(key);
184179
}
185180

186181
Handle lookup(const Key& key) {
@@ -957,7 +952,7 @@ GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent)
957952

958953
namespace {
959954
struct GraphNodeBox {
960-
CUgraphNode resource;
955+
mutable CUgraphNode resource;
961956
GraphHandle h_graph;
962957
};
963958
} // namespace
@@ -969,15 +964,36 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) {
969964
);
970965
}
971966

967+
static HandleRegistry<CUgraphNode, GraphNodeHandle> graph_node_registry;
968+
972969
GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) {
970+
if (node) {
971+
if (auto h = graph_node_registry.lookup(node)) {
972+
return h;
973+
}
974+
}
973975
auto box = std::make_shared<const GraphNodeBox>(GraphNodeBox{node, h_graph});
974-
return GraphNodeHandle(box, &box->resource);
976+
GraphNodeHandle h(box, &box->resource);
977+
if (node) {
978+
graph_node_registry.register_handle(node, h);
979+
}
980+
return h;
975981
}
976982

977983
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept {
978984
return h ? get_box(h)->h_graph : GraphHandle{};
979985
}
980986

987+
void invalidate_graph_node(const GraphNodeHandle& h) noexcept {
988+
if (h) {
989+
CUgraphNode node = get_box(h)->resource;
990+
if (node) {
991+
graph_node_registry.unregister_handle(node);
992+
}
993+
get_box(h)->resource = nullptr;
994+
}
995+
}
996+
981997
// ============================================================================
982998
// Graphics Resource Handles
983999
// ============================================================================

cuda_core/cuda/core/_cpp/resource_handles.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,9 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_
415415
// Extract the owning graph handle from a node handle.
416416
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept;
417417

418+
// Zero the CUgraphNode resource inside the handle, marking it invalid.
419+
void invalidate_graph_node(const GraphNodeHandle& h) noexcept;
420+
418421
// ============================================================================
419422
// Graphics resource handle functions
420423
// ============================================================================
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Mutable-set proxy for graph node predecessors and successors."""
6+
7+
from libc.stddef cimport size_t
8+
from libcpp.vector cimport vector
9+
from cuda.bindings cimport cydriver
10+
from cuda.core._graph._graph_def._graph_node cimport GraphNode
11+
from cuda.core._resource_handles cimport (
12+
GraphHandle,
13+
GraphNodeHandle,
14+
as_cu,
15+
graph_node_get_graph,
16+
)
17+
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
18+
from collections.abc import MutableSet
19+
20+
21+
# ---- Python MutableSet wrapper ----------------------------------------------
22+
23+
class AdjacencySetProxy(MutableSet):
24+
"""Mutable set proxy for a node's predecessors or successors. Mutations
25+
write through to the underlying CUDA graph."""
26+
27+
__slots__ = ("_core",)
28+
29+
def __init__(self, node, bint is_fwd):
30+
self._core = _AdjacencySetCore(node, is_fwd)
31+
32+
# Used by operators such as &|^ to create non-proxy views when needed.
33+
@classmethod
34+
def _from_iterable(cls, it):
35+
return set(it)
36+
37+
# --- abstract methods required by MutableSet ---
38+
39+
def __contains__(self, x):
40+
if not isinstance(x, GraphNode):
41+
return False
42+
return x in (<_AdjacencySetCore>self._core).query()
43+
44+
def __iter__(self):
45+
return iter((<_AdjacencySetCore>self._core).query())
46+
47+
def __len__(self):
48+
return (<_AdjacencySetCore>self._core).count()
49+
50+
def add(self, value):
51+
if not isinstance(value, GraphNode):
52+
raise TypeError(
53+
f"expected GraphNode, got {type(value).__name__}")
54+
if value in self:
55+
return
56+
(<_AdjacencySetCore>self._core).add_edge(<GraphNode>value)
57+
58+
def discard(self, value):
59+
if not isinstance(value, GraphNode):
60+
return
61+
if value not in self:
62+
return
63+
(<_AdjacencySetCore>self._core).remove_edge(<GraphNode>value)
64+
65+
# --- override for bulk efficiency ---
66+
67+
def clear(self):
68+
"""Remove all edges in a single driver call."""
69+
members = (<_AdjacencySetCore>self._core).query()
70+
if members:
71+
(<_AdjacencySetCore>self._core).remove_edges(members)
72+
73+
def __isub__(self, it):
74+
"""Remove edges to all nodes in *it* in a single driver call."""
75+
if it is self:
76+
self.clear()
77+
else:
78+
to_remove = [v for v in it if isinstance(v, GraphNode) and v in self]
79+
if to_remove:
80+
(<_AdjacencySetCore>self._core).remove_edges(to_remove)
81+
return self
82+
83+
def update(self, *others):
84+
"""Add edges to multiple nodes at once."""
85+
nodes = []
86+
for other in others:
87+
if isinstance(other, GraphNode):
88+
nodes.append(other)
89+
else:
90+
nodes.extend(other)
91+
if not nodes:
92+
return
93+
for n in nodes:
94+
if not isinstance(n, GraphNode):
95+
raise TypeError(
96+
f"expected GraphNode, got {type(n).__name__}")
97+
new = [n for n in nodes if n not in self]
98+
if new:
99+
(<_AdjacencySetCore>self._core).add_edges(new)
100+
101+
def __ior__(self, it):
102+
"""Add edges to all nodes in *it* in a single driver call."""
103+
self.update(it)
104+
return self
105+
106+
def __repr__(self):
107+
return "{" + ", ".join(repr(n) for n in self) + "}"
108+
109+
110+
# ---- cdef core holding a function pointer ------------------------------------
111+
112+
# Signature shared by driver_get_preds and driver_get_succs.
113+
ctypedef cydriver.CUresult (*_adj_fn_t)(
114+
cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil
115+
116+
117+
cdef class _AdjacencySetCore:
118+
"""Cythonized core implementing AdjacencySetProxy"""
119+
cdef:
120+
GraphNodeHandle _h_node
121+
GraphHandle _h_graph
122+
_adj_fn_t _query_fn
123+
bint _is_fwd
124+
125+
def __init__(self, GraphNode node, bint is_fwd):
126+
self._h_node = node._h_node
127+
self._h_graph = graph_node_get_graph(node._h_node)
128+
self._is_fwd = is_fwd
129+
self._query_fn = driver_get_succs if is_fwd else driver_get_preds
130+
131+
cdef inline void _resolve_edge(
132+
self, GraphNode other,
133+
cydriver.CUgraphNode* c_from,
134+
cydriver.CUgraphNode* c_to) noexcept:
135+
if self._is_fwd:
136+
c_from[0] = as_cu(self._h_node)
137+
c_to[0] = as_cu(other._h_node)
138+
else:
139+
c_from[0] = as_cu(other._h_node)
140+
c_to[0] = as_cu(self._h_node)
141+
142+
cdef list query(self):
143+
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
144+
if c_node == NULL:
145+
return []
146+
cdef size_t count = 0
147+
with nogil:
148+
HANDLE_RETURN(self._query_fn(c_node, NULL, &count))
149+
if count == 0:
150+
return []
151+
cdef vector[cydriver.CUgraphNode] nodes_vec
152+
nodes_vec.resize(count)
153+
with nogil:
154+
HANDLE_RETURN(self._query_fn(
155+
c_node, nodes_vec.data(), &count))
156+
return [GraphNode._create(self._h_graph, nodes_vec[i])
157+
for i in range(count)]
158+
159+
cdef Py_ssize_t count(self):
160+
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
161+
if c_node == NULL:
162+
return 0
163+
cdef size_t n = 0
164+
with nogil:
165+
HANDLE_RETURN(self._query_fn(c_node, NULL, &n))
166+
return <Py_ssize_t>n
167+
168+
cdef void add_edge(self, GraphNode other):
169+
cdef cydriver.CUgraphNode c_from, c_to
170+
self._resolve_edge(other, &c_from, &c_to)
171+
with nogil:
172+
HANDLE_RETURN(driver_add_edges(as_cu(self._h_graph), &c_from, &c_to, 1))
173+
174+
cdef void add_edges(self, list nodes):
175+
cdef size_t n = len(nodes)
176+
cdef vector[cydriver.CUgraphNode] from_vec
177+
cdef vector[cydriver.CUgraphNode] to_vec
178+
from_vec.resize(n)
179+
to_vec.resize(n)
180+
cdef size_t i
181+
for i in range(n):
182+
self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i])
183+
with nogil:
184+
HANDLE_RETURN(driver_add_edges(
185+
as_cu(self._h_graph), from_vec.data(), to_vec.data(), n))
186+
187+
cdef void remove_edge(self, GraphNode other):
188+
cdef cydriver.CUgraphNode c_from, c_to
189+
self._resolve_edge(other, &c_from, &c_to)
190+
with nogil:
191+
HANDLE_RETURN(driver_remove_edges(as_cu(self._h_graph), &c_from, &c_to, 1))
192+
193+
cdef void remove_edges(self, list nodes):
194+
cdef size_t n = len(nodes)
195+
cdef vector[cydriver.CUgraphNode] from_vec
196+
cdef vector[cydriver.CUgraphNode] to_vec
197+
from_vec.resize(n)
198+
to_vec.resize(n)
199+
cdef size_t i
200+
for i in range(n):
201+
self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i])
202+
with nogil:
203+
HANDLE_RETURN(driver_remove_edges(
204+
as_cu(self._h_graph), from_vec.data(), to_vec.data(), n))
205+
206+
207+
# ---- driver wrappers: absorb CUDA version differences ----
208+
209+
cdef inline cydriver.CUresult driver_get_preds(
210+
cydriver.CUgraphNode node, cydriver.CUgraphNode* out,
211+
size_t* count) noexcept nogil:
212+
IF CUDA_CORE_BUILD_MAJOR >= 13:
213+
return cydriver.cuGraphNodeGetDependencies(node, out, NULL, count)
214+
ELSE:
215+
return cydriver.cuGraphNodeGetDependencies(node, out, count)
216+
217+
218+
cdef inline cydriver.CUresult driver_get_succs(
219+
cydriver.CUgraphNode node, cydriver.CUgraphNode* out,
220+
size_t* count) noexcept nogil:
221+
IF CUDA_CORE_BUILD_MAJOR >= 13:
222+
return cydriver.cuGraphNodeGetDependentNodes(node, out, NULL, count)
223+
ELSE:
224+
return cydriver.cuGraphNodeGetDependentNodes(node, out, count)
225+
226+
227+
cdef inline cydriver.CUresult driver_add_edges(
228+
cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr,
229+
cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil:
230+
IF CUDA_CORE_BUILD_MAJOR >= 13:
231+
return cydriver.cuGraphAddDependencies(
232+
graph, from_arr, to_arr, NULL, count)
233+
ELSE:
234+
return cydriver.cuGraphAddDependencies(
235+
graph, from_arr, to_arr, count)
236+
237+
238+
cdef inline cydriver.CUresult driver_remove_edges(
239+
cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr,
240+
cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil:
241+
IF CUDA_CORE_BUILD_MAJOR >= 13:
242+
return cydriver.cuGraphRemoveDependencies(
243+
graph, from_arr, to_arr, NULL, count)
244+
ELSE:
245+
return cydriver.cuGraphRemoveDependencies(
246+
graph, from_arr, to_arr, count)

cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,12 @@ cdef class GraphDef:
314314
with nogil:
315315
HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(as_cu(self._h_graph), c_path, flags))
316316
317-
def nodes(self) -> tuple:
317+
def nodes(self) -> set:
318318
"""Return all nodes in the graph.
319319

320320
Returns
321321
-------
322-
tuple of GraphNode
322+
set of GraphNode
323323
All nodes in the graph.
324324
"""
325325
cdef size_t num_nodes = 0
@@ -328,21 +328,21 @@ cdef class GraphDef:
328328
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes))
329329
330330
if num_nodes == 0:
331-
return ()
331+
return set()
332332
333333
cdef vector[cydriver.CUgraphNode] nodes_vec
334334
nodes_vec.resize(num_nodes)
335335
with nogil:
336336
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
337337
338-
return tuple(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes))
338+
return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes))
339339
340-
def edges(self) -> tuple:
340+
def edges(self) -> set:
341341
"""Return all edges in the graph as (from_node, to_node) pairs.
342342

343343
Returns
344344
-------
345-
tuple of tuple
345+
set of tuple
346346
Each element is a (from_node, to_node) pair representing
347347
a dependency edge in the graph.
348348
"""
@@ -355,7 +355,7 @@ cdef class GraphDef:
355355
HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges))
356356
357357
if num_edges == 0:
358-
return ()
358+
return set()
359359
360360
cdef vector[cydriver.CUgraphNode] from_nodes
361361
cdef vector[cydriver.CUgraphNode] to_nodes
@@ -369,7 +369,7 @@ cdef class GraphDef:
369369
HANDLE_RETURN(cydriver.cuGraphGetEdges(
370370
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))
371371
372-
return tuple(
372+
return set(
373373
(GraphNode._create(self._h_graph, from_nodes[i]),
374374
GraphNode._create(self._h_graph, to_nodes[i]))
375375
for i in range(num_edges)

cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ from cuda.core._resource_handles cimport GraphHandle, GraphNodeHandle
99
cdef class GraphNode:
1010
cdef:
1111
GraphNodeHandle _h_node
12-
tuple _pred_cache
13-
tuple _succ_cache
1412
object __weakref__
1513

1614
@staticmethod

0 commit comments

Comments
 (0)