Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions cuda_core/cuda/core/graph/_graph_builder.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,48 @@ class GraphBuilder:
def is_join_required(self) -> bool:
"""Returns True if this graph builder must be joined before building is ended."""

@property
def graph_definition(self) -> GraphDefinition:
"""The captured graph as an explicit :class:`~graph.GraphDefinition`.

The returned :class:`~graph.GraphDefinition` is a view of the same
graph this builder is producing: nodes added through it appear in
subsequent :meth:`complete` and :meth:`debug_dot_print` calls, and
the view stays valid even after the builder is closed.

This lets you mix the capture and explicit APIs on a single graph,
for example to inspect what was captured, augment it with extra
nodes, or build a conditional body entirely with the explicit API.

Availability:

- **Primary builders** (created by :meth:`Device.create_graph_builder`
or :meth:`Stream.create_graph_builder`): only after
:meth:`end_building`.

- **Conditional-body builders** (returned by :meth:`if_then`,
:meth:`if_else`, :meth:`while_loop`, :meth:`switch`): both before
:meth:`begin_building` and after :meth:`end_building`. The body
graph already exists when the conditional is created, so you may
populate it through this view without ever calling
:meth:`begin_building` on the body builder.

- **Forked builders** (returned by :meth:`split`): never. Forked
builders share the primary builder's graph; access it through the
primary instead.

Returns
-------
GraphDefinition
A view of the graph being built.

Raises
------
RuntimeError
If the builder is forked, currently building, or (for primary
builders) has not started building yet.
"""

def begin_building(self, mode: str | None='relaxed') -> GraphBuilder:
"""Begins the building process.

Expand Down
60 changes: 59 additions & 1 deletion cuda_core/cuda/core/graph/_graph_builder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from libc.stdint cimport intptr_t

from cuda.bindings cimport cydriver

from cuda.core.graph._graph_definition cimport GraphCondition
from cuda.core.graph._graph_definition cimport GraphCondition, GraphDefinition
from cuda.core.graph._utils cimport _attach_host_callback_to_graph
from cuda.core._resource_handles cimport (
GraphHandle,
Expand Down Expand Up @@ -282,6 +282,64 @@ cdef class GraphBuilder:
"""Returns True if this graph builder must be joined before building is ended."""
return self._kind == FORKED

@property
def graph_definition(self) -> GraphDefinition:
"""The captured graph as an explicit :class:`~graph.GraphDefinition`.

The returned :class:`~graph.GraphDefinition` is a view of the same
graph this builder is producing: nodes added through it appear in
subsequent :meth:`complete` and :meth:`debug_dot_print` calls, and
the view stays valid even after the builder is closed.

This lets you mix the capture and explicit APIs on a single graph,
for example to inspect what was captured, augment it with extra
nodes, or build a conditional body entirely with the explicit API.

Availability:

- **Primary builders** (created by :meth:`Device.create_graph_builder`
or :meth:`Stream.create_graph_builder`): only after
:meth:`end_building`.

- **Conditional-body builders** (returned by :meth:`if_then`,
:meth:`if_else`, :meth:`while_loop`, :meth:`switch`): both before
:meth:`begin_building` and after :meth:`end_building`. The body
graph already exists when the conditional is created, so you may
populate it through this view without ever calling
:meth:`begin_building` on the body builder.

- **Forked builders** (returned by :meth:`split`): never. Forked
builders share the primary builder's graph; access it through the
primary instead.

Returns
-------
GraphDefinition
A view of the graph being built.

Raises
------
RuntimeError
If the builder is forked, currently building, or (for primary
builders) has not started building yet.
"""
if self._kind == FORKED:
raise RuntimeError(
"graph_definition is unavailable on forked graph builders; "
"access it through the primary builder instead."
)
if self._state == CAPTURING:
raise RuntimeError(
"graph_definition is unavailable while capture is in "
"progress; call end_building() first."
)
if self._kind == PRIMARY and self._state == CAPTURE_NOT_STARTED:
raise RuntimeError(
"graph_definition is unavailable before begin_building() on "
"a primary builder; no graph has been created yet."
)
return GraphDefinition._from_handle(self._h_graph)

def begin_building(self, mode: str | None = "relaxed") -> GraphBuilder:
"""Begins the building process.

Expand Down
192 changes: 190 additions & 2 deletions cuda_core/tests/graph/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

import numpy as np
import pytest
from helpers.graph_kernels import compile_common_kernels
from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels
from helpers.marks import requires_module
from helpers.misc import try_create_condition

from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource, launch
from cuda.core.graph import GraphBuilder
from cuda.core.graph import GraphBuilder, GraphDefinition


def test_graph_is_building(init_cuda):
Expand Down Expand Up @@ -384,3 +385,190 @@ def test_graph_stream_lifetime(init_cuda):

# Destroy the stream
stream.close()


# ---------------------------------------------------------------------------
# GraphBuilder.graph_definition
# ---------------------------------------------------------------------------


def test_graph_definition_returns_graph_definition_after_end_building(init_cuda):
"""Primary builder exposes its captured graph as a GraphDefinition after end_building()."""
mod = compile_common_kernels()
empty_kernel = mod.get_kernel("empty_kernel")

gb = Device().create_graph_builder().begin_building()
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
gb.end_building()

