diff --git a/cuda_core/cuda/core/graph/_graph_builder.pyi b/cuda_core/cuda/core/graph/_graph_builder.pyi index af1748ad86..00af261423 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pyi +++ b/cuda_core/cuda/core/graph/_graph_builder.pyi @@ -129,6 +129,48 @@ class GraphBuilder: def is_join_required(self) -> bool: """Returns True if this graph builder must be joined before building is ended.""" + @property + def graph_definition(self) -> GraphDefinition: + """The captured graph as an explicit :class:`~graph.GraphDefinition`. + + The returned :class:`~graph.GraphDefinition` is a view of the same + graph this builder is producing: nodes added through it appear in + subsequent :meth:`complete` and :meth:`debug_dot_print` calls, and + the view stays valid even after the builder is closed. + + This lets you mix the capture and explicit APIs on a single graph, + for example to inspect what was captured, augment it with extra + nodes, or build a conditional body entirely with the explicit API. + + Availability: + + - **Primary builders** (created by :meth:`Device.create_graph_builder` + or :meth:`Stream.create_graph_builder`): only after + :meth:`end_building`. + + - **Conditional-body builders** (returned by :meth:`if_then`, + :meth:`if_else`, :meth:`while_loop`, :meth:`switch`): both before + :meth:`begin_building` and after :meth:`end_building`. The body + graph already exists when the conditional is created, so you may + populate it through this view without ever calling + :meth:`begin_building` on the body builder. + + - **Forked builders** (returned by :meth:`split`): never. Forked + builders share the primary builder's graph; access it through the + primary instead. + + Returns + ------- + GraphDefinition + A view of the graph being built. + + Raises + ------ + RuntimeError + If the builder is forked, currently building, or (for primary + builders) has not started building yet. + """ + def begin_building(self, mode: str | None='relaxed') -> GraphBuilder: """Begins the building process. diff --git a/cuda_core/cuda/core/graph/_graph_builder.pyx b/cuda_core/cuda/core/graph/_graph_builder.pyx index c7b2ba5f74..dfd106fb0d 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/graph/_graph_builder.pyx @@ -9,7 +9,7 @@ from libc.stdint cimport intptr_t from cuda.bindings cimport cydriver -from cuda.core.graph._graph_definition cimport GraphCondition +from cuda.core.graph._graph_definition cimport GraphCondition, GraphDefinition from cuda.core.graph._utils cimport _attach_host_callback_to_graph from cuda.core._resource_handles cimport ( GraphHandle, @@ -282,6 +282,64 @@ cdef class GraphBuilder: """Returns True if this graph builder must be joined before building is ended.""" return self._kind == FORKED + @property + def graph_definition(self) -> GraphDefinition: + """The captured graph as an explicit :class:`~graph.GraphDefinition`. + + The returned :class:`~graph.GraphDefinition` is a view of the same + graph this builder is producing: nodes added through it appear in + subsequent :meth:`complete` and :meth:`debug_dot_print` calls, and + the view stays valid even after the builder is closed. + + This lets you mix the capture and explicit APIs on a single graph, + for example to inspect what was captured, augment it with extra + nodes, or build a conditional body entirely with the explicit API. + + Availability: + + - **Primary builders** (created by :meth:`Device.create_graph_builder` + or :meth:`Stream.create_graph_builder`): only after + :meth:`end_building`. + + - **Conditional-body builders** (returned by :meth:`if_then`, + :meth:`if_else`, :meth:`while_loop`, :meth:`switch`): both before + :meth:`begin_building` and after :meth:`end_building`. The body + graph already exists when the conditional is created, so you may + populate it through this view without ever calling + :meth:`begin_building` on the body builder. + + - **Forked builders** (returned by :meth:`split`): never. Forked + builders share the primary builder's graph; access it through the + primary instead. + + Returns + ------- + GraphDefinition + A view of the graph being built. + + Raises + ------ + RuntimeError + If the builder is forked, currently building, or (for primary + builders) has not started building yet. + """ + if self._kind == FORKED: + raise RuntimeError( + "graph_definition is unavailable on forked graph builders; " + "access it through the primary builder instead." + ) + if self._state == CAPTURING: + raise RuntimeError( + "graph_definition is unavailable while capture is in " + "progress; call end_building() first." + ) + if self._kind == PRIMARY and self._state == CAPTURE_NOT_STARTED: + raise RuntimeError( + "graph_definition is unavailable before begin_building() on " + "a primary builder; no graph has been created yet." + ) + return GraphDefinition._from_handle(self._h_graph) + def begin_building(self, mode: str | None = "relaxed") -> GraphBuilder: """Begins the building process. diff --git a/cuda_core/tests/graph/test_graph_builder.py b/cuda_core/tests/graph/test_graph_builder.py index 18dfe21cc1..efb70fe75d 100644 --- a/cuda_core/tests/graph/test_graph_builder.py +++ b/cuda_core/tests/graph/test_graph_builder.py @@ -5,11 +5,12 @@ import numpy as np import pytest -from helpers.graph_kernels import compile_common_kernels +from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels from helpers.marks import requires_module +from helpers.misc import try_create_condition from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource, launch -from cuda.core.graph import GraphBuilder +from cuda.core.graph import GraphBuilder, GraphDefinition def test_graph_is_building(init_cuda): @@ -384,3 +385,190 @@ def test_graph_stream_lifetime(init_cuda): # Destroy the stream stream.close() + + +# --------------------------------------------------------------------------- +# GraphBuilder.graph_definition +# --------------------------------------------------------------------------- + + +def test_graph_definition_returns_graph_definition_after_end_building(init_cuda): + """Primary builder exposes its captured graph as a GraphDefinition after end_building().""" + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + gb = Device().create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + gb.end_building() + + gd = gb.graph_definition + assert isinstance(gd, GraphDefinition) + # The captured graph must contain the launched kernels. + assert len(gd.nodes()) == 2 + + +def test_graph_definition_raises_before_begin_building(init_cuda): + """Primary builder has no graph allocated before begin_building().""" + gb = Device().create_graph_builder() + with pytest.raises(RuntimeError, match="before begin_building"): + _ = gb.graph_definition + + +def test_graph_definition_raises_during_capture(init_cuda): + """graph_definition is unsafe while the driver is actively capturing.""" + gb = Device().create_graph_builder().begin_building() + try: + with pytest.raises(RuntimeError, match="capture is in"): + _ = gb.graph_definition + finally: + gb.end_building() + + +def test_graph_definition_raises_for_forked(init_cuda): + """Forked builders share the primary's graph; their property must raise.""" + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + gb = Device().create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + primary, sibling = gb.split(2) + try: + with pytest.raises(RuntimeError, match="forked"): + _ = sibling.graph_definition + finally: + sibling = GraphBuilder.join(primary, sibling) + sibling.end_building() + + +def test_graph_definition_shares_ownership(init_cuda): + """Closing the builder must not invalidate a held GraphDefinition.""" + mod = compile_common_kernels() + empty_kernel = mod.get_kernel("empty_kernel") + + gb = Device().create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), empty_kernel) + gb.end_building() + + gd = gb.graph_definition + gb.close() + # The shared CUgraph keeps the graph alive. + assert len(gd.nodes()) == 1 + + +def test_graph_definition_round_trips_through_explicit_api(init_cuda): + """Mutating via the explicit API survives complete() and runs correctly.""" + mod = compile_common_kernels() + add_one = mod.get_kernel("add_one") + + launch_stream = Device().create_stream() + mr = LegacyPinnedMemoryResource() + b = mr.allocate(4) + arr = np.from_dlpack(b).view(np.int32) + arr[0] = 0 + + gb = launch_stream.create_graph_builder().begin_building() + launch(gb, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + gb.end_building() + + # Add a second add_one through the explicit GraphDefinition view. + gd = gb.graph_definition + captured_node = next(iter(gd.nodes())) + captured_node.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + assert len(gd.nodes()) == 2 + + graph = gb.complete() + graph.launch(launch_stream) + launch_stream.sync() + assert arr[0] == 2 + + b.close() + + +@requires_module(np, "2.1") +def test_graph_definition_hybrid_conditional_body(init_cuda): + """Populate a conditional body entirely through the explicit API. + + This is the headline hybrid flow enabled by the new property: + ``if_then`` returns a ``GraphBuilder`` for the body, but instead of + calling ``begin_building`` and capturing into it, we reach for + ``graph_definition`` and add nodes through the explicit API. + """ + mod = compile_conditional_kernels(int) + add_one = mod.get_kernel("add_one") + set_handle = mod.get_kernel("set_handle") + + launch_stream = Device().create_stream() + mr = LegacyPinnedMemoryResource() + b = mr.allocate(4) + arr = np.from_dlpack(b).view(np.int32) + arr[0] = 0 + + gb = Device().create_graph_builder().begin_building() + condition = try_create_condition(gb) + launch(gb, LaunchConfig(grid=1, block=1), set_handle, condition, 1) + body_gb = gb.if_then(condition) + + # Skip body_gb.begin_building() entirely -- the body graph already + # exists at conditional-node creation time and is exposed here. + body_def = body_gb.graph_definition + assert isinstance(body_def, GraphDefinition) + assert len(body_def.nodes()) == 0 + body_def.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + + graph = gb.end_building().complete() + graph.launch(launch_stream) + launch_stream.sync() + assert arr[0] == 1 + + b.close() + + +@requires_module(np, "2.1") +def test_graph_definition_conditional_body_after_capture(init_cuda): + """Capture into a conditional body, then augment it via the explicit API.""" + mod = compile_conditional_kernels(int) + add_one = mod.get_kernel("add_one") + set_handle = mod.get_kernel("set_handle") + + launch_stream = Device().create_stream() + mr = LegacyPinnedMemoryResource() + b = mr.allocate(4) + arr = np.from_dlpack(b).view(np.int32) + arr[0] = 0 + + gb = Device().create_graph_builder().begin_building() + condition = try_create_condition(gb) + launch(gb, LaunchConfig(grid=1, block=1), set_handle, condition, 1) + body_gb = gb.if_then(condition).begin_building() + + # Capture one increment into the body. + launch(body_gb, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + body_gb.end_building() + + # Add a second increment via the explicit API on the same body graph. + body_def = body_gb.graph_definition + captured_node = next(iter(body_def.nodes())) + captured_node.launch(LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data) + assert len(body_def.nodes()) == 2 + + graph = gb.end_building().complete() + graph.launch(launch_stream) + launch_stream.sync() + assert arr[0] == 2 + + b.close() + + +@requires_module(np, "2.1") +def test_graph_definition_conditional_body_during_capture_raises(init_cuda): + """The CAPTURING-state guard fires for conditional bodies too.""" + gb = Device().create_graph_builder().begin_building() + condition = try_create_condition(gb) + body_gb = gb.if_then(condition).begin_building() + try: + with pytest.raises(RuntimeError, match="capture is in"): + _ = body_gb.graph_definition + finally: + body_gb.end_building() + gb.end_building()