|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +"""Mutable-set proxy for graph node predecessors and successors.""" |
| 6 | + |
| 7 | +from libc.stddef cimport size_t |
| 8 | +from libcpp.vector cimport vector |
| 9 | +from cuda.bindings cimport cydriver |
| 10 | +from cuda.core._graph._graph_def._graph_node cimport GraphNode |
| 11 | +from cuda.core._resource_handles cimport ( |
| 12 | + GraphHandle, |
| 13 | + GraphNodeHandle, |
| 14 | + as_cu, |
| 15 | + graph_node_get_graph, |
| 16 | +) |
| 17 | +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN |
| 18 | +from collections.abc import MutableSet |
| 19 | + |
| 20 | + |
| 21 | +# ---- Python MutableSet wrapper ---------------------------------------------- |
| 22 | + |
| 23 | +class AdjacencySetProxy(MutableSet): |
| 24 | + """Mutable set proxy for a node's predecessors or successors. Mutations |
| 25 | + write through to the underlying CUDA graph.""" |
| 26 | + |
| 27 | + __slots__ = ("_core",) |
| 28 | + |
| 29 | + def __init__(self, node, bint is_fwd): |
| 30 | + self._core = _AdjacencySetCore(node, is_fwd) |
| 31 | + |
| 32 | + # Used by operators such as &|^ to create non-proxy views when needed. |
| 33 | + @classmethod |
| 34 | + def _from_iterable(cls, it): |
| 35 | + return set(it) |
| 36 | + |
| 37 | + # --- abstract methods required by MutableSet --- |
| 38 | + |
| 39 | + def __contains__(self, x): |
| 40 | + if not isinstance(x, GraphNode): |
| 41 | + return False |
| 42 | + return x in (<_AdjacencySetCore>self._core).query() |
| 43 | + |
| 44 | + def __iter__(self): |
| 45 | + return iter((<_AdjacencySetCore>self._core).query()) |
| 46 | + |
| 47 | + def __len__(self): |
| 48 | + return (<_AdjacencySetCore>self._core).count() |
| 49 | + |
| 50 | + def add(self, value): |
| 51 | + if not isinstance(value, GraphNode): |
| 52 | + raise TypeError( |
| 53 | + f"expected GraphNode, got {type(value).__name__}") |
| 54 | + if value in self: |
| 55 | + return |
| 56 | + (<_AdjacencySetCore>self._core).add_edge(<GraphNode>value) |
| 57 | + |
| 58 | + def discard(self, value): |
| 59 | + if not isinstance(value, GraphNode): |
| 60 | + return |
| 61 | + if value not in self: |
| 62 | + return |
| 63 | + (<_AdjacencySetCore>self._core).remove_edge(<GraphNode>value) |
| 64 | + |
| 65 | + # --- override for bulk efficiency --- |
| 66 | + |
| 67 | + def clear(self): |
| 68 | + """Remove all edges in a single driver call.""" |
| 69 | + members = (<_AdjacencySetCore>self._core).query() |
| 70 | + if members: |
| 71 | + (<_AdjacencySetCore>self._core).remove_edges(members) |
| 72 | + |
| 73 | + def __isub__(self, it): |
| 74 | + """Remove edges to all nodes in *it* in a single driver call.""" |
| 75 | + if it is self: |
| 76 | + self.clear() |
| 77 | + else: |
| 78 | + to_remove = [v for v in it if isinstance(v, GraphNode) and v in self] |
| 79 | + if to_remove: |
| 80 | + (<_AdjacencySetCore>self._core).remove_edges(to_remove) |
| 81 | + return self |
| 82 | + |
| 83 | + def update(self, *others): |
| 84 | + """Add edges to multiple nodes at once.""" |
| 85 | + nodes = [] |
| 86 | + for other in others: |
| 87 | + if isinstance(other, GraphNode): |
| 88 | + nodes.append(other) |
| 89 | + else: |
| 90 | + nodes.extend(other) |
| 91 | + if not nodes: |
| 92 | + return |
| 93 | + for n in nodes: |
| 94 | + if not isinstance(n, GraphNode): |
| 95 | + raise TypeError( |
| 96 | + f"expected GraphNode, got {type(n).__name__}") |
| 97 | + new = [n for n in nodes if n not in self] |
| 98 | + if new: |
| 99 | + (<_AdjacencySetCore>self._core).add_edges(new) |
| 100 | + |
| 101 | + def __ior__(self, it): |
| 102 | + """Add edges to all nodes in *it* in a single driver call.""" |
| 103 | + self.update(it) |
| 104 | + return self |
| 105 | + |
| 106 | + def __repr__(self): |
| 107 | + return "{" + ", ".join(repr(n) for n in self) + "}" |
| 108 | + |
| 109 | + |
| 110 | +# ---- cdef core holding a function pointer ------------------------------------ |
| 111 | + |
| 112 | +# Signature shared by driver_get_preds and driver_get_succs. |
| 113 | +ctypedef cydriver.CUresult (*_adj_fn_t)( |
| 114 | + cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil |
| 115 | + |
| 116 | + |
| 117 | +cdef class _AdjacencySetCore: |
| 118 | + """Cythonized core implementing AdjacencySetProxy""" |
| 119 | + cdef: |
| 120 | + GraphNodeHandle _h_node |
| 121 | + GraphHandle _h_graph |
| 122 | + _adj_fn_t _query_fn |
| 123 | + bint _is_fwd |
| 124 | + |
| 125 | + def __init__(self, GraphNode node, bint is_fwd): |
| 126 | + self._h_node = node._h_node |
| 127 | + self._h_graph = graph_node_get_graph(node._h_node) |
| 128 | + self._is_fwd = is_fwd |
| 129 | + self._query_fn = driver_get_succs if is_fwd else driver_get_preds |
| 130 | + |
| 131 | + cdef inline void _resolve_edge( |
| 132 | + self, GraphNode other, |
| 133 | + cydriver.CUgraphNode* c_from, |
| 134 | + cydriver.CUgraphNode* c_to) noexcept: |
| 135 | + if self._is_fwd: |
| 136 | + c_from[0] = as_cu(self._h_node) |
| 137 | + c_to[0] = as_cu(other._h_node) |
| 138 | + else: |
| 139 | + c_from[0] = as_cu(other._h_node) |
| 140 | + c_to[0] = as_cu(self._h_node) |
| 141 | + |
| 142 | + cdef list query(self): |
| 143 | + cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) |
| 144 | + if c_node == NULL: |
| 145 | + return [] |
| 146 | + cdef size_t count = 0 |
| 147 | + with nogil: |
| 148 | + HANDLE_RETURN(self._query_fn(c_node, NULL, &count)) |
| 149 | + if count == 0: |
| 150 | + return [] |
| 151 | + cdef vector[cydriver.CUgraphNode] nodes_vec |
| 152 | + nodes_vec.resize(count) |
| 153 | + with nogil: |
| 154 | + HANDLE_RETURN(self._query_fn( |
| 155 | + c_node, nodes_vec.data(), &count)) |
| 156 | + return [GraphNode._create(self._h_graph, nodes_vec[i]) |
| 157 | + for i in range(count)] |
| 158 | + |
| 159 | + cdef Py_ssize_t count(self): |
| 160 | + cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) |
| 161 | + if c_node == NULL: |
| 162 | + return 0 |
| 163 | + cdef size_t n = 0 |
| 164 | + with nogil: |
| 165 | + HANDLE_RETURN(self._query_fn(c_node, NULL, &n)) |
| 166 | + return <Py_ssize_t>n |
| 167 | + |
| 168 | + cdef void add_edge(self, GraphNode other): |
| 169 | + cdef cydriver.CUgraphNode c_from, c_to |
| 170 | + self._resolve_edge(other, &c_from, &c_to) |
| 171 | + with nogil: |
| 172 | + HANDLE_RETURN(driver_add_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) |
| 173 | + |
| 174 | + cdef void add_edges(self, list nodes): |
| 175 | + cdef size_t n = len(nodes) |
| 176 | + cdef vector[cydriver.CUgraphNode] from_vec |
| 177 | + cdef vector[cydriver.CUgraphNode] to_vec |
| 178 | + from_vec.resize(n) |
| 179 | + to_vec.resize(n) |
| 180 | + cdef size_t i |
| 181 | + for i in range(n): |
| 182 | + self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i]) |
| 183 | + with nogil: |
| 184 | + HANDLE_RETURN(driver_add_edges( |
| 185 | + as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) |
| 186 | + |
| 187 | + cdef void remove_edge(self, GraphNode other): |
| 188 | + cdef cydriver.CUgraphNode c_from, c_to |
| 189 | + self._resolve_edge(other, &c_from, &c_to) |
| 190 | + with nogil: |
| 191 | + HANDLE_RETURN(driver_remove_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) |
| 192 | + |
| 193 | + cdef void remove_edges(self, list nodes): |
| 194 | + cdef size_t n = len(nodes) |
| 195 | + cdef vector[cydriver.CUgraphNode] from_vec |
| 196 | + cdef vector[cydriver.CUgraphNode] to_vec |
| 197 | + from_vec.resize(n) |
| 198 | + to_vec.resize(n) |
| 199 | + cdef size_t i |
| 200 | + for i in range(n): |
| 201 | + self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i]) |
| 202 | + with nogil: |
| 203 | + HANDLE_RETURN(driver_remove_edges( |
| 204 | + as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) |
| 205 | + |
| 206 | + |
| 207 | +# ---- driver wrappers: absorb CUDA version differences ---- |
| 208 | + |
| 209 | +cdef inline cydriver.CUresult driver_get_preds( |
| 210 | + cydriver.CUgraphNode node, cydriver.CUgraphNode* out, |
| 211 | + size_t* count) noexcept nogil: |
| 212 | + IF CUDA_CORE_BUILD_MAJOR >= 13: |
| 213 | + return cydriver.cuGraphNodeGetDependencies(node, out, NULL, count) |
| 214 | + ELSE: |
| 215 | + return cydriver.cuGraphNodeGetDependencies(node, out, count) |
| 216 | + |
| 217 | + |
| 218 | +cdef inline cydriver.CUresult driver_get_succs( |
| 219 | + cydriver.CUgraphNode node, cydriver.CUgraphNode* out, |
| 220 | + size_t* count) noexcept nogil: |
| 221 | + IF CUDA_CORE_BUILD_MAJOR >= 13: |
| 222 | + return cydriver.cuGraphNodeGetDependentNodes(node, out, NULL, count) |
| 223 | + ELSE: |
| 224 | + return cydriver.cuGraphNodeGetDependentNodes(node, out, count) |
| 225 | + |
| 226 | + |
| 227 | +cdef inline cydriver.CUresult driver_add_edges( |
| 228 | + cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, |
| 229 | + cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: |
| 230 | + IF CUDA_CORE_BUILD_MAJOR >= 13: |
| 231 | + return cydriver.cuGraphAddDependencies( |
| 232 | + graph, from_arr, to_arr, NULL, count) |
| 233 | + ELSE: |
| 234 | + return cydriver.cuGraphAddDependencies( |
| 235 | + graph, from_arr, to_arr, count) |
| 236 | + |
| 237 | + |
| 238 | +cdef inline cydriver.CUresult driver_remove_edges( |
| 239 | + cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, |
| 240 | + cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: |
| 241 | + IF CUDA_CORE_BUILD_MAJOR >= 13: |
| 242 | + return cydriver.cuGraphRemoveDependencies( |
| 243 | + graph, from_arr, to_arr, NULL, count) |
| 244 | + ELSE: |
| 245 | + return cydriver.cuGraphRemoveDependencies( |
| 246 | + graph, from_arr, to_arr, count) |
0 commit comments