Skip to content

Commit 2096f7d

Browse files
authored
Merge branch 'main' into dependabot/github_actions/actions-monthly-81fd04c4ef
2 parents d39fd0a + 66a687c commit 2096f7d

24 files changed

+2456
-2223
lines changed

cuda_core/cuda/core/_graph/_graph_builder.pyx

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import weakref
66
from dataclasses import dataclass
77

8+
from libc.stdint cimport intptr_t
9+
810
from cuda.bindings cimport cydriver
911

1012
from cuda.core._graph._utils cimport _attach_host_callback_to_graph
@@ -14,6 +16,7 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
1416
from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
1517

1618
from cuda.core._utils.cuda_utils import (
19+
CUDAError,
1720
driver,
1821
handle_return,
1922
)
@@ -783,24 +786,42 @@ class Graph:
783786
"""
784787
return self._mnff.graph
785788

786-
def update(self, builder: GraphBuilder):
787-
"""Update the graph using new build configuration from the builder.
789+
def update(self, source: "GraphBuilder | GraphDef") -> None:
790+
"""Update the graph using a new graph definition.
788791

789-
The topology of the provided builder must be identical to this graph.
792+
The topology of the provided source must be identical to this graph.
790793

791794
Parameters
792795
----------
793-
builder : :obj:`~_graph.GraphBuilder`
794-
The builder to update the graph with.
796+
source : :obj:`~_graph.GraphBuilder` or :obj:`~_graph._graph_def.GraphDef`
797+
The graph definition to update from. A GraphBuilder must have
798+
finished building.
795799

796800
"""
797-
if not builder._building_ended:
798-
raise ValueError("Graph has not finished building.")
801+
from cuda.core._graph._graph_def import GraphDef
802+
803+
cdef cydriver.CUgraph cu_graph
804+
cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph)
799805

800-
# Update the graph with the new nodes from the builder
801-
exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph))
802-
if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS:
803-
raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}")
806+
if isinstance(source, GraphBuilder):
807+
if not source._building_ended:
808+
raise ValueError("Graph has not finished building.")
809+
cu_graph = <cydriver.CUgraph><intptr_t>int(source._mnff.graph)
810+
elif isinstance(source, GraphDef):
811+
cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle)
812+
else:
813+
raise TypeError(
814+
f"expected GraphBuilder or GraphDef, got {type(source).__name__}")
815+
816+
cdef cydriver.CUgraphExecUpdateResultInfo result_info
817+
cdef cydriver.CUresult err
818+
with nogil:
819+
err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info)
820+
if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE:
821+
reason = driver.CUgraphExecUpdateResult(result_info.result)
822+
msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})"
823+
raise CUDAError(msg)
824+
HANDLE_RETURN(err)
804825

805826
def upload(self, stream: Stream):
806827
"""Uploads the graph in a stream.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from cuda.core._graph._graph_def._graph_def cimport Condition, GraphDef
6+
from cuda.core._graph._graph_def._graph_node cimport GraphNode
7+
from cuda.core._graph._graph_def._subclasses cimport (
8+
AllocNode,
9+
ChildGraphNode,
10+
ConditionalNode,
11+
EmptyNode,
12+
EventRecordNode,
13+
EventWaitNode,
14+
FreeNode,
15+
HostCallbackNode,
16+
IfElseNode,
17+
IfNode,
18+
KernelNode,
19+
MemcpyNode,
20+
MemsetNode,
21+
SwitchNode,
22+
WhileNode,
23+
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Explicit CUDA graph construction — GraphDef, GraphNode, and node subclasses."""
6+
7+
from cuda.core._graph._graph_def._graph_def import (
8+
Condition,
9+
GraphAllocOptions,
10+
GraphDef,
11+
)
12+
from cuda.core._graph._graph_def._graph_node import GraphNode
13+
from cuda.core._graph._graph_def._subclasses import (
14+
AllocNode,
15+
ChildGraphNode,
16+
ConditionalNode,
17+
EmptyNode,
18+
EventRecordNode,
19+
EventWaitNode,
20+
FreeNode,
21+
HostCallbackNode,
22+
IfElseNode,
23+
IfNode,
24+
KernelNode,
25+
MemcpyNode,
26+
MemsetNode,
27+
SwitchNode,
28+
WhileNode,
29+
)
30+
31+
__all__ = [
32+
"AllocNode",
33+
"ChildGraphNode",
34+
"Condition",
35+
"ConditionalNode",
36+
"EmptyNode",
37+
"EventRecordNode",
38+
"EventWaitNode",
39+
"FreeNode",
40+
"GraphAllocOptions",
41+
"GraphDef",
42+
"GraphNode",
43+
"HostCallbackNode",
44+
"IfElseNode",
45+
"IfNode",
46+
"KernelNode",
47+
"MemcpyNode",
48+
"MemsetNode",
49+
"SwitchNode",
50+
"WhileNode",
51+
]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from cuda.bindings cimport cydriver
6+
from cuda.core._resource_handles cimport GraphHandle
7+
8+
9+
cdef class Condition:
10+
cdef:
11+
cydriver.CUgraphConditionalHandle _c_handle
12+
object __weakref__
13+
14+
15+
cdef class GraphDef:
16+
cdef:
17+
GraphHandle _h_graph
18+
object __weakref__
19+
20+
@staticmethod
21+
cdef GraphDef _from_handle(GraphHandle h_graph)

0 commit comments

Comments
 (0)