gd = gb.graph_definition
assert isinstance(gd, GraphDefinition)
# The captured graph must contain the launched kernels.
assert len(gd.nodes()) == 2


def test_graph_definition_raises_before_begin_building(init_cuda):
"""Primary builder has no graph allocated before begin_building()."""
gb = Device().create_graph_builder()
with pytest.raises(RuntimeError, match="before begin_building"):
_ = gb.graph_definition


def test_graph_definition_raises_during_capture(init_cuda):
"""graph_definition is unsafe while the driver is actively capturing."""
gb = Device().create_graph_builder().begin_building()
try:
with pytest.raises(RuntimeError, match="capture is in"):
_ = gb.graph_definition
finally:
gb.end_building()


def test_graph_definition_raises_for_forked(init_cuda):
"""Forked builders share the primary's graph; their property must raise."""
mod = compile_common_kernels()
empty_kernel = mod.get_kernel("empty_kernel")

gb = Device().create_graph_builder().begin_building()
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
primary, sibling = gb.split(2)
try:
with pytest.raises(RuntimeError, match="forked"):
_ = sibling.graph_definition
finally:
sibling = GraphBuilder.join(primary, sibling)
sibling.end_building()


def test_graph_definition_shares_ownership(init_cuda):
"""Closing the builder must not invalidate a held GraphDefinition."""
mod = compile_common_kernels()
empty_kernel = mod.get_kernel("empty_kernel")

gb = Device().create_graph_builder().begin_building()
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
gb.end_building()

gd = gb.graph_definition
gb.close()
# The shared CUgraph keeps the graph alive.
assert len(gd.nodes()) == 1


def test_graph_definition_round_trips_through_explicit_api(init_cuda):
"""Mutating via the explicit API survives complete() and runs correctly."""
mod = compile_common_kernels()
add_one = mod.get_kernel("add_one")

launch_stream = Device().create_stream()
mr = LegacyPinnedMemoryResource()
b = mr.allocate(4)
arr = np.from_dlpack(b).view(np.int32)
arr[0] = 0

gb = launch_stream.create_graph_builder().begin_building()
launch(gb, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
gb.end_building()

# Add a second add_one through the explicit GraphDefinition view.
gd = gb.graph_definition
captured_node = next(iter(gd.nodes()))
captured_node.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
assert len(gd.nodes()) == 2

graph = gb.complete()
graph.launch(launch_stream)
launch_stream.sync()
assert arr[0] == 2

b.close()


@requires_module(np, "2.1")
def test_graph_definition_hybrid_conditional_body(init_cuda):
"""Populate a conditional body entirely through the explicit API.

This is the headline hybrid flow enabled by the new property:
``if_then`` returns a ``GraphBuilder`` for the body, but instead of
calling ``begin_building`` and capturing into it, we reach for
``graph_definition`` and add nodes through the explicit API.
"""
mod = compile_conditional_kernels(int)
add_one = mod.get_kernel("add_one")
set_handle = mod.get_kernel("set_handle")

launch_stream = Device().create_stream()
mr = LegacyPinnedMemoryResource()
b = mr.allocate(4)
arr = np.from_dlpack(b).view(np.int32)
arr[0] = 0

gb = Device().create_graph_builder().begin_building()
condition = try_create_condition(gb)
launch(gb, LaunchConfig(grid=1, block=1), set_handle, condition, 1)
body_gb = gb.if_then(condition)

# Skip body_gb.begin_building() entirely -- the body graph already
# exists at conditional-node creation time and is exposed here.
body_def = body_gb.graph_definition
assert isinstance(body_def, GraphDefinition)
assert len(body_def.nodes()) == 0
body_def.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)

graph = gb.end_building().complete()
graph.launch(launch_stream)
launch_stream.sync()
assert arr[0] == 1

b.close()


@requires_module(np, "2.1")
def test_graph_definition_conditional_body_after_capture(init_cuda):
"""Capture into a conditional body, then augment it via the explicit API."""
mod = compile_conditional_kernels(int)
add_one = mod.get_kernel("add_one")
set_handle = mod.get_kernel("set_handle")

launch_stream = Device().create_stream()
mr = LegacyPinnedMemoryResource()
b = mr.allocate(4)
arr = np.from_dlpack(b).view(np.int32)
arr[0] = 0

gb = Device().create_graph_builder().begin_building()
condition = try_create_condition(gb)
launch(gb, LaunchConfig(grid=1, block=1), set_handle, condition, 1)
body_gb = gb.if_then(condition).begin_building()

# Capture one increment into the body.
launch(body_gb, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
body_gb.end_building()

# Add a second increment via the explicit API on the same body graph.
body_def = body_gb.graph_definition
captured_node = next(iter(body_def.nodes()))
captured_node.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
assert len(body_def.nodes()) == 2

graph = gb.end_building().complete()
graph.launch(launch_stream)
launch_stream.sync()
assert arr[0] == 2

b.close()


@requires_module(np, "2.1")
def test_graph_definition_conditional_body_during_capture_raises(init_cuda):
"""The CAPTURING-state guard fires for conditional bodies too."""
gb = Device().create_graph_builder().begin_building()
condition = try_create_condition(gb)
body_gb = gb.if_then(condition).begin_building()
try:
with pytest.raises(RuntimeError, match="capture is in"):
_ = body_gb.graph_definition
finally:
body_gb.end_building()
gb.end_building()
Loading