Skip to content
42 changes: 31 additions & 11 deletions cuda_core/cuda/core/_graph/_graph_builder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import weakref
from dataclasses import dataclass

from libc.stdint cimport intptr_t

from cuda.bindings cimport cydriver

from cuda.core._graph._utils cimport _attach_host_callback_to_graph
Expand All @@ -14,6 +16,7 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
from cuda.core._utils.version cimport cy_binding_version, cy_driver_version

from cuda.core._utils.cuda_utils import (
CUDAError,
driver,
handle_return,
)
Expand Down Expand Up @@ -783,24 +786,41 @@ class Graph:
"""
return self._mnff.graph

def update(self, builder: GraphBuilder):
"""Update the graph using new build configuration from the builder.
def update(self, source: "GraphBuilder | GraphDef") -> None:
"""Update the graph using a new graph definition.

The topology of the provided builder must be identical to this graph.
The topology of the provided source must be identical to this graph.

Parameters
----------
builder : :obj:`~_graph.GraphBuilder`
The builder to update the graph with.
source : :obj:`~_graph.GraphBuilder` or :obj:`~_graph._graph_def.GraphDef`
The graph definition to update from. A GraphBuilder must have
finished building.

"""
if not builder._building_ended:
raise ValueError("Graph has not finished building.")
from cuda.core._graph._graph_def import GraphDef

cdef cydriver.CUgraph cu_graph
cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph)

# Update the graph with the new nodes from the builder
exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph))
if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS:
raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}")
if isinstance(source, GraphBuilder):
if not source._building_ended:
raise ValueError("Graph has not finished building.")
cu_graph = <cydriver.CUgraph><intptr_t>int(source._mnff.graph)
elif isinstance(source, GraphDef):
cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle)
else:
raise TypeError(
f"expected GraphBuilder or GraphDef, got {type(source).__name__}")

cdef cydriver.CUgraphExecUpdateResultInfo result_info
cdef cydriver.CUresult err
with nogil:
err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info)
if err != cydriver.CUresult.CUDA_SUCCESS:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used Cursor GPT-5.4 1M High to "comb" through this "very complex and very large" PR. It only found this one "High" item:


I think this would be a bit safer if it distinguished the graph-update failure case from ordinary driver errors, e.g.

        cdef cydriver.CUgraphExecUpdateResultInfo result_info
        cdef cydriver.CUresult err
        with nogil:
            err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info)
        if err == cydriver.CUresult.CUDA_SUCCESS:
            return
        if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE:
            reason = driver.CUgraphExecUpdateResult(result_info.result)
            msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})"
            raise CUDAError(msg)
        raise CUDAError(err)

Rationale:

  • Using cydriver.cuGraphExecUpdate(...) directly here makes sense, since the higher-level binding drops resultInfo on non-success and would lose the detailed update reason entirely.
  • But resultInfo appears to be the structured explanation for the specific CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE path, not necessarily for every possible non-success CUresult.
  • Even when result_info.result == CU_GRAPH_EXEC_UPDATE_ERROR, the enum docs say the actual explanation is described by the function return value. The current code discards err, so it may collapse distinct driver failures into the same generic resultInfo-based message.
  • This shape preserves the nice detailed message for graph-update incompatibilities while still surfacing ordinary driver errors accurately.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the docs, cuGraphExecUpdate only returns CUDA_SUCCESS or CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"very complex and very large"

To be clear, nearly all of this change is refactoring and code movement. The graph tests were regrouped slightly and renamed. The huge _graphdef module was split into three parts. The are not many functional changes here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the docs, cuGraphExecUpdate only returns CUDA_SUCCESS or CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE.

Documentation tends to be imprecise, or become imprecise over time without anyone noticing.

The suggested change improves the quality of implementation at a very small cost.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the driver code and the docs are indeed incorrect.

reason = driver.CUgraphExecUpdateResult(result_info.result)
msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})"
raise CUDAError(msg)

def upload(self, stream: Stream):
"""Uploads the graph in a stream.
Expand Down
23 changes: 23 additions & 0 deletions cuda_core/cuda/core/_graph/_graph_def/__init__.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from cuda.core._graph._graph_def._graph_def cimport Condition, GraphDef
from cuda.core._graph._graph_def._graph_node cimport GraphNode
from cuda.core._graph._graph_def._subclasses cimport (
AllocNode,
ChildGraphNode,
ConditionalNode,
EmptyNode,
EventRecordNode,
EventWaitNode,
FreeNode,
HostCallbackNode,
IfElseNode,
IfNode,
KernelNode,
MemcpyNode,
MemsetNode,
SwitchNode,
WhileNode,
)
51 changes: 51 additions & 0 deletions cuda_core/cuda/core/_graph/_graph_def/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

"""Explicit CUDA graph construction — GraphDef, GraphNode, and node subclasses."""

from cuda.core._graph._graph_def._graph_def import (
Condition,
GraphAllocOptions,
GraphDef,
)
from cuda.core._graph._graph_def._graph_node import GraphNode
from cuda.core._graph._graph_def._subclasses import (
AllocNode,
ChildGraphNode,
ConditionalNode,
EmptyNode,
EventRecordNode,
EventWaitNode,
FreeNode,
HostCallbackNode,
IfElseNode,
IfNode,
KernelNode,
MemcpyNode,
MemsetNode,
SwitchNode,
WhileNode,
)

__all__ = [
"AllocNode",
"ChildGraphNode",
"Condition",
"ConditionalNode",
"EmptyNode",
"EventRecordNode",
"EventWaitNode",
"FreeNode",
"GraphAllocOptions",
"GraphDef",
"GraphNode",
"HostCallbackNode",
"IfElseNode",
"IfNode",
"KernelNode",
"MemcpyNode",
"MemsetNode",
"SwitchNode",
"WhileNode",
]
21 changes: 21 additions & 0 deletions cuda_core/cuda/core/_graph/_graph_def/_graph_def.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from cuda.bindings cimport cydriver
from cuda.core._resource_handles cimport GraphHandle


cdef class Condition:
cdef:
cydriver.CUgraphConditionalHandle _c_handle
object __weakref__


cdef class GraphDef:
cdef:
GraphHandle _h_graph
object __weakref__

@staticmethod
cdef GraphDef _from_handle(GraphHandle h_graph)
Loading
Loading